P5161 WD与数列

发布时间 2024-01-09 15:10:41作者: zltzlt

洛谷传送门

考虑两个 \(\text{lcs}\)\(t\) 的前缀 \([1, i]\)\([1, j]\)。我们发现可能的左端点取值为 \(\min(|i - j| - 1, t)\)

考虑建出 SAM。那么两点的 \(\text{lca}\)\(\text{len}\) 就是它们的 \(\text{lcs}\)。枚举这个 \(\text{lca}\)。那么相当于先考虑一棵子树的所有 \(\text{endpos}\) 与之前的产生的贡献,然后再加入这棵子树的 \(\text{endpos}\)

可以使用线段树合并 + 启发式合并或 DSU on tree + BIT。不难发现两个 \(\text{endpos}\) 的贡献即上文的 \(\min(|i - j| - 1, t)\) 是分段一次函数形式。线段树维护 \(\text{endpos}\) 和及个数即可。

时间复杂度 \(O(n \log^2 n)\),空间复杂度 \(O(n \log n)\)

code
// Problem: P5161 WD与数列
// Contest: Luogu
// URL: https://www.luogu.com.cn/problem/P5161
// Memory Limit: 500 MB
// Time Limit: 1000 ms
// 
// Powered by CP Editor (https://cpeditor.org)

#include <bits/stdc++.h>
#define pb emplace_back
#define fst first
#define scd second
#define mkp make_pair
#define mems(a, x) memset((a), (x), sizeof(a))

using namespace std;
typedef long long ll;
typedef double db;
typedef unsigned long long ull;
typedef long double ldb;
typedef pair<ll, ll> pii;

bool Mst;

const int maxn = 600100;

ll n, a[maxn];
int rt[maxn], head[maxn], len, to[maxn], nxt[maxn], p[maxn];
vector<int> pos[maxn];

inline void add_edge(int u, int v) {
	to[++len] = v;
	nxt[len] = head[u];
	head[u] = len;
}

struct node {
	ll x, y;
	node(ll a = 0, ll b = 0) : x(a), y(b) {}
};

inline node operator + (const node &a, const node &b) {
	return node(a.x + b.x, a.y + b.y);
}

namespace SGT {
	node a[maxn * 20];
	int nt, ls[maxn * 20], rs[maxn * 20];
	
	void update(int &rt, int l, int r, int x) {
		if (!rt) {
			rt = ++nt;
		}
		a[rt].x += x;
		++a[rt].y;
		if (l == r) {
			return;
		}
		int mid = (l + r) >> 1;
		(x <= mid) ? update(ls[rt], l, mid, x) : update(rs[rt], mid + 1, r, x);
	}
	
	node query(int rt, int l, int r, int ql, int qr) {
		if (!rt || ql > qr) {
			return node(0, 0);
		}
		if (ql <= l && r <= qr) {
			return a[rt];
		}
		int mid = (l + r) >> 1;
		node res(0, 0);
		if (ql <= mid) {
			res = res + query(ls[rt], l, mid, ql, qr);
		}
		if (qr > mid) {
			res = res + query(rs[rt], mid + 1, r, ql, qr);
		}
		return res;
	}
	
	int merge(int u, int v) {
		if (!u || !v) {
			return u | v;
		}
		a[u] = a[u] + a[v];
		ls[u] = merge(ls[u], ls[v]);
		rs[u] = merge(rs[u], rs[v]);
		return u;
	}
}

struct SAM {
	int lst, tot, fa[maxn], len[maxn];
	map<int, int> ch[maxn];
	
	inline void init() {
		for (int i = 1; i <= tot; ++i) {
			fa[i] = len[i] = 0;
			ch[i].clear();
		}
		lst = tot = 1;
	}
	
	inline void insert(int c) {
		int u = ++tot, p = lst;
		len[u] = len[p] + 1;
		lst = u;
		for (; p && ch[p].find(c) == ch[p].end(); p = fa[p]) {
			ch[p][c] = u;
		}
		if (!p) {
			fa[u] = 1;
			return;
		}
		int q = ch[p][c];
		if (len[q] == len[p] + 1) {
			fa[u] = q;
			return;
		}
		int nq = ++tot;
		fa[nq] = fa[q];
		ch[nq] = ch[q];
		len[nq] = len[p] + 1;
		fa[u] = fa[q] = nq;
		for (; p && ch[p][c] == q; p = fa[p]) {
			ch[p][c] = nq;
		}
	}
} sam;

void solve() {
	scanf("%lld", &n);
	for (int i = 1; i <= n; ++i) {
		scanf("%lld", &a[i]);
	}
	ll ans = n * (n - 1) / 2;
	--n;
	sam.init();
	for (int i = 1; i <= n; ++i) {
		a[i] -= a[i + 1];
		sam.insert(a[i]);
		SGT::update(rt[sam.lst], 1, n, i);
		pos[sam.lst].pb(i);
	}
	for (int i = 2; i <= sam.tot; ++i) {
		add_edge(sam.fa[i], i);
	}
	for (int i = 1; i <= sam.tot; ++i) {
		p[i] = i;
	}
	sort(p + 1, p + sam.tot + 1, [&](const int &x, const int &y) {
		return sam.len[x] > sam.len[y];
	});
	for (int _ = 1; _ <= sam.tot; ++_) {
		int u = p[_];
		for (int i = head[u]; i; i = nxt[i]) {
			int v = to[i];
			if (pos[u].size() < pos[v].size()) {
				swap(pos[u], pos[v]);
				swap(rt[u], rt[v]);
			}
			// printf("u, v: %d %d\n", u, v);
			for (int x : pos[v]) {
				// printf("x: %d\n", x);
				node res = SGT::query(rt[u], 1, n, 1, x - sam.len[u] - 1);
				ans += sam.len[u] * res.y;
				res = SGT::query(rt[u], 1, n, x - sam.len[u], x - 1);
				ans += (x - 1) * res.y - res.x;
				res = SGT::query(rt[u], 1, n, x + 1, x + sam.len[u]);
				ans += res.x - (x + 1) * res.y;
				res = SGT::query(rt[u], 1, n, x + sam.len[u] + 1, n);
				ans += sam.len[u] * res.y;
				pos[u].pb(x);
			}
			// printf("ans: %lld\n", ans);
			vector<int>().swap(pos[v]);
			rt[u] = SGT::merge(rt[u], rt[v]);
		}
	}
	printf("%lld\n", ans);
}

bool Med;

int main() {
	fprintf(stderr, "%.2lf MB\n", (&Mst - &Med) / 1048576.);
	int T = 1;
	// scanf("%d", &T);
	while (T--) {
		solve();
	}
	return 0;
}