【学习笔记】(2) 基础莫队——优美的暴力

发布时间 2023-06-02 15:02:43作者: jiangchenyangsong

莫队,是莫涛发明的一种解决区间查询等问题的离线算法,基于分块思想,复杂度一般为 \(\mathcal{O}(N \sqrt{N})\)

普通莫队

例题:P1972 [SDOI2009] HH的项链

其实这道题用莫队过不了,就仅是用来引入莫队而已

题意:长度为 \(N\) 的序列,\(M\) 次询问,每次询问一段闭区间内有多少个不同的数。

先考虑暴力做法。对于每次操作,从 \(L\)\(R\) 扫一遍,统计个数。但我们想,对于一些左右端点都相近区间,这样的做法显然浪费了很多可以利用的数据。

于是再考虑另一种暴力做法。用两个指针 \(l,r\) 标记左右区间,对于一个新的查询,只需移动这两个指针到新的位置即可。这样保留的可以利用的数据,不用重新扫描。但又很容易发现新的问题:对于两个相距很远的区间,移动指针仍然需要 \(\mathcal{O}(N)\) 。只需稍加构造,复杂度仍为 \(\mathcal{O}(NM)\)。此时,先前区间维护的信息也不能很好地传递给之后相近的区间,而是被中间的其他区间浪费掉了。

若我们将所有区间按右端点排序,则右端点指针仅会移动至多 \(n\) 次。这种单调性给我们启发。如果能再将所有相近的左端点维护在一起,那么不就解决了以上问题吗?

莫队维护相近左端点的方法是分块。设每块长度为 \(S\),共 $\dfrac{N}{S} $ 块。将所有区间按照 \(L\) 所属块排序,\(L\) 所属块相同时再按 \(R\) 递增排序。在这样的顺序下,执行上述暴力做法。然后我们再来分析时间复杂度:

  • 左端点指针:块内移动每次为 \(\mathcal{O}(S)\),移动 \(M\) 次;块间移动每次为 \(\mathcal{O}(S)\),移动 $\dfrac{N}{S} $ 次。共 \(\mathcal{O}(SM+N)\)
  • 右端点指针:对于左端点所属的每个块,每次 \(\mathcal{O}(N)\),移动 \(\dfrac{N}{S}\) 次。共 \(\mathcal{O}(\dfrac{N^2}{S})\)

总时间复杂度为 \(\mathcal{O}(SM+N+\dfrac{N^2}{S})\)。其中 \(N\) 一定小于 \(\dfrac{N^2}{S}\) ,可以忽略。利用基本不等式,\(Sm+\dfrac{N^2}{S} \geq \sqrt{N^2M}\),其中 \(N,M\) 是常数,故最小值在 \(SM\)\(\dfrac{N^2}{S}\) 相等时取得。于是得到 \(S=\sqrt{\dfrac{N^2}{M}}\) 时,复杂度取最小值,值为 \(\mathcal{O}(N\sqrt{M})\)

#include<bits/stdc++.h>
#define N 1000005
using namespace std;
int n,m,res,len;
int a[N],f[N],ans[N],pos[N];
//pos[i]表示a[i]所在的块的编号 
//f[i]表示当前种类的个数 
struct obj{
	int l,r,id;
	bool operator <(const obj &A) const{
		return pos[A.l]==pos[l]?r<A.r:pos[l]<pos[A.l]; //排序 
	}
}q[N];
void add(int x){
	if(!f[a[x]]) res++;
	f[a[x]]++;
}
void del(int x){
	f[a[x]]--;
	if(!f[a[x]]) res--;
}
int main(){
	scanf("%d",&n);
	for(int i=1;i<=n;i++) scanf("%d",&a[i]);
	scanf("%d",&m);
	for(int i=1;i<=m;i++){
		int l,r;
		scanf("%d%d",&l,&r);
		q[i]=(obj){l,r,i};
	}
	len=max(1,(int)sqrt((double)n*n/m));
	for(int i=1;i<=n;i++) pos[i]=(i-1)/len+1;
	sort(q+1,q+1+m);
	for(int i=1,l=1,r=0;i<=m;i++){
		while(l>q[i].l) add(--l);
		while(l<q[i].l) del(l++);
		while(r<q[i].r) add(++r);
		while(r>q[i].r) del(r--); 
		ans[q[i].id]=res;
	}
	for(int i=1;i<=m;i++) printf("%d\n",ans[i]);
	return 0;
}

例题:P1494 [国家集训队] 小 Z 的袜子

