「解题报告」CF500G New Year Running

发布时间 2023-06-28 15:41:03作者: APJifengc

来个垃圾做法。

首先这个树没啥用,我们只需要找到两条路径的交与方向即可。分类讨论即可得到,略掉了。

把距离找出来之后,我们可以将运动路径写成序列,那么我们现在相当于有两个分成五段的序列,其中有两端为 \(1 \to d\)\(d \to 1\)。我们枚举几种情况,分别是 \(1 \to d, 1 \to d\)\(1 \to d, d \to 1\)\(d \to 1, 1 \to d\)\(d \to 1, d \to 1\) 几种情况。

方向相同的时候,容易发现第一次相遇的位置一定是在两个 \(1 / d\) 的位置相遇。我们直接设两个序列走了 \(p, q\) 次,那么相当于要解一个 \(pm_1 + a_1 = qm_2 + a_2\) 的方程,这个可以直接 exgcd 解出来。

方向不同的时候,位置不能确定了,但是我们一定能确定两个位置的和,所以我们考虑枚举第一个序列的位置 \(x\),设两个位置的和为 \(a\),那么我们就是需要解 \(pm_1 + x = qm_2 + a - x\),即 \(pm_1 = qm_2 + a - 2x\)。我们可以两边模 \(m_2\),这样我们只需要满足 \(qm_2 + a \equiv 2x \pmod {m_1}\) 的最小的正整数 \(q\) 即可。此时 \(x \in [l, r]\)

把问题改成 \(kx+b \equiv 2x \pmod p\)。后面的 \(2x\) 很麻烦,我们先考虑把这个除掉。但是 \(\gcd(p, 2)\) 不一定等于 \(1\),所以不一定存在逆元。那么我们分情况讨论:

  1. \(p\) 为奇数:显然有逆元,直接乘即可。
  2. \(p\) 为偶数:我们讨论 \(k, b\) 的奇偶性。如果 \(k\) 为偶数,那么如果 \(b\) 为奇数,一定无解,否则可以直接全部除以 \(2\)。如果 \(k\) 为奇数,那么我们一定能确定 \(x\) 的奇偶性,所以直接令 \(x = 2k/2k+1\),然后再全部除以 \(2\) 即可。

这样,我们把问题转换成了 \(kx+b \equiv x \pmod {p}\) 了,这相当于在求 \(kx \bmod p \in [l, r]\) 的最小 \(x\)。这个是类欧的经典问题,可以看 ARC127F ±AB

细节特别多。不保证代码不存在细节问题。

