luogu P3345 [ZJOI2015]幻想乡战略游戏

发布时间 2023-05-10 09:54:48作者: 谭皓猿

P3345 [ZJOI2015]幻想乡战略游戏

这道题还是比较有意思的,做了一个比较长的时间,但是点分树实在是太毒瘤了,所以记录一下线段树的做法。

题面

给一棵树,有边权,每次修改一个点的点权,修改完后输出所有点到这棵树的带权重心的贡献,即\(\sum dis_i\times val_i\)

题解

考虑朴素的暴力找重心,直接在无权重心的基础上将树的大小改为树的权值大小就可以了,正确性还是比较显然的。
想一想怎么优化寻找重心的这一个过程。
我们会发现一个神奇的东西,一个点越接近重心,它的答案就越小
这有什么用呢?这样我们可以像移动指针一样移动根,每移动一步答案都会减少,一直移动到不能移动的点就是重心了。
可以移动到一个点\(x\)当且仅当\(sz_{rt} \leq sz_x\times 2\),这里的\(sz_x\)是子树的点权和。
怎么证明呢?
从一个点\(x\)走到点\(y\),答案的变化为\(e_{x,y}\times(sz_{rt}-sz_y-sz_y)\)
要想答案缩小,那么就必须满足\(sz_{rt} \leq sz_x\times 2\)
怎么做有什么好处呢?一个点是否走只和rt有关系。
那这样我们就可以支持快速移动了,要找的就是满足\(sz_{rt} \leq sz_x\times 2\)且不能移动的点。
我们要怎么找呢?接下来就是一个很妙的操作了。
一个点可以移动是既可以移动到子树外也可以移动到子树内的,这对于我们的维护来说是很麻烦的。
我们直接以\(1\)为初始根,这样因为答案的修改有单调性,只可能移动到子树内,而不可能移动到子树之外。
接下来问题就转化成了我们要寻找满足\(sz_{rt} \leq sz_x\times 2\)且深度最深的点。
这是一个很\(easy\)的操作,只需要在线段树上维护一个最大值,在满足条件的区间上贪心地往右走就行了。
关于修改,只需要在修改时将到\(1\)的路径上全部一起修改了即可。
这样我们就完成了找到重心的操作就做完了,这下考虑怎么统计答案
设当前重心为\(x\),发现这其实是一个点对的贡献,即\(\sum dis(x,y)*val_y\)
画个图,发现换根之后\(x\)的子树距离不变,\(x\)之上的要删去\(lca(x,y)\)的贡献然后再连接到\(x\)上,这样统计答案其实是一件很困难的事情。
我们不妨换一个思路,根据我们一般做树上差分时会用到这样一个式子\(dis(x,y)=dis(x,rt)+dis(y,rt)-2\times dis(lca,rt)\)
把它套到这题来\(dis(x,rt)\times val_y+dis(y,rt)\times val_y-2\times dis(lca_{x,y},rt)\times val_y\)

这个式子的前两项是容易的,关于第三项,我们会发现实际上只有\(x\)\(rt\)所在的子树会产生一个减的贡献,跨越\(rt\)的子树的\(lca\)都是\(rt\),是无意义的。
那么我们我们只需要维护子树的权值与一条边边权之积,也就是在线段树上维护\(sz_i\times w_i,w_i\)是边权。
因为\(x\)所在\(rt\)的子树内的点到\(x\)\(lca\),都是\(x\)\(rt\)路径上的点,我们就直接遍历\(lca\)统计答案。
统计\(x\)到路径上的节点的\(\sum w_i\times sz_i\),这样对于子树内每一个点都都统计了\(dis_{lca_{x,y},rt}\)的贡献,正确性可以分两类讨论一下就好了,这里略。

code

