CF1725C Circular Mirror

发布时间 2023-06-24 11:03:25作者: 腾云今天首飞了吗

虽然是一道绿题,但是感觉推式子时的一些细节还是值得学习的,并且还是有点 \(2\) \(hard\) \(4\) \(me\)......

一个圆上有 \(N\) 个可染色的点,编号 \(1\to N\)\(N\) 号点和 \(1\) 号点相邻。
你可以用 \(M\) 种颜色将这些点染色。要求不能出现有三个同色点围成直角三角形。
请求出全部合法方案的总数,输出它模 \(998\ 244\ 353\) 的值。

显然,作为圆直径的点对可以跟非直径的点对分成两类。如何判断直径,只需要记录一下前缀和即可,然后得到直径点对数 \(cnt\)

然后思路就有些混乱了。对于一个以直径为斜边的直角三角形,为了使它不染成同一种颜色,大体上可以归类于两种方法:

  • 把该直径上的两点染成同一种颜色,把其他所有点染成与之不同的颜色。
  • 把该直径上两点染成不同的颜色,其他点随便染。

结论看起来是简洁的,可是当时确实就绕在这两种情况里了,没能捋清楚。。。泵。

对直径上的点,染色有两种选择,自然联想到乘法原理,试着能否把两个步骤的选择数量乘起来以得其结果。但截至到目前,似乎还不太可行,因为不知道在 \(cnt\) 条直径中,哪些选择方案一染色,哪些选择方案二直接枚举
\(i\) 为选择方案一染色的直径点对数,则 \(cnt - i\) 为用方案二染色的直径点对数。
首先由于所有直径互不关联,故从 \(cnt\) 个直径中选 \(i\) 个直径的方案数为\(\begin{pmatrix}cnt\\i\end{pmatrix}\)\(m\) 种颜色,从里面挑 \(i\) 个出来分配到 \(i\) 条直径中,\(i\) 个直径互不影响,因此是有顺序的。挑选、分配的方案为 \(A_m^i\)

对于剩下的 \(cnt - i\) 条直径,已经选了 \(i\) 种颜色,还有 \(m-i\) 种颜色。要求直径两个端点颜色不同,则其中一个端点有 \(m - i\) 种颜色可选,另一个端点有 \(m - i - 1\) 种颜色可选。因为一共有 \(cnt - i\) 条直径,总的方案数就是 \(((m - i - 1) \times (m - i))^{cnt - i}\)

剩下还有 \(n - 2 * cnt\) 个点,每个点有 \(m - i\) 种颜色选,总数就是 \((m - i)^{n - 2 * cnt}\)
全部乘起来,求和:

\[ans = \sum_{i = 0}^{m}C_{cnt}^{i} \times A_{m}^{i} \times ((m - i - 1) \times (m - i))^{cnt - i} \times (m - i)^{n - 2 \times cnt} \]

然后就完了。

#include<bits/stdc++.h>
#define int long long
using namespace std;
const int MOD = 998244353;
const int MAXN = 2e6;
int fac[MAXN + 5],inv[MAXN + 5],n,m,d[MAXN + 5];
map<int,bool> vis;
int qpow(int a,int n){
    int ret = 1;
    while(n){
        if(n & 1)ret = 1ll * ret * a % MOD;
        a = 1ll * a * a % MOD;
        n >>= 1;
    }
    return ret;
}
int c(int n,int m){
    if(m > n)return 0;
    return 1ll * fac[n] * inv[m] % MOD * inv[n - m] % MOD;
}
int a(int n,int m){
    if(n < m)return 0;
    return 1ll * fac[n] * inv[n - m] % MOD;
}
signed main(){
    fac[0] = 1;
    for(int i = 1; i <= MAXN; i++){
        fac[i] = 1ll * fac[i - 1] * i % MOD;
    }
    inv[MAXN] = qpow(fac[MAXN],MOD - 2);
    for(int i = MAXN; i; i--){
        inv[i - 1] = 1ll * inv[i] * i % MOD; 
    }
    scanf("%lld%lld",&n,&m);
    int sum = 0,tot = 0,cnt = 0;
    for(int i = 1; i <= n; i++){
        scanf("%lld",&d[i]);
        tot += d[i];

    }
    int p = tot;
    tot /= 2;
    vis[0] = 1;
    for(int i = 1; i < n; i++){
        sum += d[i];
        int k = sum - tot;
        if(vis.find(k) != vis.end())++cnt;
        vis[sum] = 1;
    }
    sum = 0;
    int x = tot - d[n];
    if(p % 2 == 1){
        cnt = 0;
    }
    long long ans = 0;
    for(int i = 0; i <= cnt; i++){
        ans = (1ll * ans + 1ll * c(cnt,i) * a(m,i) % MOD * qpow(m - i,n - cnt - i) % MOD * qpow(m - i - 1,cnt - i) % MOD) % MOD;
    }
    cout << ans % MOD;
}