BZOJ #3784. 树上的路径

发布时间 2023-07-14 20:24:46作者: 谭皓猿

BZOJ #3784. 树上的路径

题意

给一颗树,求所有路径长度中前 \(k\) 大。

题解

首先对于前 \(k\) 大,我们有一个常见的方法,二分。
二分第 \(k\) 大的路径长度,然后使用点分治统计,点分治内部还要二分,所以时间复杂度 \(O(nolg^3n)\)
二分显然是行不通了,想一下就会发现外层和内层的二分都不好去掉。
我们考虑求前 \(k\) 小要怎么做,显然可以将所有的边丢到优先队列中,然后取出队头,扩展然后又丢回去。
但是由于总方案数太多,我们无法通过取小的将大的推出来,所以我们考虑修改一下这个做法,超级钢琴这道题就是一个好方法。
但是扩展到树上不太方便,我们考虑一个和前缀和类似的方法。
已知 \(dis_{x,y} = dis_{x,lca}+dis_{lca,y}\) ,然后我们可以枚举 \(x\) ,和 \(lca\),然后再通过 \(dfn\)转化为区间,记录 \(x,lca\) 拆两个元组然后丢进去。
注意到枚举 \(lca\) 这个东西是和深度有关的,所以我们可以使用点分治来枚举 \(lca\)
然后对于每一个分治点,将其子树全部记录下来,用类似 \(dfn\) 的方式,然后我们会发现每一个 \(x\),对应的都是一个区间。
然后就和超级钢琴一模一样了,记录点数是 \(nlogn\) 级别的,然后算上优先队列,时间复杂度 \(O(nlog^2n)\)

code

点击查看代码
#include <bits/stdc++.h>
#define rg register
#define pc putchar
#define gc getchar
#define pf printf
#define space pc(' ')
#define enter pc('\n')
#define me(x,y) memset(x,y,sizeof(x))
#define pb push_back
#define FOR(i,k,t,p) for(rg int i(k) ; i <= t ; i += p)
#define ROF(i,k,t,p) for(rg int i(k) ; i >= t ; i -= p)
using namespace std ;
inline void read(){}
template<typename T,typename ...T_>
inline void read(T &x,T_&...p)
{
    x = 0 ;rg int f(0) ; rg char c(gc()) ;
    while(!isdigit(c)) f |= (c=='-'),c = gc() ;
    while(isdigit(c)) x = (x<<1)+(x<<3)+(c^48),c = gc() ;
    x = (f?-x:x) ;
    read(p...);
}
int stk[30],tp ;
inline void print(){}
template<typename T,typename ...T_>
inline void print(T x,T_...p)
{
    if(x < 0) pc('-'),x = -x ;
    do stk[++tp] = x%10,x /= 10 ; while(x) ;
    while(tp) pc(stk[tp--]^48) ; space ;
    print(p...) ;
}
const int N = 1e6+5 ;
int n,m,sum,top,le,ri,rt ;
int L[N],R[N] ;
int sz[N],vis[N],a[N],mx[N][30],val[N] ;
struct Edge{int v,w ;} ;
vector<Edge>po[N],e[N] ;
void Get_sz(int x,int fa)
{
    sz[x] = 1 ;
    for(auto [v,w]:e[x]) if(!vis[v] && v != fa)
        Get_sz(v,x),sz[x] += sz[v] ;
}
void Get_rt(int x,int fa)
{
    sz[x] = 1 ; int mix = 0 ;
    for(auto [v,w]:e[x]) if(!vis[v] && v != fa)
        Get_rt(v,x),sz[x] += sz[v],mix = max(sz[v],mix) ;
    mix = max(mix,sum-sz[x]) ;
    if(mix <= sum/2) rt = x ;    
}
void Get(int x,int fa,int dis)
{
    val[++top] = dis,L[top] = le,R[top] = ri ;
    for(auto [v,w]:e[x]) if(!vis[v] && v != fa) Get(v,x,dis+w) ;
}
void Solve(int x)
{
    vis[x] = 1,le = ++top ;
    for(auto [v,w]:e[x]) if(!vis[v])
        ri = top,Get(v,x,w) ;
    for(auto [v,w]:e[x]) if(!vis[v])
        rt = 0,Get_sz(v,x),sum = sz[v],Get_rt(v,x),Solve(rt) ;
}
inline int Max(int x,int y){return val[x]>val[y]?x:y ;}
struct Node
{
    int l,r,posi,posj ;
    bool operator < (const Node &A) const
    {
        return val[posi]+val[posj] < val[A.posi]+val[A.posj] ;
    }
} ;
priority_queue<Node>q ;
inline int Query(int l,int r)
{
    // print(l,r),enter ;
    int k = log2(r-l+1) ;
    return Max(mx[l][k],mx[r-(1<<k)+1][k]) ;
}
signed main()
{
//	freopen(".in","r",stdin) ;
//	freopen(".out","w",stdout) ;
    read(n,m) ;
    FOR(i,2,n,1)
    {
        int u,v,w ; read(u,v,w) ;
        e[u].pb({v,w}),e[v].pb({u,w}) ;
    }
    sz[rt = 0] = n+1,sum = n,Get_rt(1,0),Solve(rt) ;
    FOR(i,1,top,1) mx[i][0] = i ;
    FOR(j,1,20,1) for(int i = 1 ; i+(1<<j)-1 <= top ; ++i)
        mx[i][j] = Max(mx[i][j-1],mx[i+(1<<j-1)][j-1]) ;
    FOR(i,1,top,1) if(L[i]) q.push({L[i],R[i],i,Query(L[i],R[i])}) ;
    // FOR(i,1,top,1) print(val[i]) ; print(top),cerr<<"!",enter ;
    // FOR(i,1,top,1) if(L[i]) print(L[i],R[i],i,Query(L[i],R[i])),enter ;
    // FOR(i,1,top,1) print(L[i],R[i]),enter ;
    FOR(i,1,m,1)
    {
        auto [l,r,posi,posj] = q.top() ; q.pop() ;
        print(val[posi]+val[posj]),enter ;
        if(posj > l) q.push({l,posj-1,posi,Query(l,posj-1)}) ;
        if(posj < r) q.push({posj+1,r,posi,Query(posj+1,r)}) ;
    }
    return 0 ;
}

...

求前 \(k\) 大,这种类型的题,贪心和二分都是可以使用的技巧,但是要注意数据规模。
看到统计路径长度我们可以往中转点和点分治的方向去想。