G2. Magic Triples (Hard Version)

发布时间 2023-04-26 16:01:50作者: onlyblues

G2. Magic Triples (Hard Version)

This is the hard version of the problem. The only difference is that in this version, $a_i \le 10^9$.

For a given sequence of $n$ integers $a$, a triple $(i, j, k)$ is called magic if:

  • $1 \le i, j, k \le n$.
  • $i$, $j$, $k$ are pairwise distinct.
  • there exists a positive integer $b$ such that $a_i \cdot b = a_j$ and $a_j \cdot b = a_k$.

Kolya received a sequence of integers $a$ as a gift and now wants to count the number of magic triples for it. Help him with this task!

Note that there are no constraints on the order of integers $i$, $j$ and $k$.

Input

The first line contains a single integer $t$ ($1 \le t \le 10^4$) — the number of test cases. The description of the test cases follows.

The first line of the test case contains a single integer $n$ ($3 \le n \le 2 \cdot 10^5$) — the length of the sequence.

The second line of the test contains $n$ integers $a_1, a_2, a_3, \dots, a_n$ ($1 \le a_i \le 10^9$) — the elements of the sequence $a$.

The sum of $n$ over all test cases does not exceed $2 \cdot 10^5$.

Output

For each test case, output a single integer — the number of magic triples for the sequence $a$.

Example

input

7
5
1 7 7 2 7
3
6 2 18
9
1 2 3 4 5 6 7 8 9
4
1000 993 986 179
7
1 10 100 1000 10000 100000 1000000
8
1 1 2 2 4 4 8 8
9
1 1 1 2 2 2 4 4 4

output

6
1
3
0
9
16
45

Note

In the first example, there are $6$ magic triples for the sequence $a$ — $(2, 3, 5)$, $(2, 5, 3)$, $(3, 2, 5)$, $(3, 5, 2)$, $(5, 2, 3)$, $(5, 3, 2)$.

In the second example, there is a single magic triple for the sequence $a$ — $(2, 1, 3)$.

 

解题思路

  先给出G1. Magic Triples (Easy Version)的做法。

  暴力的做法还是很容易想到的,我们枚举三元组$(i,j,k)$中的$i$,然后$b$从$1$开始枚举,看一下数组中是否存在$a_i \cdot b$和$a_i \cdot b^2$。因此还需要先开个哈希表统计数组中每个元素出现的次数,记作$\text{cnt}[x]$,表示$x$在数组中出现了$\text{cnt}[x]$次。

  如果$b=1$,那么有$a_i = a_j = a_k$,因此如果$\text{cnt}[a_i] \geq 3$,那么根据乘法原理,值均为$a_i$的三元组数量就是$\text{cnt}[a_i] \times (\text{cnt}[a_i]-1) \times (\text{cnt}[a_i] - 2)$。因此可以根据元素的不同种类来分别计算答案。为了方便,对于$b=1$的情况,我们可以枚举每一个元素,然后求$\sum\limits_{i=1}^{n}{(\text{cnt}[a_i] - 1) \times (\text{cnt}[a_i] - 2)}$,得到的结果与前一种方法是一样的,注意到同一类元素的数量为$\text{cnt}[a_i] \times \left( {(\text{cnt}[a_i]-1) \cdot (\text{cnt}[a_i] - 2)} \right)$,就是把$\text{cnt}[a_i]$分解成若干个$1$累加而已。

  如果$b \geq 2$,我们要保证$a_i \cdot b^2 \leq M$,这里的$M = {10}^6$。即$b \le \sqrt{M /a_i} \le \sqrt{M}$,因此$b$最大枚举到$\sqrt{M}$。然后根据乘法原理,满足条件的三元组数量就是$\text{cnt}[a_i\cdot b] \times \text{cnt}[a_i\cdot b^2]$。

  然后比较坑的地方是由于是多组测试数据,因此每次都 memset 一个值域数组肯定会超时的,而用 std::unordered_map 肯定会被卡,用 std::map 时间复杂度就达到$O(\sqrt{M}\cdot n \log{n})$也有可能会超时,就很难办了。比赛的时候在这个地方卡了很久,最后还是交了 std::unordered_map 的做法,当然也不出意外的fst了。

  其实要用数组实现哈希表也非常的简单,一般情况下都是直接对整个数组清零,实际上在一组数据中有很多地方都没有用到,如果直接全部清零很明显浪费了很多的时间。做法是在跑完一组数据后,只对用过的地方清零就好了。在这题中做法就是令一组数据中所有的$\text{cnt}[a_i]=0$。

  AC代码如下,时间复杂度为$O(\sqrt{M}\cdot n)$:

 1 #include <bits/stdc++.h>
 2 using namespace std;
 3 
 4 typedef long long LL;
 5 
 6 const int N = 2e5 + 10, M = 1e6 + 10;
 7 
 8 int n;
 9 int a[N];
