Maximum Diameter 题解

发布时间 2023-09-01 18:58:45作者: TKXZ133

Maximum Diameter

题目大意

定义长度为 \(n\) 的序列 \(a\) 的权值为:

  • 所有的 \(n\) 个点的第 \(i\) 个点的度数为 \(a_i\) 的树的直径最大值,如果不存在这样的树,其权值为 \(0\)

给定 \(n\),求所有长度为 \(n\) 的序列的权值和。

思路分析

\(n\) 个点的树的边数为 \(n-1\),总度数为 \(2n-2\),故序列 \(a\) 的权值不为 \(0\) 当且仅当 \(\sum a=2n-2\)\(a_i>0\),因此我们只需要考虑这样的序列即可。

考虑如何根据给定序列构造出直径最大的树,设 \(a\) 中有 \(k\)\(1\),也就是树上有 \(k\) 个叶子节点,那么我们可以将剩下的 \(n-k\) 个节点全部串在一起,再在两端放上两个叶子节点,用 \(n-k+2\) 个点构造出一条长 \(n-k+1\) 的链,其余的叶子节点挂在链上,显然这是最优方案,直径为 \(n-k+1\)

考虑计数。枚举 \(k\),那么叶子节点的选择方案数为 \({n \choose k}\)。而非叶子节点的度数必须大于 \(1\),且有 \(n-k\) 个,又因为剩余的可用度数为 \(2n-2-k\),所以这个问题等价于将 \(2n-2-k\) 个相同的球放在 \(n-k\) 个盒子里,且每个盒子的球必须大于 \(1\),由插板法易得其方案数为:

\[{(2n-2-k)-2(n-k)+(n-k)-1\choose (2n-2-k)-2(n-k)}={n-3\choose k-2} \]

再算上直径产生的贡献,故我们所求式即:

\[\sum_{k=1}^n{n\choose k}{n-3\choose k-2}(n-k+1) \]

这个式子可以 \(O(n)\) 计算,但这显然不够,我们需要继续化简。

我们有以下两个式子:

  • 吸收恒等式:\(k{n\choose k}=n{n-1\choose k-1}\)

  • 范德蒙德卷积:\(\sum\limits_{i=0}^k{n\choose i}{m\choose k-i}={n+m\choose k}\)

一式可以直接拆组合数简单证明,二式通过组合意义显然成立。

然后我们就可以通过以上两个式子对所求式进行化简了:

\[\begin{aligned} \sum_{k=1}^n{n\choose k}{n-3\choose k-2}(n-k+1)&= -\sum_{k=1}^n{n\choose k}{n-3\choose k-2}(k-2+1-n)\\&= -\sum_{k=1}^n{n\choose k}{n-3\choose k-2}(k-2)+(n-1)\sum_{k=1}^n{n\choose k}{n-3\choose k-2}\\&= (n-1)\sum_{k=1}^n{n\choose k}{n-3\choose k-2}-(n-3)\sum_{k=1}^n{n\choose k}{n-4\choose k-3}\\&= (n-1)\sum_{k=1}^n{n\choose k}{n-3\choose n-k-1}-\sum_{k=1}^n{n\choose k}{n-4\choose n-k-1}\\&= (n-1){2n-3\choose n-1}-(n-3){2n-4\choose n-3} \end{aligned}\]

化到这样就可以 \(O(1)\) 计算了,只需要 \(O(n)\) 预处理组合数就行了。

代码

#include <algorithm>
#include <iostream>
#include <cstring>
#include <cstdio>
#include <cmath>

using namespace std;
const int N = 2002000, L = 2000000, mod = 998244353;
#define int long long

int fac[N], inv[N];
int T, n;

int q_pow(int a, int b){
    int res = 1;
    while (b) {
        if (b & 1) res = (res * a) % mod;
        a = (a * a) % mod;
        b >>= 1;
    }
    return res;
}

int C(int n, int m){
    if(n < m || n < 0 || m < 0) return 0;
    return fac[n] * (inv[m] * inv[n - m] % mod) % mod;
}

signed main(){
    fac[0] = 1;
    for (int i = 1; i <= L; i ++) fac[i] = fac[i - 1] * i % mod;
    inv[L] = q_pow(fac[L], mod - 2);
    for (int i = L; i >= 1; i --) inv[i - 1] = inv[i] * i % mod;
    scanf("%lld", &T);
    while (T --) {
        scanf("%lld", &n);
        int res1 = (n - 1) * C(2 * n - 3, n - 1) % mod;
        int res2 = (n - 3) * C(2 * n - 4, n - 3) % mod;
        int ans = (res1 - res2 + mod) % mod;
        cout << ans << '\n';
    }
    return 0;
}