题解 CF1857G【Counting Graphs】

发布时间 2023-08-09 18:17:18作者: rui_er

一个非常显然的事情是:总方案数即为每条边方案数之积。

树边已经确定,考察每条非树边 \((u,v)\) 可以怎么取。给定的树 \(T\) 是唯一最小生成树,这意味着非树边 \((u,v)\) 要么不存在,要么权值大于 \(T\)\((u,v)\) 之间任意一条边的权值。设 \(T\)\((u,v)\) 间的最大边权为 \(k\),则 \((u,v)\) 对答案的贡献为 \(S-k+1\)

但我们无法枚举每条非树边计算贡献,因为复杂度为 \(O(n^2)\)。考虑将“一类”非树边放到一起计算。

考虑 Kruskal 算法的过程,每次取出权值最小的边 \((u,v,w)\) 加入最小生成树,并将两个连通块 \(B_1,B_2\) 合并。当一条边 \((u,v,w)\) 加入最小生成树时,它就是跨越两个连通块 \(B_1,B_2\) 的任意一对点间的最大权值。这就意味着对于每一对 \(B_1\times B_2\) 中的点对(\(\times\) 是集合直积,\((u,v)\) 除外),这条边要么不存在,要么权值大于 \(w\)

我们用桶统计出对于每个 \(w\),有多少条边的要求是“要么不存在,要么权值大于 \(w\)”,并用快速幂计算即可。

时间复杂度 \(O(n\log n+n\log S)\)

// Problem: G. Counting Graphs
// Contest: Codeforces - Codeforces Round 891 (Div. 3)
// URL: https://codeforces.com/contest/1857/problem/G
// Memory Limit: 256 MB
// Time Limit: 2000 ms
// 
// Powered by CP Editor (https://cpeditor.org)

//By: OIer rui_er
#include <bits/stdc++.h>
#define rep(x,y,z) for(ll x=(y);x<=(z);x++)
#define per(x,y,z) for(ll x=(y);x>=(z);x--)
#define debug(format...) fprintf(stderr, format)
#define fileIO(s) do{freopen(s".in","r",stdin);freopen(s".out","w",stdout);}while(false)
using namespace std;
typedef long long ll;

mt19937 rnd(std::chrono::duration_cast<std::chrono::nanoseconds>(std::chrono::system_clock::now().time_since_epoch()).count());
ll randint(ll L, ll R) {
    uniform_int_distribution<ll> dist(L, R);
    return dist(rnd);
}

template<typename T> void chkmin(T& x, T y) {if(x > y) x = y;}
template<typename T> void chkmax(T& x, T y) {if(x < y) x = y;}

const ll N = 2e5+5, mod = 998244353;

ll T, n, S;

struct Edge {
    ll u, v, w;
}e[N];

struct Dsu {
    ll fa[N], sz[N];
    void init(ll x) {rep(i, 1, x) fa[i] = i, sz[i] = 1;}
    ll find(ll x) {return x == fa[x] ? x : fa[x] = find(fa[x]);}
    bool same(ll x, ll y) {return find(x) == find(y);}
    bool merge(ll x, ll y) {
        if(same(x, y)) return false;
        x = find(x); y = find(y);
        fa[x] = y;
        sz[y] += sz[x];
        return true;
    }
}dsu;

map<ll, ll> cnt;

ll qpow(ll x, ll y) {
    ll ans = 1;
    for(; y; y >>= 1, x = x * x % mod) if(y & 1) ans = ans * x % mod;
    return ans;
}

int main() {
    for(scanf("%lld", &T); T; T--) {
        map<ll, ll>().swap(cnt);
        scanf("%lld%lld", &n, &S);
        rep(i, 1, n-1) scanf("%lld%lld%lld", &e[i].u, &e[i].v, &e[i].w);
        sort(e+1, e+n, [](const Edge& a, const Edge& b) {return a.w < b.w;});
        dsu.init(n);
        rep(i, 1, n-1) {
            ll u = e[i].u, v = e[i].v, w = e[i].w;
            cnt[w] += dsu.sz[dsu.find(u)] * dsu.sz[dsu.find(v)] - 1;
            dsu.merge(u, v);
        }
        ll ans = 1;
        for(auto [key, val] : cnt) if(val) ans = ans * qpow(S-key+1, val) % mod;
        printf("%lld\n", ans);
    }
    return 0;
}