【题解】[2023牛客多校] Distance

发布时间 2023-08-04 20:00:56作者: marti88414

题目传送门:[2023牛客多校] Distance

题意

对于任意两个元素个数相同的set:A、B,每次可以执行以下两种操作之一:

  • 将A中的任意元素加一
  • 将B中的任意元素加一

\(C(A, B)\) 含义为将 \(A、B\) 改变为完全相同的 set 所需要花费的最小次数;

初始给你两个set:\(S、T\) ,计算 \(\sum_{A \subseteq S} \sum_{B \subseteq T} C(A,B)\)

其中 \(1 \leq n \leq 2 \times 10^3\)

分析

数据量不大,可以暴力枚举,但是如果枚举每一种集合划分方式不太现实;

有个简单的一目了然的性质:在将两个集合 \(sort\) 之后,任意两个集合的 \(C\) 一定是对应位置的元素弄成相同所需要的花费;利用这个性质我们可以枚举每个点对会在多少种划分方式中出现即可,对答案的贡献就是点对的差

题解

\(sort\) 之后暴力枚举每两个点 \(x,y\),计算其会被几次划分进集合,公式如下:

\[\sum_{i=0}^{min(x-1,y-1)}\binom{x-1}{i}\binom{y-1}{i} \times \sum_{i=0}^{min(n-x,n-y)}\binom{n-x}{i}\binom{n-y}{i} \]

由组合数学定理,该公式可简化为:

\[\binom{x+y-2}{min(x-1,y-1)} \times \binom{2n-x-y}{min(n-x,n-y)} \]

再乘上一个点对差的绝对值即可,时间复杂度 \(O(n^2)\)

AC代码

#pragma GCC optimize(2)
#pragma GCC optimize(3)
#include<bits/stdc++.h>
#define int long long
#define cin std::cin
#define cout std::cout
#define fastio ios::sync_with_stdio(0), cin.tie(nullptr)
using namespace std;
const int N = 2e5 + 10;
const int mod = 998244353;
const int inf = 0x3fffffffffffffff;
char buf[1<<21],*p1=buf,*p2=buf;
inline char getc(){
    return p1==p2&&(p2=(p1=buf)+fread(buf,1,1<<21,stdin),p1==p2)?EOF:*p1++;
}
inline int read(){
    int ret = 0,f = 0;char ch = getc();
    while (!isdigit (ch)){
        if (ch == '-') f = 1;
        ch = getc();
    }
    while (isdigit (ch)){
        ret = ret * 10 + ch - 48;
        ch = getc();
    }
    return f?-ret:ret;
}
int frac[N], inv[N];
int qpow(int a, int b) {
    int s = 1;
    for (; b; a = 1ll * a * a % mod, b >>= 1) if (b & 1) s = 1ll * s * a % mod;
    return s;
}
void set_up() {
    frac[0] = inv[0] = 1;
    for (int i = 1; i <= N - 5; i++) frac[i] = 1ll * frac[i - 1] * i % mod;
    inv[N - 5] = qpow(frac[N - 5], mod - 2);
    for (int i = N - 6; i; i--)
    inv[i] = 1ll * inv[i + 1] * (i + 1) % mod;
}
inline int C(int n, int m) {
    if (n < m) return 0;
    return 1ll * frac[n] * inv[m] % mod * inv[n - m] % mod;
}
inline int calc(int x, int y) {
    return C(x + y, min(x, y));
}
int n, m;
int s[N], t[N];
inline void solve() {
    n = read();
    for(int i = 1; i <= n; ++i) {
        s[i] = read();
    }
    for(int i = 1; i <= n; ++i) {
        t[i] = read();
    }
    sort(s + 1, s + n + 1);
    sort(t + 1, t + n + 1);
    int ans = 0;
    for(int i = 1; i <= n; ++i) {
        for(int j = 1; j <= n; ++j) {
            ans += calc(i - 1, j - 1) % mod * calc(n - i, n - j) % mod * abs(s[i] - t[j]) % mod;
            ans %= mod;
            // cout << ans << endl;
        }
    }
    printf("%lld\n", ans);
}
signed main() {
    // fastio;
    set_up();
    int T;
    // cin >> T;
    T = 1;
    while(T --) {
        solve();
    }
    return 0;
}