## HDU7328 Snake

发布时间 2023-08-05 21:06:33作者: touchfishman

HDU7328 Snake

tag: 容斥,生成函数

题目链接

题意:

1到n个数,分成m组,组队元素排列顺序不同则为不同的组,且每组元素个数不能超过k,问有多少种方案。

容斥做法:

  • 暂且不管排列的事情,先把问题看成求n个球,m个盒子,盒子不能为空,且每个盒子中球数不能超过k。
  • 要保证每盒球数不超过k,我们可通过容斥来解决,定义f(i)为至少i个盒子球数超过k了,最终答案ans即为: f(0)-f(1)+f(2)-....f(m)
  • 保证有i个盒子超过k,我们可以先给这x个盒子放k个,排列组合为$C_{m}^{i}$
    ,那么现在还剩n-xk个球,再用隔板法把他们分成m份,排列组合为$C_{n-ki-1}^{m-1}$
  • 最后再将 ans*(n!)/(m!),算上排序的组合数,而多重集之间没有排序关系,所以要除m的阶乘
  • 这题可以说是2021威海M的简单版本,推荐去写。

代码:

#define fst std::ios::sync_with_stdio(false), std::cin.tie(0), std::cout << std::fixed << std::setprecision(20)
#define le "\n"
#define ll long long 
#include <bits/stdc++.h>
using namespace std;
const int N=1e6+50;
const int mod=998244353;
ll fac[N],inv[N];

ll qsm(ll n,ll m){
    ll res = 1;;
    while(m){
        if(m&1) res = res*n%mod;
        n = n*n%mod;
        m >>= 1;
    }
    return res;
}

void init(){
    int mz =  1e6;
    fac[0] = inv[0] = 1;
    for(int i=1;i<=mz;i++) fac[i] = fac[i-1]*i%mod;
    inv[mz] = qsm(fac[mz],mod-2);
    for(int i = mz-1;i>=1;i--) inv[i] = inv[i+1]*(i+1)%mod;
}

ll C(int a,int b){
    if(a<b||b<0||a<0) return 0; 
    return (fac[a]*inv[b]%mod)*inv[a-b]%mod;
}

void solve(){
    int n,m,k; cin>>n>>m>>k;
    ll ans = 0;
    bool f = 1;
    for(int i=0;i<=m;i++){
       ll tmp = f ? 1 : mod-1;
       f ^= 1;
       ans = (ans+tmp*C(m,i)%mod*C(n-1-i*k,m-1)%mod)%mod;
    }
    ans = ans*fac[n]%mod*inv[m]%mod;
    cout<<ans<<le;
}
int main(){
    fst;
    init();
    int t; cin>>t;
    while(t--){
        solve();
    }
    return 0;
}   

生成函数做法:

  • 同上面一样先将问题看成小球分盒子问题,这回我们用生成函数求解ans
  • 因为首先不考虑内部排序,且盒子不能为空,那么一个盒子的生成函数为: $x+x{2}+x+...+x^{n} = \frac {x}{x-1}$
  • 同时不能超过k个,那么生成函数应该是:
    $x+x{2}+x+...+x^{k} = \frac {(x^k-1)x}{x-1}$
  • 我们要求的ans为$(\frac{(xk-1)x}{x-1})$的第n项,展开得到答案为 $(x^{k} − 1)^{m}(1 − x)^{-m}$的第 n − m 项, $(x^{k} − 1)^{m}$用二项式定理展开即可,$(1 − x){-m}$则需要用到广义二项式定理:$(1+x) = \sum_{i=0}{+\infty}C_{n+i-1}x^{i}$
  • 剩下步骤同上

代码:

#define fst std::ios::sync_with_stdio(false), std::cin.tie(0), std::cout << std::fixed << std::setprecision(20)
#define le "\n"
#define ll long long 
#include <bits/stdc++.h>
using namespace std;
const int N=1e6+50;
const int mod=998244353;
ll fac[N],inv[N];

ll qsm(ll n,ll m){
    ll res = 1;;
    while(m){
        if(m&1) res = res*n%mod;
        n = n*n%mod;
        m >>= 1;
    }
    return res;
}

void init(){
    int mz =  1e6;
    fac[0] = inv[0] = 1;
    for(int i=1;i<=mz;i++) fac[i] = fac[i-1]*i%mod;
    inv[mz] = qsm(fac[mz],mod-2);
    for(int i = mz-1;i>=1;i--) inv[i] = inv[i+1]*(i+1)%mod;
}

ll C(int a,int b){
    if(a<b||b<0||a<0) return 0; 
    return (fac[a]*inv[b]%mod)*inv[a-b]%mod;
}

void solve(){
    int n,m,k; cin>>n>>m>>k;
    vector<ll> a(n+1),b(n+1);
    for(int i=0;i*k<=n-m;i++){
        ll tmp = (i&1) ? mod-1: 1;
        a[i*k] = C(m,i)*tmp%mod;
    }

    for(int i=0;i<=n-m;i++){
        b[i] = C(m+i-1,i)%mod;
    }
    ll ans = 0;
    for(int i=0;i<=n-m;i++) ans = (ans+a[i]*b[n-m-i]%mod)%mod;
    ans = ans*fac[n]%mod*inv[m]%mod;
    cout<<ans<<le;
}
int main(){
    fst;
    init();
    int t; cin>>t;
    while(t--){
        solve();
    }
    return 0;
}