对于\(L,R\) 的询问。

设其中颜色为 \(x,y,z\) 的袜子的个数为 \(a,b,c \dots\)

那么答案即为 \(\dfrac{\dfrac{a*(a-1)}{2}+\dfrac{b*(b-1)}{2}+\dfrac{c*(c-1)}{2} ...}{\dfrac{(R-L+1)*(R-L)}{2}}\)

化简得: \(\dfrac{a^2+b^2+c^2+...-(a+b+c+d+\dots)}{(R-L+1)*(R-L)}\)

即: \(\dfrac{a^2+b^2+c^2+...-(R-L+1)}{(R-L+1)*(R-L)}\)

我们需要解决的一个问题,就是求一个区间内每种颜色数目的平方和。

\(code\)

#include<bits/stdc++.h>
#define N 50005
#define LL long long
using namespace std;
int n,m,res,len;
int a[N],pos[N];
LL ans1[N],ans2[N],f[N],ans;
struct obj{
	int l,r,id;
	bool operator <(const obj &A) const{
		return pos[A.l]==pos[l]?r<A.r:pos[l]<pos[A.l];
	}
}q[N];
void update(int x,int k){
	ans-=f[a[x]]*f[a[x]];
	f[a[x]]+=k;
	ans+=f[a[x]]*f[a[x]];
}
int main(){
	scanf("%d%d",&n,&m);
	for(int i=1;i<=n;i++) scanf("%d",&a[i]);
	for(int i=1;i<=m;i++){
		scanf("%d%d",&q[i].l,&q[i].r);
		q[i].id=i;
	}
	len=max(1,(int)sqrt((double)n*n/m));
	for(int i=1;i<=n;i++) pos[i]=(i-1)/len+1;
	sort(q+1,q+1+m);
	for(int i=1,l=1,r=0;i<=m;i++){
		int ii=q[i].id;
		while(l>q[i].l) update(--l,1);
		while(l<q[i].l) update(l++,-1);
		while(r<q[i].r) update(++r,1);
		while(r>q[i].r) update(r--,-1);
		if(q[i].l==q[i].r){
			ans1[ii]=0,ans2[ii]=1;
			continue;
		}
		ans1[ii]=ans-(r-l+1);
		ans2[ii]=(r-l+1)*1LL*(r-l);
		LL k=__gcd(ans1[ii],ans2[ii]);
		ans1[ii]/=k,ans2[ii]/=k;
	}
	for(int i=1;i<=m;i++) printf("%lld/%lld\n",ans1[i],ans2[i]);
	return 0;
}

普通莫队奇偶性优化

对于左端点在同一奇数块的区间,右端点按升序排列。

对于左端点在同一偶数块的区间,右端点按降序排列。

这样我们的 \(r\) 指针在处理完这个奇数块的问题后,将在返回的途中处理偶数块的问题,再向 \(n\) 移动处理下一个奇数块的问题,优化了 \(r\) 指针的移动次数,一般情况下,这种优化能让程序快 \(30 \%\) 左右。

bool operator <(const obj &A) const{
	return pos[A.l]==pos[l]?(pos[l]&1?r<A.r:r>A.r):pos[l]<pos[A.l];
}

P4396 [AHOI2013]作业

求区间 \([l,r]\) 值在值域 \([a,b]\) 的数的个数和数的种类。

莫队套值域分块。

