树上前缀和

发布时间 2023-08-12 21:12:30作者: ljfyyds

树上前缀和

模板传送门

#include <algorithm>
#include <cstring>
#include <iostream>
using namespace std;
typedef long long LL;
typedef pair<int, int> PII;
inline int read()
{
    int x = 0, f = 1;
    char s = getchar();
    while (s < '0' || s > '9')
    {
        if (s == '-')
            f = -f;
        s = getchar();
    }
    while (s >= '0' && s <= '9')
    {
        x = (x << 3) + (x << 1) + (s ^ 48);
        s = getchar();
    }
    return x * f;
}
const int N = 3e5 + 10, mod = 998244353;

LL fa[N][22]; // fa[v][2]:v向上走2^2次方步的祖先
LL dep[N];    // dep[v]:点v的深度
LL mi[60];    // mi[j]:表示dep[v]的j次幂
LL s[N][60];  // s[v][j]表示从根节点到v节点路径的深度的j次幂之和
LL e[2 * N], ne[2 * N], h[2 * N], idx = 0;
int n, m;
void add(int a, int b) //加边函数
{
    e[idx] = b, ne[idx] = h[a], h[a] = idx++;
}

void dfs(int u, int f)
{
    for (int i = 1; i <= 20; i++)
        fa[u][i] = fa[fa[u][i - 1]][i - 1];
    for (int i = h[u]; ~i; i = ne[i])
    {
        int v = e[i];
        if (v == f)
            continue;
        // cout << u << ' ' << v << endl;
        dep[v] = dep[u] + 1;
        fa[v][0] = u;
        for (int j = 1; j <= 50; j++)
            mi[j] = mi[j - 1] * dep[v] % mod;
        for (int j = 1; j <= 50; j++)
            s[v][j] = (s[u][j] + mi[j]) % mod;
        dfs(v, u);
    }
}

int lca(int u, int v) //最近公共祖先
{
    if (dep[u] < dep[v])
        swap(u, v);
    for (int i = 20; ~i; i--)
        if (dep[fa[u][i]] >= dep[v])
            u = fa[u][i];
    if (u == v)
        return u;
    for (int i = 20; ~i; i--)
        if (fa[u][i] != fa[v][i])
            u = fa[u][i], v = fa[v][i];
    // if (u == 1)
    // return 0;
    return fa[u][0];
}

int main()
{
    memset(h, -1, sizeof(h));
    n = read();
    for (int i = 1; i < n; i++)
    {
        int a = read(), b = read();
        add(a, b), add(b, a);
    }
    mi[0] = 1;
    dfs(1, 0); //预处理每个点可能的的信息
    cin >> m;
    while (m--)
    {
        int u = read(), v = read(), k = read();
        int l = lca(u, v);
        cout << (LL)((s[u][k] + s[v][k]) % mod - (s[l][k] + s[fa[l][0]][k]) % mod + mod) % mod << endl;
        // cout << (LL)(s[u][k] + s[v][k] - s[l][k] - s[fa[l][0]][k] + 2 * mod) % mod << endl;
    }
    return 0;
}