본문 바로가기
프로그래밍

누적합으로 부분합 구하기

by blopz 2025. 8. 6.
# test-input
import sys
sys.stdin = open("./input/2001_input.txt", "r")

T = int(input().strip())


for test_case in range(1, T + 1):
    N, M = map(int, input().split())
    flies = [list(map(int, input().split())) for _ in range(N)]
    max_flies = None
    for j in range(N-M+1):
        for i in range(N-M+1):
            sum_flies = 0
            for j_delta in range(M):
                for i_delta in range(M):
                    sum_flies += flies[j+j_delta][i+i_delta]
            if max_flies is None or max_flies < sum_flies:
                max_flies = sum_flies

    print(f"#{test_case} {max_flies}")



무식하게 다 돌았다 (브루트포스 방법이라고 한다)

 

현재 0 부터 N-M+1 만큼 돌고

안에서 또 M만큼 돌면서 합을 구하고

그게 최대인지 확인한다

 

이 경우 시간복잡도가 O((N-M+1)^2 * M^2)로 나오는데

M과 N이 커지면 무지막지하게 비효율적이 된다.

 

해당 문제를 좀 더 효율적으로 풀 수 있는 방법이 있었는데 기억이 안나와서 검색을 좀 해보니

누적합을 계산하는 방법이 있었다

 

누적합은 배열의 누적합을 배열에 저장해 놓는 걸 의미한다

 

부분합 구하기

 

해당 보라색 부분의 부분합을 구하려면

빨간색 - 노란색 - 초록색 + 파란색 을 하면 구할 수 있다

 

O(M^2) 이 O(1) 가 된것이다

 

예를 들면, 크기가 100000000 x 100000000 (10억 x 10억)인 격자에서

100000 x 100000 (10만 x 10만) 크기의 부분합 중 최대를 구한다고 쳐보자.

 

그럼 브루트포스는 가능한 모든 시작 위치에서 정사각형 영역을 직접 순회하며 합을 구해야 한다.

 

가능한 시작 위치는 (10억 - 10만 + 1)^2 ≈ 10^18개이고, 각 위치마다 10만 x 10만 = 10^10개의 원소를 더해야 한다.

따라서 총 연산량은 10^18 × 10^10 = 10^28이 된다. 이건 현실적으로 불가능한 연산량이다.

 

반면, 누적합(prefix sum)을 사용하면 전체 배열의 누적합을 먼저 한 번 계산하는 데 10^18번의 연산이 필요하다.

그 이후 각 정사각형 영역의 합은 O(1) 시간에 구할 수 있으므로, 가능한 시작 위치마다 한 번씩만 계산하면 된다.

 

따라서 전체 연산량은 10^18 + 10^18 = 2 × 10^18 정도로, 브루트포스보다 약 10^10배, 즉 100억 배 더 효율적이다.

 

# test-input
import sys
sys.stdin = open("./input/2001_input.txt", "r")

T = int(input().strip())


for test_case in range(1, T + 1):
    N, M = map(int, input().split())
    flies = [list(map(int, input().split())) for _ in range(N)]

    prefix_sum = [[ 0 for _ in range(N)] for _ in range(N)]

    for j in range(N):
        for i in range(N):
            prefix_sum[j][i] = flies[j][i]
            if j-1 >= 0: prefix_sum[j][i] += prefix_sum[j-1][i]
            if i-1 >= 0: prefix_sum[j][i] += prefix_sum[j][i-1]
            if j-1 >= 0 and i-1 >= 0 : prefix_sum[j][i] -= prefix_sum[j-1][i-1]

    #print(prefix_sum)

    max_flies = None
    for j in range(N-M+1):
        for i in range(N-M+1):
            sum_flies = prefix_sum[j+M-1][i+M-1]
            if j-1 >= 0: sum_flies -= prefix_sum[j-1][i+M-1]
            if i-1 >= 0: sum_flies -= prefix_sum[j+M-1][i-1]
            if j-1 >= 0 and i-1 >= 0 : sum_flies += prefix_sum[j-1][i-1]
            if max_flies is None or max_flies < sum_flies:
                max_flies = sum_flies

    print(f"#{test_case} {max_flies}")​



누적합으로 다시 풀었다