点击查看代码
#include <bits/stdc++.h>
using namespace std;
const int N = 1e6 + 51;
inline int read() {
    int x = 0, f = 1;
    char ch = getchar();
    while (ch < '0' || ch > '9') {
        if (ch == '-')
            f = -f;
        ch = getchar();
    }
    while (ch >= '0' && ch <= '9') {
        x = (x << 1) + (x << 3) + (ch ^ 48);
        ch = getchar();
    }
    return x * f;
}
int n, m, len, cnt;
int a[N], pos[N], L[N], R[N], ans[N], res[N];
// ans[i] 是不同数的种类
// res[i] 是数的数量
int sumr[N], sum[N], suma[N];
// sum[i] 表示 i 这个值的数量
// sumr[i] 表示 pos[i] 这个块的 res
// suma[i] 表示 pos[i] 这个块的 ans
struct Query {
    int id, l, r, a, b;
    bool operator<(const Query &A) const { return pos[l] ^ pos[A.l] ? pos[l] < pos[A.l] : r < A.r; }
} q[N];
inline void add(int x) {
    ++sum[a[x]], ++sumr[pos[a[x]]];
    if (sum[a[x]] <= 1)
        ++suma[pos[a[x]]];
}
inline void del(int x) {
    --sum[a[x]], --sumr[pos[a[x]]];
    if (sum[a[x]] <= 0)
        --suma[pos[a[x]]];
}
inline int get_ans(int l, int r) {
    int ss = 0;
    if (l > r)
        return 0;
    if (pos[l] == pos[r]) {
        for (int i = l; i <= r; ++i) ss += (sum[i] >= 1);
        return ss;
    }
    for (int i = l; i <= R[pos[l]]; ++i) ss += (sum[i] >= 1);
    for (int i = pos[l] + 1; i <= pos[r] - 1; ++i) ss += suma[i];
    for (int i = L[pos[r]]; i <= r; ++i) ss += (sum[i] >= 1);
    return ss;
}
inline int get_res(int l, int r) {
    int ss = 0;
    if (l > r)
        return 0;
    if (pos[l] == pos[r]) {
        for (int i = l; i <= r; ++i) ss += sum[i];
        return ss;
    }
    for (int i = l; i <= R[pos[l]]; ++i) ss += sum[i];
    for (int i = pos[l] + 1; i <= pos[r] - 1; ++i) ss += sumr[i];
    for (int i = L[pos[r]]; i <= r; ++i) ss += sum[i];
    return ss;
}
int main() {
    n = read(), m = read();
    len = sqrt(n);
    for (int i = 1; i <= n; ++i) a[i] = read(), pos[i] = (i - 1) / len + 1;
    for (int i = 1; i <= m; ++i) q[i] = (Query){ i, read(), read(), read(), read() };
    sort(q + 1, q + 1 + m);
    for (int i = 1; i <= n; ++i) L[i] = R[i - 1] + 1, R[i] = min(n, L[i] + len - 1);
    int l = 1, r = 0;
    for (int i = 1; i <= m; ++i) {
        while (r < q[i].r) add(++r);
        while (l > q[i].l) add(--l);
        while (l < q[i].l) del(l++);
        while (r > q[i].r) del(r--);
        res[q[i].id] = get_res(q[i].a, q[i].b);
        ans[q[i].id] = get_ans(q[i].a, q[i].b);
    }
    for (int i = 1; i <= m; ++i) printf("%d %d\n", res[i], ans[i]);
    return 0;
}

带修莫队

例题:P1903 [国家集训队] 数颜色 / 维护队列

这题可以理解为 HH的项链\(+\)单点修改 。有了修改之后,我们就不能只考虑左右边界,而还要加入另外一个量:时间 \(t\) 。从几何角度理解,这从一个二维问题变成了一个三维问题。从之前的只对 \(L\) 分块,变成要对 \(L, R\) 分块,并且加入关于 \(t\) 的排序。排序优先度为 \(L\) 分块 \(\rightarrow R\) 分块 \(\rightarrow t\) ,从小到大排序。

此时设三个指针:左指针 \(l\),右指针 \(r\),时间指针 \(t\)。指针移动一单位均为 \(\mathcal{O}(1)\)。分析复杂度:

  • \(l\):与普通莫队相似,\(\mathcal{O}(SM+N)\)
  • \(r\):块内移动\(\mathcal{O}(S)\)\(M\) 次;块间移动 \(\mathcal{O}(N)\)\(\dfrac{N}{S}\) 次。共$ \mathcal{O}(SM+\dfrac{N^2}{S})$。
  • \(t\):设共修改 \(T\) 次。\(l\) 块间移动 \(\dfrac{N}{S}\) 次,\(r\) 块间移动 \(\dfrac{N}{S}\) 次,每次 t 移动为 \(O(T)\)。共 \(\mathcal{O}(\dfrac{N ^2T}{S^2})\)
    相加,忽略 \(N\) 及常数,则为 \(\mathcal{O}(MS+\dfrac{N^2}{S}+\dfrac{N^2T}{S^2})\)。结论是,当 \(M\)\(N\) 处于同一个数量级时,最小值为 \(\mathcal{O}(\sqrt[3]{N^4T})\),当 \(S=\sqrt[3]{NT}\) 时取得。
    另外,\(t\) 指针会上下移动,故要维护好修改操作,并支持向下移动。
