D. Small GCD

发布时间 2023-11-27 18:32:50作者: onlyblues

D. Small GCD

Let $a$, $b$, and $c$ be integers. We define function $f(a, b, c)$ as follows:

Order the numbers $a$, $b$, $c$ in such a way that $a \le b \le c$. Then return $\gcd(a, b)$, where $\gcd(a, b)$ denotes the greatest common divisor (GCD) of integers $a$ and $b$.

So basically, we take the $\gcd$ of the $2$ smaller values and ignore the biggest one.

You are given an array $a$ of $n$ elements. Compute the sum of $f(a_i, a_j, a_k)$ for each $i$, $j$, $k$, such that $1 \le i < j < k \le n$.

More formally, compute $$\sum_{i = 1}^n \sum_{j = i+1}^n \sum_{k =j +1}^n f(a_i, a_j, a_k).$$

Input

Each test contains multiple test cases. The first line contains the number of test cases $t$ ($1 \le t \le 10$). The description of the test cases follows.

The first line of each test case contains a single integer $n$ ($3 \le n \le 8 \cdot 10^4$) — length of the array $a$.

The second line of each test case contains $n$ integers, $a_1, a_2, \ldots, a_n$ ($1 \le a_i \le 10^5$) — elements of the array $a$.

It is guaranteed that the sum of $n$ over all test cases does not exceed $8 \cdot 10^4$.

Output

For each test case, output a single number — the sum from the problem statement.

Example

input

2
5
2 3 6 12 17
8
6 12 8 10 15 12 18 16

output

24
203

Note

In the first test case, the values of $f$ are as follows:

$i=1$, $j=2$, $k=3$, $f(a_i,a_j,a_k)=f(2,3,6)=\gcd(2,3)=1$;

$i=1$, $j=2$, $k=4$, $f(a_i,a_j,a_k)=f(2,3,12)=\gcd(2,3)=1$;

$i=1$, $j=2$, $k=5$, $f(a_i,a_j,a_k)=f(2,3,17)=\gcd(2,3)=1$;

$i=1$, $j=3$, $k=4$, $f(a_i,a_j,a_k)=f(2,6,12)=\gcd(2,6)=2$;

$i=1$, $j=3$, $k=5$, $f(a_i,a_j,a_k)=f(2,6,17)=\gcd(2,6)=2$;

$i=1$, $j=4$, $k=5$, $f(a_i,a_j,a_k)=f(2,12,17)=\gcd(2,12)=2$;

$i=2$, $j=3$, $k=4$, $f(a_i,a_j,a_k)=f(3,6,12)=\gcd(3,6)=3$;

$i=2$, $j=3$, $k=5$, $f(a_i,a_j,a_k)=f(3,6,17)=\gcd(3,6)=3$;

$i=2$, $j=4$, $k=5$, $f(a_i,a_j,a_k)=f(3,12,17)=\gcd(3,12)=3$;

$i=3$, $j=4$, $k=5$, $f(a_i,a_j,a_k)=f(6,12,17)=\gcd(6,12)=6$.

The sum over all triples is $1+1+1+2+2+2+3+3+3+6=24$.
In the second test case, there are $56$ ways to choose values of $i$, $j$, $k$. The sum over all $f(a_i,a_j,a_k)$ is $203$.

 

