「解题报告」ARC123F Insert Addition

发布时间 2023-03-28 20:55:18作者: APJifengc

我啥都不会啊唔。

我们考虑不使用数来刻画这个东西,而是使用一个系数对来刻画这个东西,即 \((x, y)\) 表示 \(ax+by\)。那么我们相当于是初始有 \((1, 0), (0, 1)\),每次相邻的两个二元组对应位置相加,即 \((a, b), (a+c, b+d), (c, d)\)

发现这个过程与 Stern-Brocot 树的构建过程是一模一样的,那么我们其实就是要找 Stern-Brocot 树上满足 \(ax+by \le n\)\(\frac{y}{x}\) 中前 \([L, R]\) 大的数。

求前 \([L, R]\) 大的问题是比较经典的,我们可以先找出第 \(L\) 大的,然后再递推找出 \([L, R]\) 大的。放在 Stern-Brocot 树上,就是先找出第 \(L\) 大的节点位置,然后接着在树上 DFS 找出 \([L, R]\) 大的数。由于 Stern-Brocot 树高 \(O(n)\),那么我们最多只会经过 \(O(n)\) 个无用节点,其他访问到的节点必定是一个答案,所以这部分的复杂度为 \(O(R - L + n)\)

我们考虑如何找第 \(L\) 大的节点。我们可以在 Stern-Brocot 树上二分,那么我们只需要快速求得每个子树有多少节点即可。

那么我们需要解决的问题为:

求有多少对 \((x, y)\),满足:

  • \(\gcd(x, y) = 1\)
  • \(ax+by \le n\)

\(\gcd(x, y) = 1\) 的限制显然可以莫反搞掉,简单推式子可以得到答案为:

\[\sum_{d=1}^n\mu(d)\sum_{i=1}^{\lfloor\frac{n}{da}\rfloor}\lfloor\frac{\lfloor\frac nd\rfloor-ia}{b}\rfloor \]

数论分块套类欧可以做到 \(O(\sqrt n \log n)\) 计算。

然后我们就跑即可。注意 Stern-Brocot 树上跳相同方向的链时需要倍增跳,复杂度 \(O(n + \sqrt{n} \log^3 n)\)

另一种做法是实数二分,然后再找出分数表示。maspy 的代码中有一个奇妙的递推式子,但是递推式子属实没有看懂,咕了。

类欧实在不想写了,贺了一份 atcoder 的板子。

正解跑的果然比乱搞快,直接最优解了。

#include <bits/stdc++.h>
using namespace std;
const int MAXN = 300005;
int a, b, n;
long long l, r;
int mu[MAXN];
int pri[MAXN], vis[MAXN], pcnt;
long long f[MAXN];