10 int cnt[M];
11 
12 void reset() {
13     for (int i = 0; i < n; i++) {
14         cnt[a[i]] = 0;
15     }
16 }
17 
18 void solve() {
19     scanf("%d", &n);
20     for (int i = 0; i < n; i++) {
21         scanf("%d", a + i);
22         cnt[a[i]]++;
23     }
24     LL ret = 0;
25     for (int i = 0; i < n; i++) {
26         if (cnt[a[i]] >= 3) ret += (cnt[a[i]] - 1ll) * (cnt[a[i]] - 2);
27     }
28     for (int i = 0; i < n; i++) {
29         int t = 1000000 / a[i];
30         for (int j = 2; j <= t / j; j++) {
31             ret += 1ll * cnt[a[i] * j] * cnt[a[i] * j * j];
32         }
33     }
34     printf("%lld\n", ret);
35     reset();
36 }
37 
38 int main() {
39     int t;
40     scanf("%d", &t);
41     while (t--) {
42         solve();
43     }
44     
45     return 0;
46 }

  然后就是Hard版本了,$M$扩大到了${10}^9$,很明显上面的做法已经不适用了。上面的做法是枚举三元组中的$i$,这里的做法是枚举中间的元素$j$(想不到就真的做不出来了)。然后更妙的是还要把$a_j$分成两种情况,即$a_j \ge M ^ \frac{2}{3}$和$a_j < M ^ \frac{2}{3}$。

  首先对于对于$b=1$的情况做法与上面的一样,下面来讨论$b \geq 2$的情况。

  如果$a_j \ge M ^ \frac{2}{3}$,那么很明显对于$a_k = a_j \cdot b$,应该有$a_j \cdot b \leq M$,即$b \leq M / a_j \leq M^\frac{1}{3}$,因此$b$最大枚举到$M^\frac{1}{3}$。最后如果满足$a_j \bmod b = 0$,那么满足条件的三元组数量就是$\text{cnt}[a_j / b] \times \text{cnt}[a_j \cdot b]$。

  如果$a_j < M ^ \frac{2}{3}$,因为有$a_i \cdot b = a_j$,因此$b$必然是$a_j$的一个约数,意味着我们可以枚举出$a_j$的所有约数$d$,如果满足$a_j \cdot d \leq M$,那么满足条件的三元组数量就是$\text{cnt}[a_j / d] \times \text{cnt}[a_j \cdot d]$。其中分解约数的时间复杂度为$O(M^\frac{1}{3})$。

  因此整个做法的时间复杂度就是$O(n\cdot M^\frac{1}{3})$。

  再补充一下debug记录。这里我是直接手写哈希表来实现,STL是真不敢用了。然后呢我把哈希表开到了${10}^6+3$的大小,结果T麻了,我百思不得其解。然后我试着把哈希表大小开到${10}^7+19$就过了。这是因为表明上看起来只用映射$2 \cdot {10}^5$的数据量,但实际上还有$a_i \cdot b$和$a_i / b$这些数据,因为还是需要在哈希表中查找的,如果哈希表比较小那么哈希冲突就会很明显。如果遇到数据$a_i = i$,那么在哈希查找的过程中就会TLE。因此可以尝试把哈希表大小多开几倍,但也不能开太大,一方面是有空间限制,另一方面是如果大小过大,那么在内存中申请空间所需要的时间也会变大,也是有可能会TLE的。

  AC代码如下:

 1 #include <bits/stdc++.h>
 2 using namespace std;
 3 
 4 typedef long long LL;
 5 
 6 const int N = 2e5 + 10, M = 1e7 + 19;
 7 
 8 int n;
 9 int a[N];
10 int h[M], cnt[M];
11 
12 int find(int x) {
13     int k = x % M;
14     while (h[k] && h[k] != x) {
15         if (++k == M) k = 0;
16     }
17     return k;
18 }
19 
20 void reset() {  // 只把用过的位置清零
21     for (int i = 0; i < n; i++) {
22         a[i] = find(a[i]);
23     }
24     for (int i = 0; i < n; i++) {
25         h[a[i]] = cnt[a[i]] = 0;
26     }
27 }
28 
29 void solve() {
30     scanf("%d", &n);
31     for (int i = 0; i < n; i++) {
32         scanf("%d", a + i);
33         int t = find(a[i]); // 把a[i]映射到t
34         h[t] = a[i];
35         cnt[t]++;
36     }
37     LL ret = 0;
38     for (int i = 0; i < n; i++) {   // b=1的情况
39         int t = find(a[i]);
40         if (cnt[t] >= 3) ret += (cnt[t] - 1ll) * (cnt[t] - 2);
41     }
42     for (int i = 0; i < n; i++) {   // b>=2的情况
43         if (a[i] >= 1000000) {  // a[j]>=M^{2/3}的情况,暴力枚举b
44             int t = 1000000000ll / a[i];
45             for (int j = 2; j <= t; j++) {  // b最大枚举到M^{1/3}
46                 if (a[i] % j == 0) {    // a[j] mod b 要等于0,这样才有a[i]
47                     int t1 = find(a[i] / j), t2 = find(a[i] * j);
48                     if (h[t1] && h[t2]) ret += 1ll * cnt[t1] * cnt[t2];
49                 }
50             }
51         }
52         else {
53             for (int j = 1; j <= a[i] / j; j++) {   // a[j]<M^{2/3}的情况,分解约数
54                 if (a[i] % j == 0) {
55                     if (j > 1 && a[i] <= 1000000000ll / j) {    // 约数不能为1,且a[k]没有超过M
56                         int t1 = find(a[i] / j), t2 = find(a[i] * j);
57                         if (h[t1] && h[t2]) ret += 1ll * cnt[t1] * cnt[t2];
58                     }
59                     if (a[i] / j != j && a[i] <= 1000000000ll / a[i] * j) {
60                         int t1 = find(j), t2 = find(a[i] / j * a[i]);
61                         if (h[t1] && h[t2]) ret += 1ll * cnt[t1] * cnt[t2];
62                     }
63                 }
64             }
65         }
66     }
67     printf("%lld\n", ret);
68     reset();
69 }
70 
71 int main() {
72     int t;
73     scanf("%d", &t);
74     while (t--) {
75         solve();
76     }
77     
78     return 0;
79 }

 

参考资料

  Codeforces Round #867 (Div. 3) Editorial:https://codeforces.com/blog/entry/115409