解题思路

  首先可以对 $a$ 从小到大排序,不会影响结果。这是因为对原本 $a$ 中所有的三元组逐个排序,与排序后 $a$ 的所有三元组是完全一样的。然后是因为统计答案,所以尝试对所有的三元组按照某个属性分类再分别统计。这里给出两种不同的分类方法。

  先给出我的做法,按照三元组中间的值进行分类,那么一共被分成 $n-2$ 类,即中间值分别为 $a_2 \sim a_{n-1}$ 的三元组。暴力做法是枚举 $a_i$ 作为中间值,那么所有比 $a_i$ 小(或相等)的值就是 $a_j, \, j \in [1,i-1]$,所有比 $a_i$ 大(或相等)的值有 $n-i$ 个,枚举所有的 $a_j$ 分别计算 $\gcd(a_i, a_j) \times (n-i)$,表示所有最小值为 $a_j$,中间值为 $a_i$,最大值比 $a_i$ 大(或相等)的三元组的贡献。那么 $\sum\limits_{j=1}^{i-1}{\gcd(a_i, a_j)} \times (n-i)$ 就是所有中间值为 $a_i$ 的三元组的贡献,最终答案就是 $\sum\limits_{i=2}^{n-1}{\sum\limits_{j=1}^{i-1}{\gcd(a_i, a_j)} \times (n-i)}$,很显然时间复杂度为 $O(n^2 \log{m})$ 会超时,其中 $m = \max\limits_{1 \leq i \leq n}\{ a_i \}$。

  当确定了 $a_i$ 作为中间值后,那么 $a_i$ 与任意数的最大公约数只可能为 $a_i$ 的约数。所以可以反过来考虑,从大到小枚举 $a_i$ 的所有约数 $d$,如果发现有之前没选过的 $a_j$ 满足 $d \mid a_j$,那么对于这些 $a_j$ 有 $\gcd(a_i, a_j) = d$,选择这些 $a_j$。由于在不超过 $10^5$ 的数中一个数的约数个数最多有 $128$ 个,因此这种做法应该是可行的。关键是在于如果维护 $a_i$ 的每个约数对应哪些 $a_j$。

  我们先通过 $O(m \log{m})$ 的时间复杂度预处理出来 $1 \sim m$ 的每个数的约数。维护数组 $\text{cnt}_x$,表示在枚举到 $a_i$ 时 $a_1 \sim a_{i-1}$ 中含约数 $x$ 的数的个数,当枚举完 $a_i$ 后遍历 $a_i$ 每个约数 $d$,有 $\text{cnt}_d \gets \text{cnt}_d + 1$。令 $\text{cnt}'$ 表示 $\text{cnt}$ 的备份,对于 $a_i$,从大到小枚举其约数 $d$,如果 $\text{cnt}'_d > 0$ 说明前面存在没选过且约数为 $d$ 的 $a_j$,有 $\text{cnt}'_d$ 个,那么贡献就是 $d \times \text{cnt}'_d \times (n-i)$。将这些 $a_j$ 标记为选过等价于将这些 $a_j$ 的记录从 $\text{cnt}'$ 抹去,做法是枚举每个 $a_j$ 的约数 $x$,然后 $\text{cnt}'_x \gets \text{cnt}'_x - 1$,然而我们只关心 $a_i$ 的约数,所以在 $a_j$ 的所有约数中只用删去同时是 $a_i$ 的约数即可。而既是 $a_j$ 的约数,又是 $a_i$ 的约数,其实就是 $a_i$ 和 $a_j$ 的所有公约数,即 $\gcd(a_i, a_j) = d$ 的所有约数。所以我们只需枚举 $d$ 的所有约数 $x$,然后 $\text{cnt}'_x \gets \text{cnt}'_x - \text{cnt}'_d$。

  $1 \sim m$ 中平均每个数含有 $O(\log{m})$ 个约数,所以估计时间复杂度为 $O(m \log{m} + n (\log{n} + \log^2{m}))$。这时间复杂度肯定是不准确的,但直觉上感觉不高毕竟一个数含有的约数数量还是很少的,麻烦知道的大佬在评论区留言 qwq。

  AC 代码如下:

#include <bits/stdc++.h>
using namespace std;

typedef long long LL;

const int N = 1e5 + 10;

int a[N];
vector<int> ds[N];
int cnt[N], bp[N];

void init() {
    for (int i = 1; i < N; i++) {
        for (int j = i; j < N; j += i) {
            ds[j].push_back(i);
        }
    }
    for (int i = 1; i < N; i++) {
        reverse(ds[i].begin(), ds[i].end());
    }
}

