疯狂动物城 题解

发布时间 2023-08-11 17:28:56作者: TKXZ133

疯狂动物城

题目大意

给定一颗 \(n\) 个点的树,第 \(i\) 个点的点权为 \(a_i\),需要维护三种操作:

    1. \(x\)\(y\) 路径加 \(c\)
    1. 查询 \(x\)\(y\) 的路径权值,一条路径的权值被定义为

\[\sum_{i\in\text{road}(x,y)}a_i\times \frac{dis(i,y)(dis(i,y)+1)}{2} \]

其中,\(\text{road}(x,y)\) 表示由 \(x\)\(y\) 的路径上的点构成的集合,\(dis(x,y)\) 表示 \(x\)\(y\) 的距离。

    1. 回退到某一版本。

思路分析

(以下记 \(l=\text{lca}(x,y)\)\(dep_i\) 表示点 \(i\) 在以 \(1\) 为根时的深度)

如果只看 \(1,3\) 操作,我们可以用树链剖分套可持久化线段树比较轻松的完成,因此主要看 \(2\) 操作。

我们发现 \(a_i\times dis\) 不好处理,考虑将其转化为 \(a_i\times dep\) 的形式:

首先可以将 \(x\)\(y\) 的路径拆分成两部分,即 \(x\)\(l\) (包含 \(l\))和 \(l\)\(y\)(不包含 \(l\)),分别记这两部分为 \(up\)\(down\)

那么 \(dis(i,y),i\in up\) 就等于 \(dep_i+dep_y-2dep_{l}\),同理,\(dis(i,y),i\in down\) 就等于 \(dep_y-dep_i\)

所以现在式子变成了这样:

\[\frac{1}{2}\Bigg(\sum_{i\in up}a_i(dep_i+dep_y-2dep_l)(dep_i+dep_y-2dep_l+1)\Bigg)+\frac{1}{2}\Bigg(\sum_{i\in down} a_i(dep_y-dep_i)(dep_y-dep_i+1)\Bigg) \]

再化简一下:

\[\begin{aligned}\frac{1}{2}\Bigg(&(dep_y^2+4dep_l^2-4dep_ydep_l+dep_y-2dep_l)\sum_{i\in up}a_i+(2dep_y-4dep_l+1)\sum_{i\in up}dep_ia_i+\sum_{i\in up}dep_i^2a_i+\\&(dep_y^2+dep_y)\sum_{i\in down}a_i+(-2dep_y-1)\sum_{i\in down}dep_ia_i+\sum_{i\in down} dep_i^2a_i\Bigg)\end{aligned} \]

这样不太好看,设 \(c_1=dep_y-2dep_l,c_2=dep_y\),再因式分解一下就可以得到一个比较好看的式子:

\[\frac{1}{2}\Bigg((c_1^2+c_1)\sum_{i\in up}a_i+(2c_1-1)\sum_{i\in up}dep_ia_i+\sum_{i\in up}dep_i^2a_i+(c_2^2+c_2)\sum_{i\in down}a_i-(2c_2-1)\sum_{i\in down}dep_ia_i+\sum_{i\in down}dep_i^2a_i\Bigg) \]

我们发现,\(c_1,c_2\) 都是常数,可以 \(O(1)\) 计算,我们只需要在线段树上维护 \(\sum a,\sum dep\times a,\sum dep^2\times a\) 就可以用树剖套线段树维护链查询和链加。

具体的说,当区间加 \(k\) 时,观察三者的变化:

\[\begin{cases}\sum a\to\sum (a+k)=\sum a+k\sum 1\\\sum dep\times a\to\sum dep\times (a+k)=\sum dep\times a+k\sum dep\\\sum dep^2\times a\to \sum dep^2\times (a+k)=\sum dep^2\times a+k\sum dep^2\end{cases} \]

因此只需要额外再维护一个 \(dep\)\(dep^2\) 的前缀和就可以做到 \(O(1)\) 更新区间,而区间查询只需要将所有值累加即可。

因为要支持回退版本,所以我们的线段树需要用可持久化线段树,但这会带来一个新的问题,当进行一次区间加时,我们会遍历到线段树上的 \(O(\log n)\) 个区间,如果每一次都新建节点并下放区间加的懒标记,考虑到我们需要维护一堆东西,空间复杂度直接爆炸。

