AtCoder Beginner Contest 235 Ex Painting Weighted Graph

发布时间 2023-06-20 22:23:52作者: zltzlt

洛谷传送门

AtCoder 传送门

为啥洛谷唯一一篇题解那么长啊。其实没这么复杂的吧。

考虑边权从小到大排序后建 Kruskal 重构树。那么每次操作相当于,选择一个点,把它子树内的叶子全部染色。注意因为边权相等的边的存在,这里建树要把边权相等的边放一起考虑,建的树也可能是二叉树。

考虑把选点的方案一一对应到最终被染色的叶子集合。不难发现只需要限制:

  • 选的点不能存在祖先 - 后代关系;
  • 一个点的所有儿子不能同时被选。

考虑树形 dp,设 \(f_{u, i, 0/1}\) 表示 \(u\) 子树内进行了 \(i\) 次操作,目前儿子有或没有同时被选,且 \(u\) 不被选的方案数。添加儿子 \(v\) 时把 \(f_{v, j, 1}\) 暴力树形背包合并进来即可。注意因为 \(f_{v, j, 1}\) 规定了 \(v\) 不被选,所以还要考虑选 \(v\) 的情况。

时间复杂度 \(O(m \log n + nk)\)

code
// Problem: Ex - Painting Weighted Graph
// Contest: AtCoder - HHKB Programming Contest 2022(AtCoder Beginner Contest 235)
// URL: https://atcoder.jp/contests/abc235/tasks/abc235_h
// Memory Limit: 1024 MB
// Time Limit: 3000 ms
// 
// Powered by CP Editor (https://cpeditor.org)

#include <bits/stdc++.h>
#define pb emplace_back
#define fst first
#define scd second
#define mems(a, x) memset((a), (x), sizeof(a))

using namespace std;
typedef long long ll;
typedef double db;
typedef unsigned long long ull;
typedef long double ldb;
typedef pair<ll, ll> pii;

const int maxn = 200100;
const int maxm = 510;
const int mod = 998244353;

int n, m, K, sz[maxn], f[maxn][maxm][2], g[maxm][2], deg[maxn];
vector<int> G[maxn], P[maxn];

struct node {
	int u, v, d;
} E[maxn];

inline void upd(int &x, int y) {
	x += y;
	(x >= mod) && (x -= mod);
}

void dfs(int u) {
	f[u][0][0] = sz[u] = 1;
	for (int v : G[u]) {
		dfs(v);
		for (int i = 0; i <= K; ++i) {
			g[i][0] = f[u][i][0];
			g[i][1] = f[u][i][1];
			f[u][i][0] = f[u][i][1] = 0;
		}
		for (int i = min(sz[u], K); ~i; --i) {
			for (int j = min(sz[v], K - i); ~j; --j) {
				if (j == 1) {
					upd(f[u][i + j][0], g[i][0]);
					upd(f[u][i + j][1], g[i][1]);
				}
				upd(f[u][i + j][1], 1LL * (g[i][0] + g[i][1]) * f[v][j][1] % mod);
			}
		}
		sz[u] += sz[v];
	}
	if (G[u].empty()) {
		upd(f[u][0][1], 1);
	}
}

struct DSU {
	int fa[maxn];
	
	int find(int x) {
		return fa[x] == x ? x : fa[x] = find(fa[x]);
	}
	
	inline void merge(int x, int y) {
		x = find(x);
		y = find(y);
		if (x != y) {
			fa[x] = y;
		}
	}
} d1, d2;

void solve() {
	scanf("%d%d%d", &n, &m, &K);
	for (int i = 1; i <= n * 2; ++i) {
		d1.fa[i] = i;
	}
	for (int i = 1; i <= m; ++i) {
		scanf("%d%d%d", &E[i].u, &E[i].v, &E[i].d);
	}
	sort(E + 1, E + m + 1, [&](node a, node b) {
		return a.d < b.d;
	});
	int ntot = n;
	for (int i = 1, j = 1; i <= m; i = (++j)) {
		while (j < m && E[j + 1].d == E[i].d) {
			++j;
		}
		set<int> st;
		for (int k = i; k <= j; ++k) {
			int x = d1.find(E[k].u), y = d1.find(E[k].v);
			d2.fa[x] = x;
			d2.fa[y] = y;
			st.insert(x);
			st.insert(y);
		}
		for (int k = i; k <= j; ++k) {
			int x = d1.find(E[k].u), y = d1.find(E[k].v);
			d2.merge(x, y);
		}
		for (int x : st) {
			P[d2.find(x)].pb(x);
		}
		for (int x : st) {
			if ((int)P[x].size() > 1) {
				int k = ++ntot;
				for (int u : P[x]) {
					d1.fa[u] = k;
					++deg[u];
					G[k].pb(u);
				}
			}
			vector<int>().swap(P[x]);
		}
	}
	for (int i = 1; i <= ntot; ++i) {
		if (!deg[i]) {
			G[0].pb(i);
		}
	}
	dfs(0);
	int ans = 0;
	for (int i = 0; i <= K; ++i) {
		upd(ans, f[0][i][0]);
		upd(ans, f[0][i][1]);
	}
	printf("%d\n", ans);
}

int main() {
	int T = 1;
	// scanf("%d", &T);
	while (T--) {
		solve();
	}
	return 0;
}