CF1868C

发布时间 2023-09-12 16:22:09作者: Linxrain

问题链接

题意:\(n\)个点,每个点的点权在\([1,m]\)之间,求所有方案的所有路径的最大值的总和

首先,对于一条长度为\(x\)的路径,设它的贡献为\(pre_x\),他的最大值取值有\(m\)种,其中最大值为\(i\)的取值有\(i^x-i^{x-1}\)种,而除了该路径外的所有点的取值一共能构造出\(m^{n-x}\)种方案,那么可以得出

\[pre _ x = m ^ { n - x } \sum _ { i = 1 } ^ m ( i ^ x - i ^ { x - 1 } ) \cdot i \]

接着我们可以预处理出深度为\(y\)子树,经过根节点的贡献\(w_y\)和子树总贡献\(sub_i\)

\[w _ y = \sum _ { i = 1 } ^ { y } \sum _ { j = 1 } ^ { y } 2 ^ { i + j - 2 } \times pre _ { i + j + 1 } + 2 \times \sum _ { i = 1 } ^ y 2 ^ { i - 1 } \times pre _ { i + 1 }\\ \]

\[sub_i = 2 \times sub_{i-1} + w_i \]

最后对所有节点进行类似线段树的遍历,如果遍历到的节点是一棵完全二叉树,那么直接统计答案,否则继续向下传递,并单独统计以自己为根节点的子树中经过自己的路径贡献和即可

总复杂度\(mlogn+log^3n\)

#include<bits/stdc++.h>
#define ll long long
using namespace std;
const int p=998244353;
ll cnt[65],pre[100005];
ll n,m,ans;
int now;
int qpow(int x,ll y){
    if(y<0)return 0;
    int res=1;
    while(y){
        if(y&1)res=(ll)res*x%p;
        x=(ll)x*x%p;
        y>>=1;
    }
    return res;
}
void init(){
    cnt[0]=(1+m)*m/2%p*qpow(m,n-1)%p;
    for(int i=1;i<=now*2+1;i++){
        pre[i]=0;
        for(int j=1;j<=m;j++){
            pre[i]=(pre[i]+(ll)(qpow(j,i)+p-qpow(j-1,i))%p*j%p)%p;
        }
        pre[i]=pre[i]*qpow(m,n-i)%p;
    }
    for(int i=1;i<=now+1;i++){
        cnt[i]=cnt[i-1]*2%p;
        for(int j=0;j<=i;j++)
            for(int k=0;k<=i;k++){
                ll prej,prek;
                prej=1ll<<max(0,j-1);
                prek=1ll<<max(0,k-1);
                cnt[i]=(cnt[i]+(pre[j+k+1])*(prej%p)%p*(prek%p))%p;
            }
    }
}
void dfs(ll u,ll l,ll r,int x){
    ll mid=l+r>>1;
    if(r<=n){
        ans=(ans+cnt[x])%p;
        return;
    }
    if(l>n){
        if(x)ans=(ans+cnt[x-1])%p;
        return;
    }
    dfs(u*2,l,mid,x-1);
    dfs(u*2+1,mid+1,r,x-1);
    ans=(ans+cnt[x-1]-(x-2>=0?cnt[x-2]:0)*2+p)%p;
    ll s=n-l+1;
    for(int i=0;i<x;i++){
        ll prei=1ll<<max(0,i-1);
        ans=(ans+pre[x+i+1]*(s%p)%p*(prei%p)%p)%p;
    }
    if(n>mid){
        ll le,ri;
        ri=n-mid;
        le=mid-l+1;
        le%=p;
        ri%=p;
        ans=(ans+pre[2*x+1]*le%p*ri%p)%p;
    }
}
void solve(){
    ans=0;
    now=0;
    scanf("%lld%lld",&n,&m);
    while(n>(1ll<<now+1)-1)now++;
    init();
    dfs(1,1ll<<now,(1ll<<now+1)-1,now);
    printf("%lld\n",ans);
}
int main(){
    int t=1;
    scanf("%d",&t);
    while(t--)solve();
    return 0;
}