因此,我们可以采用标记永久化的方法,区间加时对每一个访问到的区间直接更新区间值,在当前区间被完全覆盖时直接给当前区间打上标记并停止,不进行下放懒标记操作和上传区间值操作;区间查询时对每个访问到的区间累加标记,在当前区间被完全覆盖时直接累加上当前节点的值。

我们发现,如果这样做的话,在区间加和区间查询时都不需要新建任何节点,只有生成一个新版本时需要消耗 \(O(\log n)\) 的空间,而且常数较小,可以通过本题。

总时间复杂度为 \(O(n\log ^2n)\),空间复杂度为 \(O(n\log n)\)

代码

(只有 4.8k)

#include <algorithm>
#include <iostream>
#include <cstring>
#include <cstdio>
#include <cmath>

using namespace std;
const int N=200100,mod=20160501;
typedef long long ll;
#define mid ((l+r)>>1)
#define D(l,r) (d[r]-d[l-1]+mod)//dep 的前缀和
#define D2(l,r) (d2[r]-d2[l-1]+mod)//dep^2 的前缀和
#define Empty (PSTn{0,0,0,0,{0,0},{0,0}})
#define ls a[p].ch[0]
#define rs a[p].ch[1]

int n,m,idx=1,op,in1,in2,in3,cnt,nrt,num,tot;
int to[N],nxt[N],head[N],rt[N];
int dfn[N],rnk[N],dep[N],siz[N],fa[N],son[N],top[N];
ll d[N],d2[N],w[N],lastans;

void add(int u,int v){
    idx++;to[idx]=v;nxt[idx]=head[u];head[u]=idx;
}

struct PSTn{
    ll sa,sda,sd2a,tag;
    int ch[2];
    bool v[2];//v 表示是否需要新建子节点
};
struct PST{
    PSTn a[N<<5];
    void merge(PSTn &p,PSTn a,PSTn b){//将两个区间值累加进一个区间
        p.sa=(a.sa+b.sa)%mod;
        p.sda=(a.sda+b.sda)%mod;
        p.sd2a=(a.sd2a+b.sd2a)%mod;
    }
    void add_t(PSTn &p,int l,int r,ll k){//更新一个节点的值
        p.sa=(p.sa+k*(r-l+1)%mod)%mod;
        p.sda=(p.sda+k*D(l,r)%mod)%mod;
        p.sd2a=(p.sd2a+k*D2(l,r)%mod)%mod;
    }
    void built(int p,int f){//判断是否需要新建子节点
        if(!a[p].v[f]) return ;
        a[++tot]=a[a[p].ch[f]];a[p].ch[f]=tot;
        a[a[p].ch[f]].v[0]=a[a[p].ch[f]].v[1]=1;
        a[p].v[f]=0;
    }
    void build(int &p,int l,int r){
        p=++tot;
        if(l==r){
            int u=rnk[l];
            a[p].sa=w[u];
            a[p].sda=1ll*w[u]*dep[u]%mod;
            a[p].sd2a=1ll*(1ll*w[u]*dep[u]%mod)*dep[u]%mod;
            return ;//初始化
        }
        build(ls,l,mid);build(rs,mid+1,r);
        merge(a[p],a[ls],a[rs]);
    }
    void add(int p,int l,int r,int x,int y,ll k){
        add_t(a[p],max(l,x),min(r,y),k);
        if(x<=l&&r<=y){a[p].tag+=k;return ;}
        if(y<=mid){built(p,0),add(ls,l,mid,x,y,k);return ;}
        if(x>mid){built(p,1),add(rs,mid+1,r,x,y,k);return ;}
        built(p,0);built(p,1);add(ls,l,mid,x,y,k);add(rs,mid+1,r,x,y,k);
    }
    PSTn ask(int p,int l,int r,int x,int y,PSTn now){
        if(x<=l&&r<=y){merge(now,now,a[p]);return now;}//合并当前区间信息
        add_t(now,max(l,x),min(r,y),a[p].tag);//以此区间的标记对其进行更新
        if(y<=mid) return ask(ls,l,mid,x,y,now);
        if(x>mid) return ask(rs,mid+1,r,x,y,now);
        return ask(ls,l,mid,x,y,ask(rs,mid+1,r,x,y,now));
    }
}tree;