#include<bits/stdc++.h>
#define N 200010
#define M 1000010
using namespace std;
void read(int &n){
	int f=1,x=0;
	char ch=getchar();
	while(ch<'0'||ch>'9'){if(ch=='-') f=-1;ch=getchar();}
	while(ch>='0'&&ch<='9'){x=(x<<1)+(x<<3)+(ch^48);ch=getchar();}
	n=x*f;
}
int n,m,res,toti,tott=1;
int a[N],c[N],p[N],pos[N],f[M],ans[N];
struct obj{
int id,l,r,t;
bool operator <(const obj &A) const{
    return pos[l]!=pos[A.l]?pos[l]<pos[A.l]:(pos[r]!=pos[A.r]?pos[r]<pos[A.r]:t<A.t); 
    //增加时间维度 
}
}q[N];
void add(int x){
	if(!f[x]) res++;
	f[x]++;
	}
void del(int x){
	f[x]--;
	if(!f[x]) res--;
}
int main(){
	read(n),read(m);
	for(int i=1;i<=n;i++) read(a[i]);
	while(m--){
    	char opt[5];
    	int l,r;
    	scanf("%s",opt);
    	read(l),read(r);
    	if(opt[0]=='Q') q[++toti]=(obj){toti,l,r,tott};
    	else p[++tott]=l,c[tott]=r;
	}
	int len=max(1,(int)(ceil(pow(n,2.0/3.0))));
	for(int i=1;i<=n;i++) pos[i]=(i-1)/len+1;
	sort(q+1,q+1+toti);
	for(int i=1,l=1,r=0,t=1;i<=toti;i++){
    	while(r<q[i].r) add(a[++r]);
    	while(r>q[i].r) del(a[r--]);
    	while(l>q[i].l) add(a[--l]);
    	while(l<q[i].l) del(a[l++]);
    	while(t<q[i].t){
        	t++;
        	if(l<=p[t]&&p[t]<=r) del(a[p[t]]),add(c[t]);
        	//如果这个修改的值在[l,r]区间内,则其变化将对答案造成影响
        	swap(a[p[t]],c[t]);
        	//这里有个很巧妙的操作
        	//对于一个操作,下一次需要为的颜色是本次被改变的颜色
        	//比如,我把颜色3改为了7,那么再进行这次修改的时候就是把7改为3
        	//所以直接交换两种颜色就好 
    	}
    	while(t>q[i].t){ //同上 
        	if(l<=p[t]&&p[t]<=r) del(a[p[t]]),add(c[t]);
        	swap(a[p[t]],c[t]);
        	--t;
    	}
    	ans[q[i].id]=res;
	}
	for(int i=1;i<=toti;i++) printf("%d\n",ans[i]);
	return 0;
}  

回滚莫队

有些题目在区间转移时,可能会出现增加或者删除无法实现的问题。在只有增加不可实现或者只有删除不可实现的时候,就可以使用回滚莫队在 \(\mathcal{O}(n\sqrt{m})\) 的时间内解决问题。回滚莫队的核心思想就是既然我只能实现一个操作,那么我就只使用一个操作,剩下的交给回滚解决。

回滚莫队分为只使用增加操作的回滚莫队和只使用删除操作的回滚莫队。

例题:AT1219 歴史の研究

给你一个长度为 \(n\) 的数组 \(A\)\(Q\) 个询问 \((1\leq Q,N \leq 10^5)\),每次询问一个区间 \([L,R]\) 内重要度最大的数字,要求 输出其重要度。一个数字 \(i\) 重要度的定义为 \(i\) 乘上 \(i\) 在区间内出现的次数。

在这个问题中,在增加的过程中更新答案是很好实现的,但是在删除的过程中更新答案是不好实现的。因为如果增加会影响答案,那么新答案必定是刚刚增加的数字的重要度,而如果删除过后区间重要度最大的数字改变,我们很难确定新的重要度最大的数字是哪一个。所以,普通的莫队很难解决这个问题。

具体算法:

  • 我们可以对原序列进行分块,对询问按以左端点所属块编号升序为第一关键字,右端点升序为第二关键字的方式排序。

  • 按顺序处理询问:

    • 如果询问左端点所属块 \(pos\) 和上一个询问左端点所属块的不同,那么将莫队区间的左端点初始化为 \(pos\) 的右端点加 \(1\), 将莫队区间的右端点初始化为 \(pos\) 的右端点;
    • 如果询问的左右端点所属的块相同,或者左右端点距离不超过一个块的大小,那么直接扫描区间回答询问;
    • 如果询问的左右端点所属的块不同:
      • 如果询问的右端点大于莫队区间的右端点,那么不断扩展右端点直至莫队区间的右端点等于询问的右端点;
      • 不断扩展莫队区间的左端点直至莫队区间的左端点等于询问的左端点;
      • 回答询问;
      • 撤销莫队区间左端点的改动,使莫队区间的左端点回滚到 的右端点加 。

