cdq分治

发布时间 2023-07-21 18:44:35作者: Rose_Lu

cdq分治是用分治来解决偏序(多是三维偏序)问题 其实我也不知道具体解决什么

这里就先以P3870陌上花开【三维偏序】为例

要求的就是\(\sum_{i=0}^{n}\)\(a_i<a_j\)\(b_i<b_j\)\(c_i<c_j\)\(i\not=j\)的个数

这里就是典型的三维偏序

那思路肯定都是先找\(a_i<a_j\),再找\(b_i<b_j\),最后找\(c_i<c_j\),cdq也是这样想的

  • 芝士思路

先将整个序列以\(a_i\)的值从大到小排序,这样我们就可以基本解决 \(a\) 数组的问题了
\(b\) 数组呢?我们考虑将排好序的 \(a\) 数组进行二分,那么我们是可以保证每次二分之后 \(l-mid\) 里的 \(a\) 是肯定小于 \((mid+1)-r\) 的,那么这时我们就可以放心的求 \(b\) 数组了。
我们求 \(b\) 的方法就是搞两个指针,分别从\(l\)\((mid+1)\)开始往后跑,来比较他们的 \(b\) ,这时就可以找到对于每个 \(j\) 小于等于它 \(b\)\(i\) 了。
\(c\) 如何比较?我们考虑用树状数组维护一下这谁能想得出来啊www
那这事就好办了,比完 \(b\) 之后就把它以 \(c\) 为下标出现次数为值存到树状数组里不就得了吗,OK,大功告成!

  • 芝士代码
#include <iostream>
#include <algorithm> 

using namespace std;
const int Max = 2e5+10;

int n , k , maxx , len , tot;
int ans[Max] , t[4*Max];

struct zx {
	int a , b , c , cnt , ans;
}p[Max] , s[Max];

bool cmp (zx x , zx y) {
	if(x.a == y.a) {
		if(x.b == y.b) return x.c < y.c;
		else return x.b < y.b;
	}
	else return x.a<y.a;
}

bool cmp1 (zx x , zx y) {
	if(x.b == y.b) return x.c < y.c;
	return x.b < y.b;
}

int lowbit(int x) {
	return x & (-x);
}

void add(int x , int y) {
	for(int i = x; i <= k; i += lowbit(i)) t[i] += y;
}

int ask(int x) {
	int res = 0;
	for(int i = x; i >= 1; i -= lowbit(i)) res += t[i];
	return res;
}

void cdq(int l , int r) {
	if(l == r) return;
	int mid = l + r >> 1;
	cdq(l , mid);
	cdq(mid+1 , r);
	sort(s+l , s+mid+1 , cmp1);
	sort(s+mid+1 , s+r+1 , cmp1);
	int j , i = l;
	for(int j = mid+1 ; j <= r; j++) {
		while(i <= mid && s[j].b >= s[i].b) {
			add(s[i].c , s[i].cnt);
			i++;
		}
		s[j].ans += ask(s[j].c);
	}
	for(int kk = l; kk < i; kk++) add(s[kk].c , -s[kk].cnt); 
}

int main() {
	cin >> n >> k;
	maxx = k;
	for(int i = 1; i <= n; i++) cin >> p[i].a >> p[i].b >> p[i].c;
	sort(p+1 , p+n+1 , cmp);
	for(int i = 1; i <= n; i++) {
		tot++;
		if(p[i+1].a != p[i].a || p[i+1].b != p[i].b || p[i+1].c != p[i].c) {
			s[++len].a = p[i].a;
			s[len].b = p[i].b;
			s[len].c = p[i].c;
			s[len].cnt = tot;
			tot = 0;
		}
	}
	cdq(1 , len);
	for(int i = 1; i <= len; i++) ans[s[i].ans+s[i].cnt-1] += s[i].cnt;
	for(int i = 0; i < n; i++) cout << ans[i] << endl;
	return 0;
}