namespace atcoder {
    namespace internal {
        // @param m `1 <= m`
        // @return x mod m
        constexpr long long safe_mod(long long x, long long m) {
            x %= m;
            if (x < 0) x += m;
            return x;
        }
        // @param n `n < 2^32`
        // @param m `1 <= m < 2^32`
        // @return sum_{i=0}^{n-1} floor((ai + b) / m) (mod 2^64)
        unsigned long long floor_sum_unsigned(unsigned long long n,
                                              unsigned long long m,
                                              unsigned long long a,
                                              unsigned long long b) {
            unsigned long long ans = 0;
            while (true) {
                if (a >= m) {
                    ans += n * (n - 1) / 2 * (a / m);
                    a %= m;
                }
                if (b >= m) {
                    ans += n * (b / m);
                    b %= m;
                }

                unsigned long long y_max = a * n + b;
                if (y_max < m) break;
                n = (unsigned long long)(y_max / m);
                b = (unsigned long long)(y_max % m);
                std::swap(m, a);
            }
            return ans;
        }
    }
    long long floor_sum(long long n, long long m, long long a, long long b) {
        assert(0 <= n && n < (1LL << 32));
        assert(1 <= m && m < (1LL << 32));
        unsigned long long ans = 0;
        if (a < 0) {
            unsigned long long a2 = internal::safe_mod(a, m);
            ans -= 1ULL * n * (n - 1) / 2 * ((a2 - a) / m);
            a = a2;
        }
        if (b < 0) {
            unsigned long long b2 = internal::safe_mod(b, m);
            ans -= 1ULL * n * ((b2 - b) / m);
            b = b2;
        }
        return ans + internal::floor_sum_unsigned(n, m, a, b);
    }
}
long long count(long long a, long long b) {
    long long ans = 0;
    for (int l = 1, r; l <= n; l = r + 1) {
        r = n / (n / l);
        if (n / l < a + b) break;
        ans += (mu[r] - mu[l - 1]) * atcoder::floor_sum(n / l / a, b, -a, n / l - a);
    }
    return ans;
}
vector<int> ans;
typedef pair<long long, long long> fraction;
fraction operator+(fraction x, fraction y) {
    return { x.first + y.first, x.second + y.second };
}
fraction operator*(long long x, fraction y) {
    return { x * y.first, x * y.second };
}
long long calc(fraction x) {
    return 1ll * a * x.first + 1ll * b * x.second;
}
void insert(fraction x) {
    ans.push_back(calc(x));
}
void solve2(fraction x, fraction y, long long l, long long r) {
    if (ans.size() >= r - l + 1) return;
    long long mid = calc(x + y);
    if (mid > n) return;
    solve2(x, x + y, l, r);
    if (ans.size() < r - l + 1) {
        insert(x + y);
    }
    solve2(x + y, y, l, r);
}
void solve1(fraction x, fraction y, long long l, long long r) {
    if (ans.size() >= r - l + 1) return;
    long long a = calc(x), b = calc(y), mid = calc(x + y);
    if (mid > n) return;
    long long cnt = count(calc(x), mid) + 1;
    if (l == cnt) {
        insert(x + y);
        solve2(x + y, y, l, r);
        return;
    } else if (l < cnt) {
        int dep = 0;
        for (int i = 20; i >= 0; i--) {
            if (l < count(a, calc((dep + (1 << i)) * x + y)) + 1) {
                dep += (1 << i);
            }
        }
        solve1(x, dep * x + y, l, r);
        for (int i = dep; i >= 1; i--) {
            if (ans.size() < r - l + 1) {
                insert(i * x + y);
                solve2(i * x + y, (i - 1) * x + y, l, r);
            } else break;
        }
    } else {
        int dep = 0;
        long long tot = count(a, b);
        for (int i = 20; i >= 0; i--) {
            if (tot - l + 1 < count(calc(x + (dep + (1 << i)) * y), b) + 1) {
                dep += (1 << i);
            }
        }
        cnt = tot - count(calc(x + dep * y), b);
        solve1(x + dep * y, y, l - cnt, r - cnt);
    }
}
int main() {
    scanf("%d%d%d%lld%lld", &a, &b, &n, &l, &r);
    mu[1] = 1;
    for (int i = 2; i <= n; i++) {
        if (!vis[i]) {
            pri[++pcnt] = i;
            mu[i] = -1;
        }
        for (int j = 1; j <= pcnt && i * pri[j] <= n; j++) {
            vis[i * pri[j]] = 1;
            if (i % pri[j] == 0) {
                break;
            } else {
                mu[i * pri[j]] = -mu[i];
            }
        }
        mu[i] += mu[i - 1];
    }
    long long cnt = count(a, b);
    if (l != cnt + 2 && r != 1) {
        solve1({ 1, 0 }, { 0, 1 }, max(l - 1, 1ll), min(r - 1, cnt));
    }
    if (l == 1) printf("%d ", a);
    for (int i : ans) {
        printf("%d ", i);
    }
    if (r == cnt + 2) printf("%d ", b);
    printf("\n");
    return 0;
}