这样我们就做到了只增加不删除(其实还是有删除的,因为要回滚嘛)

复杂度证明:

假设回滚莫队的分块大小是 \(S\)

  • 对于左、右端点在同一个块内的询问,可以在 \(\mathcal{O}(S)\) 时间内计算;
  • 对于其他询问,考虑左端点在相同块内的询问,它们的右端点单调递增,移动右端点的时间复杂度是 \(\mathcal{O}(n)\),而左端点单次询问的移动不超过 S,因为有 \(\dfrac{n}{S}\)
    个块,所以总复杂度是 \(\mathcal{O}(mS+\dfrac{n^2}{S})\),取 \(S=\dfrac{n}{\sqrt{m}}\) 最优,时间复杂度为 \(\mathcal{O}(n\sqrt{m})\)
#include<bits/stdc++.h>
#define LL long long
#define N 100005
#define int long long
using namespace std;
int read(){
	int x=0,f=1;
	char ch=getchar();
	while(ch<'0'||ch>'9'){if(ch=='-')f=-f;ch=getchar();}
	while(ch>='0'&&ch<='9'){x=(x<<1)+(x<<3)+(ch^48);ch=getchar();}
	return x*f;
}
int n,Q,len,cnt,l,r,lq=1;
LL res;
LL a[N],b[N];
LL CC[N];
int L[N],R[N],pos[N];
LL ans[N];
struct Query{
	int l,r,id;
	bool operator <(const Query &A) const{
		return pos[l]!=pos[A.l]?pos[l]<pos[A.l]:r<A.r;
	}
}q[N];
void add(int x){
	CC[a[x]]++;
	res=max(res,CC[a[x]]*b[a[x]]);
}
signed main(){
	n=read(),Q=read();
	for(int i=1;i<=n;i++) b[i]=a[i]=read();
	for(int i=1;i<=Q;i++){
		q[i].l=read(),q[i].r=read();
		q[i].id=i;
	}
	sort(b+1,b+1+n);
	int tot=unique(b+1,b+1+n)-b-1;
	for(int i=1;i<=n;i++) a[i]=lower_bound(b+1,b+1+tot,a[i])-b;
	len=sqrt(n),cnt=(n-1)/len+1;
	for(int i=1;i<=cnt;i++) L[i]=R[i-1]+1,R[i]=min(n,L[i]+len-1);
	for(int i=1;i<=n;i++) pos[i]=(i-1)/len+1;
	sort(q+1,q+1+Q);
	for(int i=1;i<=cnt;i++){
		memset(CC,0,sizeof(CC));
		r=R[i],res=0;
		while(pos[q[lq].l]==i){
			l=R[i]+1;
			if(q[lq].r-q[lq].l<=len){ //处理长度不超过 S 的情况 
				LL tmp=0;
				int c[N]={0};
				for(int j=q[lq].l;j<=q[lq].r;j++) c[a[j]]++;
				for(int j=q[lq].l;j<=q[lq].r;j++) tmp=max(tmp,c[a[j]]*b[a[j]]);
				for(int j=q[lq].l;j<=q[lq].r;j++) c[a[j]]--;
				ans[q[lq].id]=tmp;
				++lq;
				continue;
			}
			while(q[lq].r>r) add(++r);
			int last=res;
			while(l>q[lq].l) add(--l);
			ans[q[lq].id]=res;
			res=last;
			while(l<R[i]+1) --CC[a[l++]];  //回滚 
			++lq;
		}
	}
	for(int i=1;i<=Q;i++) printf("%lld\n",ans[i]);
	return 0;
}

例题:P4137 Rmq Problem / mex

可以发现删除操作十分简单,但是添加操作难以在 \(\mathcal{O}(1)\) 复杂度做到

考虑只减不加的回滚莫队
,先遍历整一个序列,算出整个序列的 \(mex\)
,与只加不减的回滚莫队相似,将询问的左端点所在的块从小到大,右端点从大到小排序

这样的话,对于每一个不同的块,右指针会从整个序列的末尾不断向左删除,同时左指针不断地回滚到块的左端点,处理完一个块的询问之后,左指针要滚到下一个块的左端点