#include <bits/stdc++.h>
#define int long long
#define rg register
#define pc putchar
#define gc getchar
#define pf printf
#define space pc(' ')
#define enter pc('\n')
#define me(x,y) memset(x,y,sizeof(x))
#define pb push_back
#define FOR(i,k,t,p) for(rg int i(k) ; i <= t ; i += p)
#define ROF(i,k,t,p) for(rg int i(k) ; i >= t ; i -= p)
using namespace std ;
bool s_gnd ;
inline void read(){}
template<typename T,typename ...T_>
inline void read(T &x,T_&...p)
{
    x = 0 ;rg int f(0) ; rg char c(gc()) ;
    while(!isdigit(c)) f |= (c=='-'),c = gc() ;
    while(isdigit(c)) x = (x<<1)+(x<<3)+(c^48),c = gc() ;
    x = (f?-x:x) ;
    read(p...);
}
int stk[30],tp ;
inline void print(){}
template<typename T,typename ...T_>
inline void print(T x,T_...p)
{
    if(x < 0) pc('-'),x = -x ;
    do stk[++tp] = x%10,x /= 10 ; while(x) ;
    while(tp) pc(stk[tp--]^48) ; space ;
    print(p...) ;
}
const int N = 2e5+5 ;
int n,q,top,sum1,sum2 ;
int dep[N],fa[N],dis[N],wh[N] ;
int sz[N],vis[N],eal[N],son[N],pe[N],id[N] ;
struct Edge{int v,w ;} ;
struct Node{int l,r,f,val,w,si ;}tr[N<<2] ;
vector<Edge>e[N] ;
bool S_GND ;
void Dfs1(int x)
{
    vis[x] = sz[x] = 1 ;
    for(auto [v,w]:e[x]) if(!vis[v])
    {
        dis[v] = dis[x]+w,Dfs1(v),sz[x] += sz[v] ;
        if(sz[v] > sz[son[x]]) son[x] = v ; eal[v] = w,fa[v] = x ;
    }
}
void Dfs2(int x,int tops)
{
    vis[x] = 1,id[x] = ++top,wh[top] = x,pe[x] = tops ;
    if(!son[x]) return ; Dfs2(son[x],tops) ;
    for(auto [v,w]:e[x]) if(!vis[v] && v != son[x]) Dfs2(v,v) ;
}
#define f(x) tr[x].f
#define w(x) tr[x].w
#define si(x) tr[x].si
#define val(x) tr[x].val
#define ls(x) (x<<1)
#define rs(x) (x<<1|1)
#define mid (tr[x].l+tr[x].r>>1)
inline void update(int x)
{
    w(x) = w(ls(x))+w(rs(x)) ;
    si(x) = max(si(ls(x)),si(rs(x))) ;
}
inline void pushup(int x)
{
    if(!f(x)) return ;
    f(ls(x)) += f(x),f(rs(x)) += f(x),si(ls(x)) += f(x),si(rs(x)) += f(x) ;
    w(ls(x)) += f(x)*val(ls(x)),w(rs(x)) += f(x)*val(rs(x)),f(x) = 0 ;
}
void Build(int x,int le,int ri)
{
    tr[x].l = le,tr[x].r = ri ; if(le == ri) {val(x) = eal[wh[le]] ; return ;}
    Build(ls(x),le,mid),Build(rs(x),mid+1,ri),val(x) = val(ls(x))+val(rs(x)) ;
}
void Modify(int x,int le,int ri,int k)
{
    if(tr[x].l >= le && tr[x].r <= ri)
    {
        si(x) += k,f(x) += k,w(x) += val(x)*k ;
        return ;
    } pushup(x) ;
    if(le <= mid) Modify(ls(x),le,ri,k) ; 
    if(mid < ri) Modify(rs(x),le,ri,k) ; update(x) ;
}
int Query(int x,int le,int ri)
{
    if(tr[x].l >= le && tr[x].r <= ri) return w(x) ; pushup(x) ;
    int res = 0 ; if(le <= mid) res = Query(ls(x),le,ri) ; if(mid < ri) res += Query(rs(x),le,ri) ; update(x) ;
    return res ; 
}
void TMY(int u,int k)
{
    while(pe[u] != 1) Modify(1,id[pe[u]],id[u],k),u = fa[pe[u]] ;
    Modify(1,1,id[u],k) ;
}
int TQY(int u)
{
    int res = sum1+sum2*dis[u] ;
    while(pe[u] != 1) res -= 2*Query(1,id[pe[u]],id[u]),u = fa[pe[u]] ;
    res -= 2*Query(1,1,id[u]) ; return res ;
}
int get_rt(int x)
{
    if(tr[x].l == tr[x].r) return wh[tr[x].l] ; pushup(x) ;
    return si(rs(x))*2 >= si(1)?get_rt(rs(x)):get_rt(ls(x)) ;
}
signed main()
{
//cerr<<(double)(&s_gnd-&S_GND)/1024.0/1024.0 ;
//	freopen(".in","r",stdin) ;
//	freopen(".out","w",stdout) ;
    read(n,q) ;
    FOR(i,2,n,1)
    {
        int u,v,w ; read(u,v,w) ;
        e[u].pb({v,w}),e[v].pb({u,w}) ;
    }
    Dfs1(1),me(vis,0),Dfs2(1,1),Build(1,1,n) ;
    while(q--)
    {
        int x,y ; read(x,y) ;
        sum1 += dis[x]*y,sum2 += y,TMY(x,y) ;
        print(TQY(get_rt(1))),enter ;     
        // print(get_rt(1)),enter ;
    }
    return 0 ;
}