Graphics Programming
펜윅 트리 본문
펜윅 트리는 정수 ($ a[1], a[2], ..., a[n] $)에 대해 다음 연산을 수행할 수 있는 자료구조다.
1. query(p): ($ a[1] + a[2] + ... + a[p] $) 을 ($ O(log n) $)으로 계산한다.
2. update(i, x): ($ a[i] += x $) 를 ($ O(log n) $)으로 수행한다.
부분합 ($ a[l] + a[l+1] + ... + a[r] $)을 구하려면 ($ sum(r) - sum(l-1) $)을 계산한다. (단 ($ a[0] = 0 $)으로 정의)
다음과 같은 단순 부분합의 경우 1번 연산에는 ($ O(1) $)이 걸리지만 2번 연산에는 ($ O(n) $)의 시간이 걸린다. 1번 쿼리가 m번, 2번 쿼리가 k번 있을 때, 이 방법은 ($ O(m + nk) $)의 시간이 걸린다.
sum[1] = a[1]
for(int i=2; i<=n; i++) sum[i] = a[i] + sum[i-1]
펜윅 트리는 1, 2번 연산을 모두 효율적으로 처리하여 ($ O((m + k) log n) $)의 시간이 걸린다. 자세한 설명과 구현은 다음 글에 잘 설명되어 있다.
https://www.acmicpc.net/blog/view/21
펜윅 트리를 이용해 문제를 몇 개 풀어보자.
문제: 버블 소트
이 문제는 배열 내 반전(inversion)의 개수를 세는 문제다. ($ i < j $) 이고 ($ A[i] > A[j] $) 인 쌍 ($ (i, j) $)을 반전이라고 한다. 즉 각 원소에 대해, 왼쪽에 자기보다 큰 수가 몇 개 있는지를 세서 다 더하면 된다.
총 스왑 횟수가 반전 수와 같은 이유는 이렇다. ($ A[i] > A[j] $)라면 ($ A[i] $)는 스왑에 의해 오른쪽으로 점점 이동하다 결국 ($ A[j] $)와 교환된다. 일단 ($ A[i] $)가 ($ A[j] $)의 오른쪽에 오고 나면 이 둘은 더 이상 스왑할 필요가 없다. ($ A[j] $)의 왼쪽에 있으면서 ($ A[j] $)보다 큰 원소가 ($ x $)개라고 하면, ($ A[j] $)가 관여하는 스왑은 ($ x $)번 일어난다.
자명한 방법은 ($ \Theta(n^2) $)의 시간이 걸리는데 n이 최대 50만이라 제한시간 1초에는 택도 없다. 더 빠른 방법이 필요하다.
($ A[i] $)를 다음과 같이 정의해보자.
($ A[i] $) = 지금까지 ($ i $)가 출현한 횟수
($ a[j] = x $)라 할 때, ($ i < j $) 이고 ($ a[i] > a[j] $)인 ($ i $)의 개수는 ($ A[x+1] + A[x+2] + ... + A[MAX] $) 이다. (MAX는 가장 큰 원소의 값) 펜윅 트리의 정의에 대어보면 이 합은 ($ query(MAX) - query(x) $)으로 표현할 수 있다. 그 다음 ($ A[x] $)에 ($ 1 $)을 더한다. 따라서 다음과 같은 의사 코드를 생각해볼 수 있다.
long long answer = 0;
for(int i=1; i<=n; i++){
answer += query(MAX) - query(a[i]);
update(a[i], 1);
}
그런데 ($ A[i] $)들의 값의 범위가 너무 큰 게 문제다. ($ -10^9 \le A[i] \le 10^9 $) 이므로 배열 크기를 ($ 10^{10} $) 으로 잡으면 메모리는 대략 38.147GB가 필요하다. 문제의 메모리 제한은 8MB다.
여기서 펜윅 트리나 세그먼트 트리를 이용해 문제를 풀 때 자주 쓰이는 테크닉이 있다. 이 문제를 풀 때 실제 ($ A[i] $) 값은 전혀 상관이 없고 대소 비교만 필요하다. 그리고 원소 개수는 최대 50만이므로 ($ A[i] $) 들을 크기 순으로 1에서 50만까지 매핑하는 것이 가능하다.
for(int i=0; i<n; i++) temp[i] = ary[i];
sort(temp, temp + n);
int p = 0;
for(int i=0; i<n; i++){
if(i > 0 && temp[i] == temp[i-1]) continue;
temp[p++] = temp[i];
}
가장 작은 원소를 1에 매핑할 때, ($ a[i] $)가 몇 번째 원소인지는 다음과 같이 계산할 수 있다.
int x = (lower_bound(temp, temp + p, ary[i]) - temp) + 1;
나는 x를 매번 계산했지만 전처리할 때 모든 ary[i]를 x로 덮어씌운 다음 문제를 풀 수도 있다.
원소 50만개인 int 배열 3개가 차지하는 메모리는 약 5.72MB이다. 이 외에는 메모리를 쓸 일이 별로 없으므로 8MB 제한에 들어갈 것이다. 다음은 완전한 코드다.
#include <stdio.h>
#include <string.h>
#include <algorithm>
using namespace std;
const int MAX = 500001;
int ary[MAX], temp[MAX];
int BIT[MAX], n;
int query(int i) {
int answer = 0;
while(i > 0){
answer += BIT[i];
i -= i & -i;
}
return answer;
}
void update(int i, int delta) {
while(i <= n){
BIT[i] += delta;
i += i & -i;
}
}
int main() {
// input
scanf("%d", &n);
for(int i=0; i<n; i++) scanf("%d", ary + i);
// map input values into [1:500000]
for(int i=0; i<n; i++) temp[i] = ary[i];
sort(temp, temp + n);
int p = 0;
for(int i=0; i<n; i++){
if(i > 0 && temp[i] == temp[i-1]) continue;
temp[p++] = temp[i];
}
// solve
memset(BIT, 0, sizeof(BIT));
long long answer = 0;
for(int i=0; i<n; i++){
int x = (lower_bound(temp, temp + p, ary[i]) - temp) + 1;
answer += query(n) - query(x);
update(x, 1);
}
// output
printf("%lld\n", answer);
return 0;
}
실제 채점 결과 메모리 사용량은 6.8MB, 실행 시간은 292ms가 나왔다.
문제: Pashmak and Parmida's problem
문제 서술은 간단하지만 함수 ($ f $)의 의미를 해석하기가 조금 어렵다. 다음과 같이 정의해보자.
$$ L_i = f(1, i, a_i) \\ R_j = f(j, n, a_j) $$
그러면 다음과 같이 문제를 다시 쓸 수 있다.
($ i < j $) 이고 ($ L_i > R_j $)인 쌍 ($ (i, j) $)은 총 몇 개 있는가?
첫 문제와 비슷한 형태가 되었다. 이번에도 대소 비교만 하면 되므로 일단 ($ a_i $)들을 1 ~ n 사이의 값으로 매핑해놓는다. 그리고 ($ L_i $)와 ($ R_j $) 자체는 각각 ($ \Theta(n) $)으로 쉽게 구할 수 있다.
int main() {
// input
scanf("%d", &n);
for(int i=0; i<n; i++) scanf("%d", ary + i);
// map values of ary[1..n] into [1..n]
for(int i=0; i<n; i++) temp[i] = ary[i];
sort(temp, temp + n);
int p = 0;
for(int i=0; i<n; i++){
if(i > 0 && temp[i] == temp[i-1]) continue;
temp[p++] = temp[i];
}
for(int i=0; i<n; i++){
ary[i] = (lower_bound(temp, temp + p, ary[i]) - temp) + 1;
}
// calculate L[i] and R[i]
memset(cnt, 0, sizeof(cnt));
for(int i=1; i<=n; i++){
int x = ary[i-1];
cnt[x]++;
L[i] = cnt[x];
}
memset(cnt, 0, sizeof(cnt));
for(int i=n; i>=1; i--){
int y = ary[i-1];
cnt[y]++;
R[i] = cnt[y];
}
// solve the problem
// ...
}
시간은 3초, 메모리는 256MB나 주므로 펜윅 트리를 이용해 답만 올바르게 구한다면 시간/메모리 제한에 걸릴 걱정은 하지 않아도 될 것 같다.
($ A[i] $)를 다음과 같이 정의하자.
($ A[i] $) = 지금까지의 ($ L_k = i $) 인 ($ k $)의 개수
그러면 ($ R_j $)보다 큰 ($ L_i $)의 개수는 ($ A[R_j + 1] \; + \; ... \; + \; A[n] $) = ($ query(n) - query(R_j) $)가 된다. 개수를 구했으면 ($ A[L_i] $)에 1을 더해준다.
// the number of pairs (i, j) s.t. i < j and L[i] > R[j]
memset(BIT, 0, sizeof(BIT));
long long answer = 0;
for(int i=1; i<=n; i++){
answer += query(n) - query(R[i]);
update(L[i], 1);
}
완전한 코드는 여기서 볼 수 있다.