这时,类似只加不减的回滚莫队,先暴力地把右指针遍历到整个序列的最右端,然后左指针向右删除到下一个块的左端点

复杂度是 \(O(n \sqrt n)\),证明略

#include<bits/stdc++.h>
#define N 200010
using namespace std;
void read(int &n){
	int f=1,x=0;
	char ch=getchar();
	while(ch<'0'||ch>'9'){if(ch=='-') f=-1;ch=getchar();}
	while(ch>='0'&&ch<='9'){x=(x<<1)+(x<<3)+(ch^48);ch=getchar();}
	n=x*f;}
int n,m,res,minn;
int a[N],pos[N],f[N],ans[N],cnt[N];
int L[N],R[N];
struct obj{
	int id,l,r;
	bool operator <(const obj &A) const{
		return pos[l]!=pos[A.l]?pos[l]<pos[A.l]:r>A.r;
	}
}q[N];
int main(){
	read(n),read(m);
	for(int i=1;i<=n;i++){
		read(a[i]);
		if(a[i]<=n) f[a[i]]++;  //大于n的点必然不会成为mex 
	} 
	while(f[minn]) minn++;
	for(int i=1;i<=m;i++){
		int l,r;
		read(l),read(r);
		q[i]=(obj){i,l,r};
	}
	int len=max(1,(int)(sqrt((double)n*n/m)));
	for(int i=1;i<=n;i++) pos[i]=(i-1)/len+1;
	for(int i=1;i<pos[n];i++) L[i]=R[i-1]+1,R[i]=i*len;
	L[pos[n]]=R[pos[n]-1]+1,R[pos[n]]=n;
	sort(q+1,q+1+m);
	for(int k=1,i=1;k<=pos[n];k++){
		int l=L[k],r=n,p=minn;
		for(int v;pos[q[i].l]==k;i++) {  //左端点在同一个块内 
			if(pos[q[i].r]==k){  
				int t=0;
				for(int j=q[i].l;j<=q[i].r;j++){
					if(a[j]<=n) cnt[a[j]]++;
				}
				while(cnt[t]) t++;
				ans[q[i].id]=t;
				for(int j=q[i].l;j<=q[i].r;j++){
					if(a[j]<=n) cnt[a[j]]--;
				}
			}else{
				while(r>q[i].r){
					if(a[r]<=n&&!(--f[a[r]])) p=min(p,a[r]);
					r--;
				}
				v=p;
				while(l<q[i].l){
					if(a[l]<=n&&!(--f[a[l]])) v=min(v,a[l]);
					l++;
				}
				ans[q[i].id]=v;
				while(l>L[k]){  
					l--;
					if(a[l]<=n) f[a[l]]++;
				}
			}
		}
		while(r<n){
			r++;
			if(a[r]<=n) f[a[r]]++;  //r滚回n 
		}
		while(l<L[k+1]){
			if(a[l]<=n&&!(--f[a[l]])) minn=min(minn,a[l]); //l滚到下一个块的左端点 
			l++;
		}
	}
	for(int i=1;i<=m;i++) printf("%d\n",ans[i]);
	return 0;
} 

树上莫队

一般的莫队只能处理线性问题,我们要把树强行压成序列。

我们这里有两种方法:

欧拉序

其实dfs序也能将树转为区间,但在这里不行,因为它无法处理 lca 的贡献。
dfs时,第一次到达该结点时记录一次,随后每访问完该结点的一棵子树返回时就再记录一次。
以下图举个例子:
图中欧拉序为 : 1 2 3 4 4 5 5 3 7 8 8 7 2 9 9 1

我们可以发现,如果我们要求一个关于子树的信息,在欧拉序中是连续的,例如 3 就将 4 和 5 包在里面。
而对于一条路径 \((x,y)\) 的话,我们也可以解决。
这里我们设 \(st[x]\) 表示 \(x\) 第一次出现在欧拉序中的时间, \(ed[x]\) 表示 \(x\) 第二次出现在欧拉序中的时间。
我们先假设 \(st[x]<st[y]\),分情况讨论:

  • 如果 \(lca(x,y)=x\),此时 \(x,y\) 在一条链上,那么在 \(st[x]\)\(st[y]\) 这段区间中,有一些点出现了两次,有一些点出现了一次,有一些点没有出现,对于我们来说,只有出现一次的点才是链上的点。例如\((1,5)\) 在区间上就是 \(1,2,3,4,4,5\)
  • 如果 \(lca(x,y) \neq x\),此时 \(x,y\) 在不同的子树中,那么我们需要统计区间 \(ed[x]\)\(st[y]\) 这段区间的点,统计方式如上,也是统计只出现一次的点。例如 \((4,8)\)在区间上就是 \(4,5,5,3,7,8\),但这样没有把 \(lca\) 给统计进去,那我们直接特判 \(lca\) ,将它统计进去即可。
    此处有个证明关于为什么只统计出现一次的点。

