Graphics Programming

펜윅 트리 본문

Season 1/Problem solving

펜윅 트리

minseoklee 2016. 10. 26. 00:19

펜윅 트리는 정수 ($ a[1], a[2], ..., a[n] $)에 대해 다음 연산을 수행할 수 있는 자료구조다.

1. query(p): ($ a[1] + a[2] + ... + a[p] $) 을 ($ O(log n) $)으로 계산한다.
2. update(i, x): ($ a[i] += x $) 를 ($ O(log n) $)으로 수행한다.

부분합 ($ a[l] + a[l+1] + ... + a[r] $)을 구하려면 ($ sum(r) - sum(l-1) $)을 계산한다. (단 ($ a[0] = 0 $)으로 정의)

다음과 같은 단순 부분합의 경우 1번 연산에는 ($ O(1) $)이 걸리지만 2번 연산에는 ($ O(n) $)의 시간이 걸린다. 1번 쿼리가 m번, 2번 쿼리가 k번 있을 때, 이 방법은 ($ O(m + nk) $)의 시간이 걸린다.

sum[1] = a[1]
for(int i=2; i<=n; i++) sum[i] = a[i] + sum[i-1]

펜윅 트리는 1, 2번 연산을 모두 효율적으로 처리하여 ($ O((m + k) log n) $)의 시간이 걸린다. 자세한 설명과 구현은 다음 글에 잘 설명되어 있다.

https://www.acmicpc.net/blog/view/21

펜윅 트리를 이용해 문제를 몇 개 풀어보자.

문제: 버블 소트

이 문제는 배열 내 반전(inversion)의 개수를 세는 문제다. ($ i < j $) 이고 ($ A[i] > A[j] $) 인 쌍 ($ (i, j) $)을 반전이라고 한다. 즉 각 원소에 대해, 왼쪽에 자기보다 큰 수가 몇 개 있는지를 세서 다 더하면 된다.

총 스왑 횟수가 반전 수와 같은 이유는 이렇다. ($ A[i] > A[j] $)라면 ($ A[i] $)는 스왑에 의해 오른쪽으로 점점 이동하다 결국 ($ A[j] $)와 교환된다. 일단 ($ A[i] $)가 ($ A[j] $)의 오른쪽에 오고 나면 이 둘은 더 이상 스왑할 필요가 없다. ($ A[j] $)의 왼쪽에 있으면서 ($ A[j] $)보다 큰 원소가 ($ x $)개라고 하면, ($ A[j] $)가 관여하는 스왑은 ($ x $)번 일어난다.

자명한 방법은 ($ \Theta(n^2) $)의 시간이 걸리는데 n이 최대 50만이라 제한시간 1초에는 택도 없다. 더 빠른 방법이 필요하다.

($ A[i] $)를 다음과 같이 정의해보자.

($ A[i] $) = 지금까지 ($ i $)가 출현한 횟수

($ a[j] = x $)라 할 때, ($ i < j $) 이고 ($ a[i] > a[j] $)인 ($ i $)의 개수는 ($ A[x+1] + A[x+2] + ... + A[MAX] $) 이다. (MAX는 가장 큰 원소의 값) 펜윅 트리의 정의에 대어보면 이 합은 ($ query(MAX) - query(x) $)으로 표현할 수 있다. 그 다음 ($ A[x] $)에 ($ 1 $)을 더한다. 따라서 다음과 같은 의사 코드를 생각해볼 수 있다.

long long answer = 0;
for(int i=1; i<=n; i++){
    answer += query(MAX) - query(a[i]);
    update(a[i], 1);
}

그런데 ($ A[i] $)들의 값의 범위가 너무 큰 게 문제다. ($ -10^9 \le A[i] \le 10^9 $) 이므로 배열 크기를 ($ 10^{10} $) 으로 잡으면 메모리는 대략 38.147GB가 필요하다. 문제의 메모리 제한은 8MB다.

여기서 펜윅 트리나 세그먼트 트리를 이용해 문제를 풀 때 자주 쓰이는 테크닉이 있다. 이 문제를 풀 때 실제 ($ A[i] $) 값은 전혀 상관이 없고 대소 비교만 필요하다. 그리고 원소 개수는 최대 50만이므로 ($ A[i] $) 들을 크기 순으로 1에서 50만까지 매핑하는 것이 가능하다.

for(int i=0; i<n; i++) temp[i] = ary[i];
sort(temp, temp + n);
int p = 0;
for(int i=0; i<n; i++){
    if(i > 0 && temp[i] == temp[i-1]) continue;
    temp[p++] = temp[i];
}

가장 작은 원소를 1에 매핑할 때, ($ a[i] $)가 몇 번째 원소인지는 다음과 같이 계산할 수 있다.

