题意:给定一颗树以及一个带权路径集合 \(U\),定义 \(W(S)\) 为 \(S\) 路径两两不交的权值最大的子集的权值和。定义 \(f(i,j)\) 为最小的 \(w\) 满足 \(W(U\cup \{(i,j,w+1)\})>W(U)\),求 \(\sum_{i=1}^n \sum_{j=1}^n f(i,j)\)。
其实可做题。但是实力菜。
首先考虑如何计算 \(W(U)\),我们对于树上问题依然考虑 DP,设 \(f_u\) 表示只有 \(u\) 子树内的点是可经过的,权最大的路径不交子集。转移我们只需要考虑 \(u\) 这个点经没经过,没经过那么最大的子集是 \(\sum_{v\in \text{son}(u)} f_v\),我们记这个值为 \(g_u\)。如果有新的路径经过了 \(u\),由于 \(u\) 是可经过的最高的点,那么它肯定是路径的 LCA,所以 tarjan 一遍求出 LCA 然后把路径挂上去。
现在相当于是说钦定了一条路径不能走,求最大不交子集。我们发现一颗子树差上一条路径依然是若干颗独立的子树,那么钦定一条路径不能走相当于求这些子树的 \(f\) 值之和。我们发现一条路径 \(P\) 的形成的子树的 \(f\) 值之和可以用 \(g_{\text{LCA}_P}+\sum_{v\in P/\{\text{LCA}_P\}} g_v-f_v\) 表示。而 \(\sum_{v\in P/\{\text{LCA}_P\} g_v-f_v\) 可以用边带权并查集轻松维护。
考虑题目中让我们计算的 \(f(i,j)\) 究竟是什么。相当于我们钦定 \((i,j)\) 必须不能经过,问答案减少了多少,我们考虑如果能求出一个子树的补集的 DP 值 \(h_u\),那么这个东西相当于 \(f(P)=f_{root}-h_{\text{LCA}_P}-g_{\text{LCA}_P}-\sum_{v\in P/\{\text{LCA}_P\}} g_v-f_v\)。
发现如果求出 \(h\),贡献的计算将是平凡的。所以我们考虑推导 \(h\)。\(h\) 显然应该从上往下 \(DP\) 得来,如果 \(h_u\) 已知,依然是考虑 \(u\) 经没经过转移到点 \(v\),如果没经过转移是 \(h_u+g_u-f_v\),如果经过了就把经过了 \(u\) 但不经过 \(v\) 的路径的权值贡献到 \(v\) 的 \(h\) 值上去。然而经过 \(u\) 不经过 \(v\) 的总路径个数之和并不是 \(O(n)\) 的,所以我们需要每计算出一条路径的权值就利用树剖动态修改贡献到这条路径的邻域上去。
时间复杂度 \(O(n\log^2 n)\),显然可以优化到 \(O(n\log n)\) 但没必要。
#include <bits/stdc++.h>
using namespace std;
const int P=998244353;
namespace myheader{
#ifndef ONLINE_JUDGE
#define ENABLE_DEBUG
#define ENABLE_FILEREDIR
#endif
#define ENABLE_FREAD
#ifdef ENABLE_FREAD
char buf[1<<20],*p1=buf,*p2=buf;
#define getchar() (p1==p2&&(p2=(p1=buf)+fread(buf,1,1<<20,stdin))?EOF:*p1++)
#endif
typedef long long ll;
typedef unsigned long long ull;
typedef double db;
typedef long double ldb;
typedef pair<int,int> pii;
#ifdef ENABLE_FILEREDIR
struct fredir{
fredir(const string str){
freopen((str+(string)".in").c_str(),"r",stdin);
freopen((str+(string)".out").c_str(),"w",stdout);
}
}redir("emotion");
#endif
#define fi first
#define se second
template<typename T>
ostream& operator<<(ostream& os,vector<T> cur){
if(cur.empty()) return os<<"{}";
os<<'{'<<cur.front();
for(size_t i=1;i<cur.size();++i) os<<','<<cur[i];
return os<<'}';
}
template<typename T>
void chmn(T& x,T v){if(x>v) x=v;}
template<typename T>
void chmx(T& x,T v){if(x<v) x=v;}
template<typename T>
void _debug(const char *str,T x){
#ifdef ENABLE_DEBUG
cerr<<str<<'='<<x<<endl;
#endif
}
template<typename T,typename... Ts>
void _debug(const char *str,T x,Ts... y){
#ifdef ENABLE_DEBUG
while(*str!=',') cerr<<*str++;
cerr<<'='<<x<<',';
_debug(str+1,y...);
#endif
}
#define debug(...) _debug(#__VA_ARGS__,__VA_ARGS__)
template<typename T=int>
T read(){
char c=getchar();bool f=0;T x=0;
while(c<48||c>57) f|=(c=='-'),c=getchar();
do x=(x<<1)+(x<<3)+(c^48),c=getchar();
while(c>=48&&c<=57);
if(f) return -x;
return x;
}
template<typename T>
void read(T& x){
char c=getchar();bool f=0;x=0;
while(c<48||c>57) f|=(c=='-'),c=getchar();
do x=(x<<1)+(x<<3)+(c^48),c=getchar();
while(c>=48&&c<=57);
if(f) x=-x;
}
template<typename T,typename... Ts>
void read(T& x,Ts&... y){read(x);read(y...);}
}
using namespace myheader;
const int N=300003;
const ll INF=0x3f3f3f3f3f3f3f3f;
int n,m;
int hd[N],ver[N<<1],nxt[N<<1],tot=1;
vector<int> vec[N],path[N];
void add(int u,int v){
nxt[++tot]=hd[u];hd[u]=tot;ver[tot]=v;
}
int pu[N],pv[N],pw[N],lca[N];
ll f[N],g[N],h[N];
__int128 pr[N];
int pre[N];
int ft[N];
int rt(int x){
if(ft[x]==x) return x;
return ft[x]=rt(ft[x]);
}
bool vis[N];
void calc(int u,int fa){
for(int x:vec[u]){
int v=pu[x]^pv[x]^u;
if(vis[x]) lca[x]=rt(v);
else vis[x]=1;
}
for(int i=hd[u];i;i=nxt[i]){
int v=ver[i];
if(v==fa) continue;
calc(v,u);
ft[v]=u;
}
}
int lnk[N];ll dis[N];
void jump(int x){
if(lnk[x]==x) return;
jump(lnk[x]);
dis[x]+=dis[lnk[x]];
lnk[x]=lnk[lnk[x]];
}
void getf(int u,int fa){
g[u]=0;
for(int i=hd[u];i;i=nxt[i]){
int v=ver[i];
if(v==fa) continue;
getf(v,u);g[u]+=f[v];
lnk[v]=u;dis[v]=g[v]-f[v];
}
f[u]=g[u];
for(int x:path[u]){
jump(pu[x]);jump(pv[x]);
chmx(f[u],g[u]+dis[pu[x]]+dis[pv[x]]+pw[x]);
}
}
int dfn[N],sn[N],sz[N],de[N],ff[N],top[N],od[N],num;
void dfs(int u,int fa){
sz[u]=1;ff[u]=fa;
pre[u]=(pre[fa]+g[u]-f[u])%P;
pr[u]=pr[fa]+g[u]-f[u];
if(pre[u]<0) pre[u]+=P;
for(int i=hd[u];i;i=nxt[i]){
int v=ver[i];
if(v==fa) continue;
de[v]=de[u]+1;
dfs(v,u);
sz[u]+=sz[v];
if(sz[v]>sz[sn[u]]) sn[u]=v;
}
}
void split(int u,int fa){
od[dfn[u]=++num]=u;top[u]=fa;
if(sn[u]) split(sn[u],fa);
for(int i=hd[u];i;i=nxt[i]){
int v=ver[i];
if(v==ff[u]||v==sn[u]) continue;
split(v,v);
}
}
#define lc (p<<1)
#define rc (p<<1|1)
ll mx[N<<2];
void update(int sl,int sr,ll val,int p=1,int l=1,int r=n){
if(sl>sr) return;
if(sl<=l&&r<=sr) return chmx(mx[p],val);
int mid=(l+r)>>1;
if(sl<=mid) update(sl,sr,val,lc,l,mid);
if(sr>mid) update(sl,sr,val,rc,mid+1,r);
}
ll query(int x,int p=1,int l=1,int r=n){
if(l==r) return mx[p];
int mid=(l+r)>>1;
if(x<=mid) return max(query(x,lc,l,mid),mx[p]);
else return max(query(x,rc,mid+1,r),mx[p]);
}
void build(int p=1,int l=1,int r=n){
mx[p]=-INF;
if(l==r) return;
int mid=(l+r)>>1;
build(lc,l,mid);
build(rc,mid+1,r);
}
#undef lc
#undef rc
int upd(int x,int y,ll val){
while(top[x]!=top[y]){
if(de[top[x]]>de[y]+1) update(dfn[top[x]],dfn[x],val);
else{update(dfn[top[x]]+1,dfn[x],val);return top[x];}
x=ff[top[x]];
}
update(dfn[y]+2,dfn[x],val);
return od[dfn[y]+1];
}
ll qw[N],qry[N];
multiset<ll> st[N],cur;
void geth(int u,int fa){
for(int x:path[u]){
qw[x]=g[u]+pr[pu[x]]+pr[pv[x]]-pr[u]-pr[u]+h[u]+pw[x];
if(pu[x]!=lca[x]) st[upd(pu[x],u,qw[x])].emplace(qw[x]);
if(pv[x]!=lca[x]) st[upd(pv[x],u,qw[x])].emplace(qw[x]);
cur.emplace(qw[x]);
}
ll mx=h[u]+g[u];
for(int x:vec[u]){
int v=pu[x]^pv[x]^u;
if(dfn[v]>dfn[u]&&dfn[v]<dfn[u]+sz[u]) continue;
chmx(mx,qw[x]);
}
for(int i=hd[u];i;i=nxt[i]){
int v=ver[i];
if(v==fa) continue;
qry[v]=query(dfn[v]);
if(qry[v]!=-INF) cur.emplace(qry[v]);
}
for(int i=hd[u];i;i=nxt[i]){
int v=ver[i];
if(v==fa) continue;
if(qry[v]!=-INF) cur.erase(cur.find(qry[v]));
for(ll t:st[v]) cur.erase(cur.find(t));
if(cur.empty()) h[v]=mx-f[v];
else h[v]=max(mx,*prev(cur.end()))-f[v];
if(qry[v]!=-INF) cur.emplace(qry[v]);
for(ll t:st[v]) cur.emplace(t);
st[v].clear();
}
cur.clear();
for(int i=hd[u];i;i=nxt[i]){
int v=ver[i];
if(v==fa) continue;
geth(v,u);
}
}
int res;
void statis(int u,int fa){
int val=(g[u]+h[u]-pre[u]-pre[u])%P;
if(val<0) val+=P;
res-=(ll)sz[u]*val%P;
if(res<0) res+=P;
for(int i=hd[u];i;i=nxt[i]){
int v=ver[i];
if(v==fa) continue;
res-=(ll)sz[v]*(sz[u]-sz[v])%P*val%P;
if(res<0) res+=P;
statis(v,u);
}
res-=(ll)pre[u]*(n<<1)%P;
if(res>=P) res-=P;
if(res<0) res+=P;
}
int main(){
read(n,m);
for(int i=1;i<n;++i){
int u=read(),v=read();
add(u,v);add(v,u);
}
for(int i=1;i<=m;++i){
read(pu[i],pv[i],pw[i]);
vec[pu[i]].emplace_back(i);
vec[pv[i]].emplace_back(i);
}
for(int i=1;i<=n;++i) ft[i]=lnk[i]=i;
calc(1,0);
for(int i=1;i<=m;++i) path[lca[i]].emplace_back(i);
getf(1,0);dfs(1,0);split(1,0);build();geth(1,0);
res=(ll)f[1]%P*n%P*n%P;
statis(1,0);
printf("%d\n",res%P);
return 0;
}