void dfs_1(int s,int gr){
    dep[s]=dep[gr]+1;
    siz[s]=1;fa[s]=gr;
    for(int i=head[s];i;i=nxt[i]){
        int v=to[i];
        if(v==gr) continue;
        dfs_1(v,s);
        siz[s]+=siz[v];
        if(siz[son[s]]<siz[v]) son[s]=v;
    }
}

void dfs_2(int s,int tp){
    top[s]=tp;dfn[s]=++cnt;rnk[cnt]=s;
    if(!son[s]) return ;
    dfs_2(son[s],tp);
    for(int i=head[s];i;i=nxt[i]){
        int v=to[i];
        if(v==fa[s]||v==son[s]) continue;
        dfs_2(v,v);
    }
}

int LCA(int x,int y){
    while(top[x]!=top[y]){
        if(dep[top[x]]<dep[top[y]]) swap(x,y);
        x=fa[top[x]];
    }
    return dep[x]>dep[y]?y:x;
}

void add_all(int x,int y,int k){
    while(top[x]!=top[y]){
        if(dep[top[x]]<dep[top[y]]) swap(x,y);
        tree.add(rt[nrt],1,n,dfn[top[x]],dfn[x],k);
        x=fa[top[x]];
    }
    tree.add(rt[nrt],1,n,min(dfn[x],dfn[y]),max(dfn[x],dfn[y]),k);
}

void update(ll &ans,int &x,ll c,int sign,int l,int r){//更新答案
    PSTn now=tree.ask(rt[nrt],1,n,dfn[l],dfn[r],Empty);
    ans=(ans+now.sa*(c*c%mod+c)%mod)%mod;
    ans=(ans+sign*now.sda*(2*c+1)%mod+now.sd2a+mod)%mod;
}

ll ask_all(int x,int y){
    int lca=LCA(x,y);
    ll ans=0,c1=(dep[y]-2*dep[lca]+mod)%mod,c2=dep[y];//计算常数
    while(top[x]!=top[y]){
        if(dep[top[x]]>dep[top[y]]){update(ans,x,c1,1,top[x],x);x=fa[top[x]];}
        else{update(ans,y,c2,-1,top[y],y);y=fa[top[y]];}
    }
    if(dep[x]>=dep[y]) update(ans,x,c1,1,y,x);
    else update(ans,y,c2,-1,x,y);
    return ans*((mod+1)/2)%mod;//最后要乘 2 的逆元
}

int main(){
    scanf("%d%d",&n,&m);
    for(int i=1;i<n;i++){
        scanf("%d%d",&in1,&in2);
        add(in1,in2);add(in2,in1);
    }
    for(int i=1;i<=n;i++) scanf("%lld",&w[i]);
    dfs_1(1,0);dfs_2(1,1);
    for(int i=1;i<=n;i++){
        d[i]=(d[i-1]+dep[rnk[i]])%mod;//注意,dep 的前缀和是在 dfs 序上累加的
        d2[i]=(d2[i-1]+1ll*dep[rnk[i]]*dep[rnk[i]]%mod)%mod;
    }
    tree.build(rt[0],1,n);
    while(m--){
        scanf("%d",&op);
        if(op==1){
            scanf("%d%d%d",&in1,&in2,&in3);
            in1^=lastans;in2^=lastans;
            rt[++num]=++tot;//生成一个新版本
            tree.a[rt[num]]=tree.a[rt[nrt]];//继承上一个版本
            nrt=num;
            tree.a[rt[nrt]].v[0]=tree.a[rt[nrt]].v[1]=1;//子节点需要新建
            add_all(in1,in2,in3);
        }
        if(op==2){
            scanf("%d%d",&in1,&in2);
            in1^=lastans;in2^=lastans;
            cout<<(lastans=ask_all(in1,in2))<<'\n';
        }
        if(op==3){
            scanf("%d",&in1);
            in1^=lastans;
            nrt=in1;
        }
    }
    return 0;
}