int x = (lower_bound(temp, temp + p, ary[i]) - temp) + 1;

나는 x를 매번 계산했지만 전처리할 때 모든 ary[i]를 x로 덮어씌운 다음 문제를 풀 수도 있다.

원소 50만개인 int 배열 3개가 차지하는 메모리는 약 5.72MB이다. 이 외에는 메모리를 쓸 일이 별로 없으므로 8MB 제한에 들어갈 것이다. 다음은 완전한 코드다.

#include <stdio.h>
#include <string.h>
#include <algorithm>
using namespace std;

const int MAX = 500001;
int ary[MAX], temp[MAX];
int BIT[MAX], n;

int query(int i) {
    int answer = 0;
    while(i > 0){
        answer += BIT[i];
        i -= i & -i;
    }
    return answer;
}

void update(int i, int delta) {
    while(i <= n){
        BIT[i] += delta;
        i += i & -i;
    }
}

int main() {
    // input
    scanf("%d", &n);
    for(int i=0; i<n; i++) scanf("%d", ary + i);

    // map input values into [1:500000]
    for(int i=0; i<n; i++) temp[i] = ary[i];
    sort(temp, temp + n);
    int p = 0;
    for(int i=0; i<n; i++){
        if(i > 0 && temp[i] == temp[i-1]) continue;
        temp[p++] = temp[i];
    }

    // solve
    memset(BIT, 0, sizeof(BIT));
    long long answer = 0;
    for(int i=0; i<n; i++){
        int x = (lower_bound(temp, temp + p, ary[i]) - temp) + 1;
        answer += query(n) - query(x);
        update(x, 1);
    }

    // output
    printf("%lld\n", answer);

    return 0;
}

실제 채점 결과 메모리 사용량은 6.8MB, 실행 시간은 292ms가 나왔다.


문제: Pashmak and Parmida's problem

문제 서술은 간단하지만 함수 ($ f $)의 의미를 해석하기가 조금 어렵다. 다음과 같이 정의해보자.

$$ L_i = f(1, i, a_i) \\ R_j = f(j, n, a_j) $$

그러면 다음과 같이 문제를 다시 쓸 수 있다.

($ i < j $) 이고 ($ L_i > R_j $)인 쌍 ($ (i, j) $)은 총 몇 개 있는가?

첫 문제와 비슷한 형태가 되었다. 이번에도 대소 비교만 하면 되므로 일단 ($ a_i $)들을 1 ~ n 사이의 값으로 매핑해놓는다. 그리고 ($ L_i $)와 ($ R_j $) 자체는 각각 ($ \Theta(n) $)으로 쉽게 구할 수 있다.

int main() {
    // input
    scanf("%d", &n);
    for(int i=0; i<n; i++) scanf("%d", ary + i);

    // map values of ary[1..n] into [1..n]
    for(int i=0; i<n; i++) temp[i] = ary[i];
    sort(temp, temp + n);
    int p = 0;
    for(int i=0; i<n; i++){
        if(i > 0 && temp[i] == temp[i-1]) continue;
        temp[p++] = temp[i];
    }
    for(int i=0; i<n; i++){
        ary[i] = (lower_bound(temp, temp + p, ary[i]) - temp) + 1;
    }

    // calculate L[i] and R[i]
    memset(cnt, 0, sizeof(cnt));
    for(int i=1; i<=n; i++){
        int x = ary[i-1];
        cnt[x]++;
        L[i] = cnt[x];
    }
    memset(cnt, 0, sizeof(cnt));
    for(int i=n; i>=1; i--){
        int y = ary[i-1];
        cnt[y]++;
        R[i] = cnt[y];
    }

    // solve the problem
    // ...
}

시간은 3초, 메모리는 256MB나 주므로 펜윅 트리를 이용해 답만 올바르게 구한다면 시간/메모리 제한에 걸릴 걱정은 하지 않아도 될 것 같다.

($ A[i] $)를 다음과 같이 정의하자.

($ A[i] $) = 지금까지의 ($ L_k = i $) 인 ($ k $)의 개수

그러면 ($ R_j $)보다 큰 ($ L_i $)의 개수는 ($ A[R_j + 1] \; + \; ... \; + \; A[n] $) = ($ query(n) - query(R_j) $)가 된다. 개수를 구했으면 ($ A[L_i] $)에 1을 더해준다.

   // the number of pairs (i, j) s.t. i < j and L[i] > R[j]
    memset(BIT, 0, sizeof(BIT));
    long long answer = 0;
    for(int i=1; i<=n; i++){
        answer += query(n) - query(R[i]);
        update(L[i], 1);
    }

완전한 코드는 여기서 볼 수 있다.

Comments