D. Cyclic MEX

发布时间 2023-12-18 20:00:17作者: onlyblues

D. Cyclic MEX

For an array $a$, define its cost as $\sum_{i=1}^{n} \operatorname{mex} ^\dagger ([a_1,a_2,\ldots,a_i])$.

You are given a permutation$^\ddagger$ $p$ of the set $\{0,1,2,\ldots,n-1\}$. Find the maximum cost across all cyclic shifts of $p$.

$^\dagger\operatorname{mex}([b_1,b_2,\ldots,b_m])$ is the smallest non-negative integer $x$ such that $x$ does not occur among $b_1,b_2,\ldots,b_m$.

$^\ddagger$A permutation of the set $\{0,1,2,...,n-1\}$ is an array consisting of $n$ distinct integers from $0$ to $n-1$ in arbitrary order. For example, $[1,2,0,4,3]$ is a permutation, but $[0,1,1]$ is not a permutation ($1$ appears twice in the array), and $[0,2,3]$ is also not a permutation ($n=3$ but there is $3$ in the array).

Input

Each test consists of multiple test cases. The first line contains a single integer $t$ ($1 \le t \le 10^5$) — the number of test cases. The description of the test cases follows.

The first line of each test case contains a single integer $n$ ($1 \le n \le 10^6$) — the length of the permutation $p$.

The second line of each test case contain $n$ distinct integers $p_1, p_2, \ldots, p_n$ ($0 \le p_i < n$) — the elements of the permutation $p$.

It is guaranteed that sum of $n$ over all test cases does not exceed $10^6$.

Output

For each test case, output a single integer — the maximum cost across all cyclic shifts of $p$.

Example

input

4
6
5 4 3 2 1 0
3
2 1 0
8
2 3 6 7 0 1 4 5
1
0

output

15
5
31
1

Note

In the first test case, the cyclic shift that yields the maximum cost is $[2,1,0,5,4,3]$ with cost $0+0+3+3+3+6=15$.

In the second test case, the cyclic shift that yields the maximum cost is $[0,2,1]$ with cost $1+1+3=5$.

 

解题思路

  定义 $f_i$ 表示前缀 $[1, i]$ 的 $\operatorname{mex}$ 值,容易知道 $f_i$ 是单调递增的。现在考虑循环左移一次,即把原序列的 $p_1$ 移到最后,看看每个原本的 $f_i$ 会有什么变化。

  首先原本的 $f_1$ 会被删除。再考虑 $f_2 \sim f_n$,如果 $f_i < p_1$,那么不会改变;如果 $f_i > p_1$,那么就会变成 $p_1$。将变化后的 $f_i$ 左移一个单位,并在最后添加 $n$,那么就会得到 $p$ 序列循环左移一个单位后的前缀 $\operatorname{mex}$ 值。

  下表是以 $p = [3 \; 1 \; 0 \; 2]$ 按照上述过程模拟的表格:

\begin{array}{|c|c|c|c|}
\hline 
p & p' & f & f'  \\
\hline
3 \; 1 \; 0 \; 2 & 1 \; 0 \; 2 \; 3 & 0 \; {\color{Red} {0 \; 2 \; 4}} & {\color{Red} {0 \; 2 \; 3}} \; 4 \\
\hline
1 \; 0 \; 2 \; 3 & 0 \; 2 \; 3 \; 1 & 0 \; {\color{Red} {2 \; 3 \; 4}} & {\color{Red} {1 \; 1 \; 1}} \; 4 \\
\hline
0 \; 2 \; 3 \; 1 & 2 \; 3 \; 1 \; 0 & 1 \; {\color{Red} {1 \; 1 \; 4}} & {\color{Red} {0 \; 0 \; 0}} \; 4 \\
\hline
2 \; 3 \; 1 \; 0 & 3 \; 1 \; 0 \; 2 & 0 \; {\color{Red} {0 \; 0 \; 4}} & {\color{Red} {0 \; 0 \; 2}} \; 4 \\
\hline
\end{array}

  代码实现只需用依次枚举 $p_i$ 然后用 std::deque 去模拟这个过程,对每种情况取 $\sum\limits_{i=1}^{n}{f_i}$ 的最大值即可。不过如果每次都从队尾开始枚举,把大于 $p_i$ 的值都改成 $p_i$,那么整个模拟的时间复杂度就会达到 $O(n^2)$。

  改进的方法是用队列存储每个值以及对应的个数,由于 $f_i$ 递增因此相同的值必然是连续的一段。每次从队尾开始枚举时,只需统计比 $p_i$ 大的值的总数 $c$,并将这些值从队列中删除,最后再把 $(p_i, c)$ 压入队尾,同时还要把 $(n, 1)$ 压入队尾。另外还要用一个变量来维护队列中的 $f_i$ 的总和。

  由于在模拟的过程中一共往队列中插入 $O(n)$ 个元素,因此删除操作执行的次数也是 $O(n)$ 级别的。

  AC 代码如下,时间复杂度为 $O(n)$:

#include <bits/stdc++.h>
using namespace std;

typedef long long LL;
typedef pair<int, int> PII;

const int N = 1e6 + 10;

int a[N];
bool vis[N];

void solve() {
    int n;
    scanf("%d", &n);
    for (int i = 0; i < n; i++) {
        scanf("%d", a + i);
    }
    memset(vis, 0, n + 10);
    LL s = 0;
    deque<PII> q;
    for (int i = 0, j = 0; i < n; i++) {
        vis[a[i]] = true;
        while (vis[j]) {
            j++;
        }
        q.push_back({j, 1});
        s += j;
    }
    LL ret = 0;
    for (int i = 0; i < n; i++) {
        s -= q.front().first;    // 把第一个元素删除
        if (--q.front().second == 0) q.pop_front();    // 这个值没有了
        int cnt = 0;    // 统计比a[i]大的数的个数
        while (!q.empty() && q.back().first > a[i]) {
            s -= 1ll * q.back().first * q.back().second;    // 将这些数从队列中删掉
            cnt += q.back().second;
            q.pop_back();
        }
        s += 1ll * a[i] * cnt + n;    // 这些数全部变成a[i]
        q.push_back({a[i], cnt});    // 并把a[i]及其个数插到队尾
        q.push_back({n, 1});    // 最后n插到队尾
        ret = max(ret, s);
    }
    printf("%lld\n", ret);
}

int main() {
    int t;
    scanf("%d", &t);
    while (t--) {
        solve();
    }
    
    return 0;
}

 

参考资料

  Codeforces Round 915 (Div. 2) Editorial:https://codeforces.com/blog/entry/123384