SP10707 COT2 - Count on a tree II

给定一个\(n\)个节点的树,每个节点表示一个整数,问\(u\)\(v\)的路径上有多少个不同的整数。
两次出现的点可以用 \(uesd[]\) 数组来表示又没有出现过,从而决定是 ADD 还是 DEL。
记得欧拉序长度是原来的2倍,预处理时不要忘记。

点击查看代码
#include<bits/stdc++.h>
#define N 200005
using namespace std;
int read(){
	int x = 0, f = 1; char ch = getchar();
	while(ch < '0' || ch > '9'){if(ch == '-') f = -f; ch = getchar();}
	while(ch >= '0' && ch <= '9'){x = (x << 1) + (x << 3) + (ch ^ 48); ch = getchar();}
	return x * f;
}
int n, m, block, tot, ola, tmp;
int a[N], b[N], pos[N], d[N], fa[N], id[N];
int ed[N], st[N], ans[N];
int mxs[N], top[N], f[N], sz[N];
bool used[N];
int Head[N], to[N], Next[N];
struct Query{
	int id, l, r, lca;
	bool operator < (const Query& A)const{return pos[l] == pos[A.l] ? r < A.r : pos[l] < pos[A.l];}
}q[N];
void add(int u, int v){
	to[++tot] = v, Next[tot] = Head[u], Head[u] = tot;
}
void lsh(){
	sort(b + 1, b + 1 + n);
	int nn = unique(b + 1, b + 1 + n) - b - 1;
	for(int i = 1; i <= n; ++i) a[i] = lower_bound(b + 1, b + 1 + nn, a[i]) - b;
}
void dfs1(int x, int ff){
	fa[x] = ff, sz[x] = 1, d[x] = d[ff] + 1;
	st[x] = ++ola, id[ola] = x;
	for(int i = Head[x]; i; i = Next[i]){
		int y = to[i]; if(d[y]) continue;
		dfs1(y, x); sz[x] += sz[y];
		if(sz[y] > sz[mxs[x]]) mxs[x] = y;
	}
	ed[x] = ++ola, id[ola] = x;
}
void dfs2(int x, int topf){
	top[x] = topf;
	if(!mxs[x]) return ;
	dfs2(mxs[x], topf);
	for(int i = Head[x]; i; i = Next[i]){
		int y = to[i]; if(top[y]) continue;
		dfs2(y, y);
	}
}
int GetLca(int x, int y){
	while(top[x] != top[y]){
		if(d[top[x]] < d[top[y]]) swap(x, y);
		x = fa[top[x]];
	}
	return d[x] < d[y] ? x : y;
}
void DealAsk(){
	for(int i = 1; i <= m; ++i){
		q[i].id = i;
		int x = read(), y = read();
		if(st[x] > st[y]) swap(x, y);
		int lca = GetLca(x, y);
		if(lca == x) q[i].l = st[x], q[i].r = st[y];
		else q[i].l = ed[x], q[i].r = st[y], q[i].lca = lca;
	}
}
void ADD(int x){return tmp += (!f[x]++), void();}
void DEL(int x){return tmp -= (!--f[x]), void();}
void Add(int x){
	used[x] ? DEL(a[x]) : ADD(a[x]); used[x] ^= 1;
}
void MC(){
	sort(q + 1, q + 1 + m);
	int l = 1, r = 0;
	for(int i = 1; i <= m; ++i){
		while(l < q[i].l) Add(id[l++]);
		while(r < q[i].r) Add(id[++r]);
		while(r > q[i].r) Add(id[r--]);
		while(l > q[i].l) Add(id[--l]);
		if(q[i].lca) Add(q[i].lca);
		ans[q[i].id] = tmp;
		if(q[i].lca) Add(q[i].lca); 
	}
}
signed main(){
	n = read(), m = read();
	block = sqrt(n);
	for(int i = 1; i <= n; ++i) b[i] = a[i] = read();
	for(int i = 1; i <= n * 2; ++i) pos[i] = (i - 1) / block + 1;
	lsh();
	for(int i = 1; i < n; ++i){
		int u = read(), v = read();
		add(u, v), add(v, u);
	}
	dfs1(1, 0), dfs2(1, 1);
	DealAsk(), MC();
	for(int i = 1; i <= m; ++i) printf("%d\n", ans[i]);
	return 0;
}