void solve() {
    int n;
    scanf("%d", &n);
    for (int i = 1; i <= n; i++) {
        scanf("%d", a + i);
    }
    sort(a + 1, a + n + 1);
    LL ret = 0;
    memset(cnt, 0, sizeof(cnt));
    for (int i = 1; i < n; i++) {
        for (auto &x : ds[a[i]]) {
            bp[x] = cnt[x];
        }
        for (auto &x : ds[a[i]]) {
            int t = bp[x];
            if (t) {
                ret += 1ll * x * t * (n - i);
                for (auto &y : ds[x]) {
                    bp[y] -= t;
                }
            }
        }
        for (auto &x : ds[a[i]]) {
            cnt[x]++;
        }
    }
    printf("%lld\n", ret);
}

int main() {
    init();
    int t;
    scanf("%d", &t);
    while (t--) {
        solve();
    }
    
    return 0;
}

  另外一种是官方题解给出的做法,按照三元组中最小的两个数的最大公约数进行分类,那么一共被分成 $m$ 类,$m = \max\limits_{1 \leq i \leq n}\{ a_i \}$。大致思路是求出最大公约数恰好为 $i$ 的三元组数量 $f(i)$,那么最终答案就是 $\sum\limits_{i=1}^{m}{i \times f(i)}$。

  关键在于如何求 $f(i)$,做法与 D. Counting Rhyme 类似,这题求的是所有最大公约数恰好为 $i$ 的二元组的数量。

  首先如果三元组的最大公约数要为 $d$,那么三元组中最小值和中间值必然是 $d$ 的倍数,我们将所有满足 $d \mid a_i$ 的下标 $i$ 筛出来,假设有 $k$ 个这样的下标。那么从这 $k$ 个下标中任意选择两个出来,这两个下标对应的数的最大公约数必然是 $d$ 的倍数。先求最大公约数是 $d$ 的倍数的三元组的数量。从小到大枚举这些下标,假设第 $i$ 个下标是 $p_i$,将 $a_{p_i}$ 作为中间值,前 $i-1$ 个下标的数作为最小值,最大值可选的数有 $n - p_i$ 个,则最大公约数是 $d$ 的倍数且中间值为 $a_{p_i}$ 的三元组的数量就是 $(i-1) \times (n - p_i)$,因此最大公约数是 $d$ 的倍数的三元组的数量就是 $s = \sum\limits_{i=2}^{k}{(i-1) \times (n - p_i)}$。

  可以知道最大公约数是 $d$ 的倍数的三元组数量本质是由最大公约数恰好为 $d$ 的三元组的数量、最大公约数恰好为 $2d$ 的三元组的数量、......构成的,即 $s = \sum\limits_{i=1}^{\left\lfloor \frac{m}{d} \right\rfloor }{f(i \cdot d)}$,从而推出 $f(d) = s - \sum\limits_{i=2}^{\left\lfloor \frac{m}{d} \right\rfloor }{f(i \cdot d)}$。对此我们可以从大到小倒着枚举 $d$ 来求 $f(d)$。

  AC 代码如下,时间复杂度为 $O(n \log n + m \log m)$:

#include <bits/stdc++.h>
using namespace std;

typedef long long LL;

const int N = 1e5 + 10;

int a[N];
vector<int> p[N];
LL f[N];

void solve() {
    int n, m = 0;
    scanf("%d", &n);
    for (int i = 1; i <= n; i++) {
        scanf("%d", a + i);
        m = max(m, a[i]);
    }
    sort(a + 1, a + n + 1);
    for (int i = 1; i <= m; i++) {
        p[i].clear();
    }
    for (int i = 1; i <= n; i++) {
        for (int j = 1; j * j <= a[i]; j++) {
            if (a[i] % j == 0) {
                p[j].push_back(i);
                if (a[i] / j != j) p[a[i] / j].push_back(i);
            }
        }
    }
    LL ret = 0;
    for (int i = m; i; i--) {
        f[i] = 0;
        for (int j = 1; j < p[i].size(); j++) {
            f[i] += 1ll * j * (n - p[i][j]);
        }
        for (int j = i + i; j <= m; j += i) {
            f[i] -= f[j];
        }
        ret += i * f[i];
    }
    printf("%lld\n", ret);
}

int main() {
    int t;
    scanf("%d", &t);
    while (t--) {
        solve();
    }
    
    return 0;
}

 

参考资料

  Codeforces Round 911 (Div. 2) Editorial:https://codeforces.com/blog/entry/122677