#include <bits/stdc++.h>
using namespace std;
const int MAXN = 200005;
int n, q;
vector<int> e[MAXN];
int siz[MAXN], dep[MAXN], fa[MAXN], son[MAXN], top[MAXN], dfn[MAXN], idf[MAXN], dcnt;
void dfs1(int u, int pre) {
    fa[u] = pre, dep[u] = dep[pre] + 1, siz[u] = 1;
    for (int v : e[u]) if (v != pre) {
        dfs1(v, u);
        siz[u] += siz[v];
        if (siz[v] > siz[son[u]]) son[u] = v;
    }
}
void dfs2(int u, int pre, int t) {
    top[u] = t, dfn[u] = ++dcnt, idf[dcnt] = u;
    if (son[u]) dfs2(son[u], u, t);
    for (int v : e[u]) if (v != pre && v != son[u]) {
        dfs2(v, u, v);
    }
}
int lca(int u, int v) {
    while (top[u] != top[v]) {
        if (dep[top[u]] > dep[top[v]]) u = fa[top[u]];
        else v = fa[top[v]];
    }
    return dep[u] > dep[v] ? v : u;
}
int dis(int u, int v) { return dep[u] + dep[v] - 2 * dep[lca(u, v)]; }
void exgcd(long long a, long long b, long long &x, long long &y) {
    if (!b) x = 1, y = 0;
    else {
        exgcd(b, a % b, x, y);
        int t = x; x = y; y = t - (a / b) * y;
    }
}
long long solve1(long long a, long long b, long long c) {
    long long g = __gcd(a, b);
    if (c % g != 0) return -1;
    a /= g, b /= g, c /= g;
    long long x, y; exgcd(a, b, x, y);
    x *= c, y *= c;
    if (x < 0 || y > 0) {
        long long t = max((-x) / b, y / a);
        x += t * b, y -= t * a;
        if (x < 0 || y > 0) x += b, y -= a;
    }
    if (x >= b || y <= -a) {
        long long t = min(x / b, (-y) / a);
        x -= t * b, y += t * a;
    }
    return x;
}
// find the minimum k that ak mod b \in [l, r]
long long solve(long long a, long long b, long long l, long long r) {
    a %= b;
    if (a == 0) {
        if (l != 0) return -1;
        return 0;
    }
    if (l == 0 || (l - 1) / a != r / a) return (l + a - 1) / a;
    long long t = solve(b, a % b, (a - r % a) % a, (a - l % a) % a);
    if (t == -1) return -1;
    return (1ll * t * b + l + a - 1) / a;
}
pair<long long, long long> solve3(long long k, long long b, long long l, long long r, long long p) {
    if (l > r) return { -1, -1 };
    l -= b, r -= b;
    if (r < 0) l += p, r += p;
    if (l < 0) {
        long long t1 = solve(k, p, l + p, p - 1);
        long long t2 = solve(k, p, 0, r);
        long long ans;
        if (t1 == -1) ans = t2;
        else if (t2 == -1) ans = t1;
        else ans = min(t1, t2);
        if (ans != -1) return { ans, (k * ans + b) % p };
        else return { -1, -1 };
    }
    long long ans = solve(k, p, l, r);
    if (ans == -1) return { -1, -1 };
    return { ans, (k * ans + b) % p };
}
pair<long long, long long> solve2(long long k, long long b, long long l, long long r, long long p) {
    k %= p, b %= p;
    long long ansk = 1, ansb = 0;
    if (p % 2 == 0) {
        if (k % 2 == 0) {
            if (b % 2 == 1) return { -1, -1 };
            k /= 2, b /= 2, p /= 2;
        } else {
            if (b % 2 == 0) {
                ansk = 2, b /= 2, p /= 2;
            } else {
                ansk = 2, ansb = 1, b = (b + k) / 2, p /= 2;
            }
        }
    } else {
        k = (k * (p + 1) / 2) % p, b = (b * (p + 1) / 2) % p;
    }
    long long w = l / p;
    l -= w * p, r -= w * p;
    auto t = solve3(k, b, l, r, p);
    if (t.first != -1) return { ansk * t.first + ansb, t.second + w * p };
    return { -1, -1 };
}
int main() {
    scanf("%d", &n);
    for (int i = 1; i < n; i++) {
        int u, v; scanf("%d%d", &u, &v);
        e[u].push_back(v), e[v].push_back(u);
    }
    dfs1(1, 0), dfs2(1, 0, 1);
    scanf("%d", &q);
    while (q--) {
        int u, v, x, y; scanf("%d%d%d%d", &u, &v, &x, &y);
        int a1, b1, c1, a2, b2, c2, m1, m2, d;
        int p, q, r;
        vector<int> a = { lca(u, x), lca(u, y), lca(v, x), lca(v, y) };
        sort(a.begin(), a.end(), [&](int a, int b) { return dep[a] > dep[b]; });
        p = a[0], q = a[1];
        if (p == q && dep[p] < max(dep[lca(u, v)], dep[lca(x, y)])) {
            printf("-1\n");
            continue;
        }
        if (dis(u, p) > dis(u, q)) swap(p, q);
        long long ans = LLONG_MAX;
        if (dis(x, p) <= dis(x, q)) {
            d = dis(p, q) + 1;
            a1 = dis(u, p), b1 = 2 * dis(v, q) - 1, c1 = a1 - 1, m1 = a1 + b1 + c1 + 2 * d;
            a2 = dis(x, p), b2 = 2 * dis(y, q) - 1, c2 = a2 - 1, m2 = a2 + b2 + c2 + 2 * d;
            long long t, a;
            pair<long long, long long> p;
            // case 1:
            t = solve1(m1, m2, a2 - a1);
            if (t != -1) ans = min(ans, t * m1 + a1);
            // case 2:
            t = solve1(m1, m2, (a2 + d + b2) - (a1 + d + b1));
            if (t != -1) ans = min(ans, t * m1 + a1 + d + b1);
            // case 3:
            a = a1 + a2 + d + b2 + d - 1;
            p = solve2(m2, a, a1, a1 + d - 1, m1);
            if (p.first != -1) ans = min(ans, p.first * m2 + a - p.second);
            // case 4:
            a = a2 + a1 + d + b1 + d - 1;
            p = solve2(m2, a, a1 + d + b1, a1 + d + b1 + d - 1, m1);
            if (p.first != -1) ans = min(ans, p.first * m2 + a - p.second);
        } else {
            d = dis(p, q) + 1;
            a1 = dis(u, p), b1 = 2 * dis(v, q) - 1, c1 = a1 - 1, m1 = a1 + b1 + c1 + 2 * d;
            a2 = dis(x, q), b2 = 2 * dis(y, p) - 1, c2 = a2 - 1, m2 = a2 + b2 + c2 + 2 * d;
            long long t, a;
            pair<long long, long long> p;
            // case 1:
            t = solve1(m1, m2, (a2 + d + b2) - a1);
            if (t != -1) ans = min(ans, t * m1 + a1);
            // case 2:
            t = solve1(m1, m2, a2 - (a1 + d + b1));
            if (t != -1) ans = min(ans, t * m1 + a1 + d + b1);
            // case 3:
            a = a1 + a2 + d - 1;
            p = solve2(m2, a, a1, a1 + d - 1, m1);
            if (p.first != -1) ans = min(ans, p.first * m2 + a - p.second);
            // case 4:
            a = a1 + d + b1 + a2 + d + b2 + d - 1;
            p = solve2(m2, a, a1 + d + b1, a1 + d + b1 + d - 1, m1);
            if (p.first != -1) ans = min(ans, p.first * m2 + a - p.second);
        }
        if (ans == LLONG_MAX) {
            printf("-1\n");
        } else {
            printf("%lld\n", ans);
        }
    }
    return 0;
}