[国家集训队\] 等差子序列 题解

发布时间 2023-08-22 16:36:34作者: MoyouSayuki

P2757 [国家集训队] 等差子序列 题解

首先简化题目之后,发现如果序列之中存在一个长度 \(\ge 3\) 的子序列,一定存在一个子序列的子序列长度为 \(3\),所以只需要统计有没有长度为 \(3\) 的子序列即可。

对于这种长度为 \(3\) 的统计问题,第一想法就是从中间的那个点入手,由于排列的性质,每一个元素都不重不漏的出现,所以如果一个元素左边没有 \(a\in[1,N]\),那么他的右边一定会出现 \(a\),且仅有一个。

所以考虑枚举每一个中间点 \(a_i\),题目变为了判断是否左边仅存在 \(a_i + k\)\(a_i - k\) 的其中之一,这个问题可以这么解决:

维护一个初始全为 \(0\) 的序列 \(b\),每当我们扫到了元素 \(x\)\(b_x\gets 1\), 经过观察发现,上述问题就转化为了左边的区间是否与右边的区间的翻转有至少一个不相同的元素,相当于是判一个回文串,但是这题显然不能用离线的 \(\text{manacher}\),所以考虑用万能的字符串哈希,把 \(b\) 序列正反的哈希算出来,只需要判断左区间的正哈希是否与右区间的反哈希相等即可,解决了询问,如果带修怎么办。

可以考虑用一棵线段树维护区间的哈希值,单点修改之后用 up() 函数重新维护线段树即可。查询的时候需要注意,因为查询的区间不一定是儿子代表的区间,所以可以在返回的时候直接乘上对应的位数即可。

时间复杂度:\(O(n\log n)\)

// Problem: P2757 [国家集训队] 等差子序列
// Contest: Luogu
// URL: https://www.luogu.com.cn/problem/P2757
// Author: Moyou
// Copyright (c) 2023 Moyou All rights reserved.
// Date: 2023-08-22 11:29:52

#include <iostream>
#include <cstring>
#include <algorithm>
#include <queue>
#include <set>
#include <map>
#define int unsigned long long
#define x first
#define y second
using namespace std;
typedef pair<int, int> PII;
typedef long long ll;
const int N = 5e5 + 10, base = 13331, yes = 1, no = 2;

int n, a[N], b[N], pl[N];
struct qwq {
    int l, r, hl, hr;
} tr[N << 2];
void up(int u) {
    tr[u].hl = tr[u << 1].hl + tr[u << 1 | 1].hl * pl[(tr[u << 1].r - tr[u << 1].l + 1)];
    tr[u].hr = tr[u << 1].hr * pl[tr[u << 1 | 1].r - tr[u << 1 | 1].l + 1] + tr[u << 1 | 1].hr;
}
void build(int u, int l, int r) {
    tr[u] = {l, r, no, no};
    if(l == r) return (void)(b[l] = no);
    build(u << 1, l, l + r >> 1), build(u << 1 | 1, (l + r >> 1) + 1, r);
    up(u);
}
void update(int u, int x) {
    int l = tr[u].l, r = tr[u].r, mid = l + r >> 1;
    if(l == r) {
        b[l] = yes;
        tr[u].hl = tr[u].hr = yes;
        return ;
    }
    if(x <= mid) update(u << 1, x);
    else update(u << 1 | 1, x);
    up(u);
}

int query(int u, int ql, int qr, int op) {
    int l = tr[u].l, r = tr[u].r, mid = l + r >> 1;
    if(ql <= l && qr >= r) return (op ? (tr[u].hr * pl[qr - r]) : tr[u].hl * pl[l - ql]);
    int t = 0;
    if(ql <= mid) t = query(u << 1, ql, qr, op);
    if(qr > mid)  t += query(u << 1 | 1, ql, qr, op);
    return t;
}

void work() {
    cin >> n;
    for(int i = 1; i <= n; i ++) cin >> a[i];
    build(1, 1, n);
    for(int i = 1; i <= n; i ++) {
        update(1, a[i]);
        int d = min(a[i] - 1, n - a[i]);
        if(query(1, a[i] - d, a[i] - 1, 0) != query(1, a[i] + 1, a[i] + d, 1))
            return (void)(cout << "Y\n");
    }
    cout << "N\n";
}

signed main() {
    ios::sync_with_stdio(0), cin.tie(0), cout.tie(0);
    int T;
    cin >> T;
    pl[0] = 1;
    for(int i = 1; i <= N - 10; i ++) pl[i] = pl[i - 1] * base;
    while(T --) work();
    
    return 0;
}