P4074 [WC2013] 糖果公园

题意

就是带修莫队+书上莫队

点击查看代码
#include<bits/stdc++.h>
#define N 200005
#define int long long
using namespace std;
int read(){
	int x = 0, f = 1; char ch = getchar();
	while(ch < '0' || ch > '9'){if(ch == '-') f = -f; ch = getchar();}
	while(ch >= '0' && ch <= '9'){x = (x << 1) + (x << 3) + (ch ^ 48); ch = getchar();}
	return x * f;
}
int n, m, Q, tot, ola, cntt, cntq, block, tmp;
int V[N], W[N], C[N], f[N][20], id[N];
int Head[N], to[N], Next[N];
int st[N], ed[N], d[N], pos[N];
int ans[N], p[N], c[N], tt[N];
bool used[N];
struct Query{
	int l, r, t, lca, id;
	bool operator < (const Query& A) const{return pos[l] == pos[A.l] ? (pos[r] == pos[A.r] ? t < A.t : r < A.r) : l < A.l;}
}q[N];
void add(int u, int v){
	to[++tot] = v, Next[tot] = Head[u], Head[u] = tot; 
}
void dfs(int x){
	d[x] = d[f[x][0]] + 1, st[x] = ++ola, id[ola] = x;
	for(int i = 1; i < 20; ++i) f[x][i] = f[f[x][i - 1]][i - 1];
	for(int i = Head[x], y; i; i = Next[i]) if((y = to[i]) != f[x][0]) f[y][0] = x, dfs(y);
	ed[x] = ++ola, id[ola] = x;
}
int GetLca(int x, int y){
	if(d[x] > d[y]) swap(x, y);
	for(int i = 19; i >= 0; --i) if(d[f[y][i]] >= d[x]) y = f[y][i];
	if(x == y) return x;
	for(int i = 19; i >= 0; --i) if(f[y][i] != f[x][i]) y = f[y][i], x = f[x][i];
	return f[x][0];
}
void Add(int x){
	used[x] ? tmp -= V[C[x]] * W[tt[C[x]]--] : tmp += V[C[x]] * W[++tt[C[x]]];
	used[x] ^= 1;
}
void work(int t){
	if(used[p[t]]){
		Add(p[t]);
		swap(C[p[t]], c[t]);
		Add(p[t]);
	}else swap(C[p[t]], c[t]);
}
signed main(){
	n = read(), m = read(), Q = read();
	for(int i = 1; i <= m; ++i) V[i] = read();
	for(int i = 1; i <= n; ++i) W[i] = read();
	for(int i = 1; i < n; ++i){
		int u = read(), v = read();
		add(u, v), add(v, u);
	}
	dfs(1);
	for(int i = 1; i <= n; ++i) C[i] = read();
	for(int i = 1; i <= Q; ++i){
		int opt = read(), x = read(), y = read();
		if(opt){
			if(st[x] > st[y]) swap(x, y);
			int lca = GetLca(x, y);
			q[++cntq].id = cntq, q[cntq].t = cntt;
			if(lca == x) q[cntq].l = st[x], q[cntq].r = st[y];
			else q[cntq].l = ed[x], q[cntq].r = st[y], q[cntq].lca = lca;
		}else p[++cntt] = x, c[cntt] = y;
	}
	block = pow(2.0 * n, 2.0 / 3.0);
	for(int i = 1; i <= 2 * n; ++i) pos[i] = (i - 1) / block + 1;
	sort(q + 1, q + 1 + cntq);
	int l = 1, r = 0, t = 0;
	for(int i = 1; i <= cntq; ++i){
		while(l < q[i].l) Add(id[l++]);
		while(r < q[i].r) Add(id[++r]);
		while(r > q[i].r) Add(id[r--]);
		while(l > q[i].l) Add(id[--l]);
		while(t < q[i].t) work(++t);
		while(t > q[i].t) work(t--);
		if(q[i].lca) Add(q[i].lca);
		ans[q[i].id] = tmp;
		if(q[i].lca) Add(q[i].lca);
	}
	for(int i = 1; i <= cntq; ++i) printf("%lld\n", ans[i]);
	return 0;
}

树分块

二次离线莫队

参考资料:OI Wiki