CF932E Team Work 题解

发布时间 2023-08-08 11:22:25作者: CCComfy

CF932E Team Work 题解

题目链接CF932E

题面翻译

给定 $ n,k $,求:

\[\sum_{i=1}^n\binom n i \times i^k \]

$ 1 \leq k \leq 5000,1 \leq n \leq 10^9 $

思路

看一眼数据范围,\(n\leq 1e9\) 不能直接枚举\(i\)\(n\)
但是\(k\)的范围较小,可以\(k^2\)做,考虑对原式进行转化,不断期望消去枚举\(i\)的影响

\[\begin{aligned} \sum_{i=1}^n\mathrm{C}_n^i\times i^k & = \sum_{i=1}^n\mathrm{C}_n^i\times \sum_{j=0}^k\{_j^k\}\times i^{\underline{k}}\\ & = \sum_{i=1}^n\mathrm{C}_n^i\times \sum_{j=0}^k\{_j^k\} \times j!\mathrm{C}_i^j \\ & = \sum_{i=1}^n\dfrac{n!}{i!\times (n-i)!} \times \sum_{j=0}^k\{_j^k\}\times j!\dfrac{i!}{j!\times (i-j)!} \\ & = \sum_{j=0}^k\{_j^k\}\sum_{i=0}^n\dfrac{n!}{(n-i)!(i-j)!} \\ & = \sum_{j=0}^k\{_j^k\}\dfrac{n!}{(n-j)!}\sum_{i=0}^n\times \dfrac{(n-j)!}{(n-i)!(i-j)!} \\ & = \sum_{j=0}^k\{_j^k\}\dfrac{n!}{(n-j)!}\sum_{i=0}^n\times \mathrm{C}_{n-j}^{n-i} \\ & = \sum_{j=0}^k\{_j^k\}\dfrac{n!}{(n-j)!}2^{n-j} \\ \end{aligned} \]

式子第一步用到了第二类斯特林数转化下降幂
第二步将下降幂转为组合数形式
至此答案只需从\(1\)\(k\)枚举j,可以\(k^2\)递推预处理出\(\{_j^k\}\)
预处理代码:

st[0][0]=1;
for(int i=1;i<=k;i++){
    for(int j=1;j<=i;j++){
        st[i][j]=(st[i-1][j-1]+j*st[i-1][j]%mod)%mod;
    }
}

对于\(\dfrac{n!}{(n-j)!}\)我们显然不能直接预处理出阶乘逆元,但我们可以从\(n-j+1\)累乘到\(n\)
最终答案为

\[Ans=\sum_{j=1}^k(S2[k][j]\times \prod_{x=n-j+1}^nx \times 2^{n-j}) \]

直接套公式求和即可
总时间复杂度不超过\(O(k^2+k(k+logn))\)

Code

#include <bits/stdc++.h>
using namespace std;
#define il inline
#define ll long long
#define int long long
il ll read(){
    ll x=0,f=1;char ch=getchar();
    while(ch<'0'||ch>'9'){if(ch=='-') f=-1;ch=getchar();}
    while(ch>='0'&&ch<='9'){x=(x<<1)+(x<<3)+(ch^48);ch=getchar();}
    return x*f;
}
il void write(int x){
    if(x<0) putchar('-'),x=-x;
    if(x>9) write(x/10);
    putchar(x%10+'0');
}


const int mod=1e9+7;
#define M 5020
ll st[M][M];
il ll fast_pow(ll x,ll a){
    ll ans=1;
    while(a){
        if(a&1)ans=ans*x%mod;
        x=x*x%mod;
        a>>=1;
    }
    return ans%mod;
}
signed main(){
    ll n=read(),k=read();
    st[0][0]=1;
    for(int i=1;i<=k;i++){
        for(int j=1;j<=i;j++){
            st[i][j]=(st[i-1][j-1]+j*st[i-1][j]%mod)%mod;
        }
    }
    ll ans=0;
    for(int i=0;i<=min(n,k);i++){
        ll comfy=1;
        for(ll j=n-i+1;j<=n;j++)comfy=(comfy*j)%mod;
        ll tmp=((st[k][i]*comfy%mod)%mod*fast_pow(2,n-i)%mod)%mod;
        ans=(ans+tmp)%mod;
    }
    write(ans);
    return 0;
}