[ARC111F] Do you like query problems?

发布时间 2023-12-17 17:28:02作者: 小超手123

题意:

给出三个数 \(n,m,q\)

你有一个长度为 \(n\) 的序列 \(a\),初始全为为 \(0\),你有三种操作:
操作 \(1\):给出 \(l,r,v\),让区间 \([l,r]\)\(v\)\(\min\)
操作 \(2\):给出 \(l,r,v\),让区间 \([l,r]\)\(v\)\(\max\)
操作 \(3\),给出 \(l,r\),求区间和,将其累加进一个叫 \(sum\) 的变量里。

你并不需要维护这个数据结构,而是统计一共有 \(q\) 个操作的情况下,所有不同的操作序列中 \(3\) 操作得到的 \(sum\) 的总和,对 \(998244353\) 取模。你需要保证 \(v\in[0,m-1]\)

分析:

img

显然的,输入的数据最多只可能有 \(((2m+1) \times \frac{n \times (n+1)}{2})^q\) 种。

对于这种问题,可以通过计算期望使得计算更简便。

不妨考虑计算每个位置对答案的贡献。

记一个 \(P(t,i)\) 表示经过 \(t\)修改操作最后变成 \(i\) 的概率。

考虑 dp 计算,如果要修改 \(i\),能改变 \(i\) 的只会是 \(\min(i,x)(x<i)\) 以及 \(\max(i,x)(x>i)\)

易得:

\[\begin{aligned} P(t,i) &=\frac{P(t-1,i) \times (n+1)+\sum_{j=0(j \ne i)}^{m} P(t-1,j)}{2 \times m} \\&=\frac{1}{2m}+\frac{P(t-1,i)}{2} \\&=\frac{1}{m}-\frac{1}{m \times 2^t}(i \ne 0) \end{aligned} \]

然后再记一个 \(E(t)\) 表示一个数经过 \(t\) 次修改得期望值。

\[\begin{aligned} E(t) &=\sum_{i=1}^{m-1}i \times P(t,i) \\&=(\frac{1}{m}-\frac{1}{m \times 2^t}) \times \frac{m \times (m-1)}{2} \\&=(1-\frac{1}{2^t})\frac{m-1}{2} \end{aligned} \]

但不可能每次操作 \([l,r]\) 都包含这个数吧(滑稽

因此需要记一个 \(z_i\) 表示 \(i\) 被包含得概率。显然 \(z(i)=\frac{i \times (n-i+1)}{\frac{n \times (n+1)}{2}}\)

于是我们就把 \(P(t,i)\) 升级到 \(w(t,i)\) 表示 \(i\) 经过 \(t\) 次全局操作得期望值:

\[\begin{aligned} w(t,i) &= \sum_{j=0}^{t}C_{t}^{j} \times z_{i}^{j} \times (1-z_{i}^{t-j})E(t) \\&=\frac{m-1}{2}(1-(1-\frac{z_i}{2})^t) \end{aligned} \]

\(T_i=1-\frac{w_i}{2}\)

\(j\) 表示操作到 \(i\) 头上得次数,最后一步用了二项式定理推导。

再加入一波修改操作 \(g(t,i)\) 表示 \(i\) 经过 \(t\) 次操作(包含查询)的期望值:

\[g(t,i)=\frac{m-1}{2}(1-(\frac{2mT_i+1}{2m+1})^t) \]

那么最后 \(x\) 的期望值为

\[\begin{aligned} E(x) &= \sum_{i=1}^{n} \frac{z_i}{2m+1}\sum_{j=1}^{q}g(j-1,i) \\&= \frac{z_i}{2m+1}\frac{m-1}{2}\sum_{j=1}^{q}(1-(\frac{2mT_i+1}{2m+1})^{j-1}) \end{aligned} \]

显然可以用等比数列求和快速计算。

最后答案就是

\[E(x) \times ((2m+1) \times \frac{n \times (n+1)}{2})^q \]

完结撒花!

img

代码:

#include<bits/stdc++.h>
#define int long long
#define mod 998244353
using namespace std;
int Pow(int a, int n) {
	if(n == 0) return 1;
	if(n == 1) return a % mod;
	int x = Pow(a, n / 2);
	if(n % 2 == 0) return x * x % mod;
	else return x * x % mod * a % mod;
}
int read() {
	char ch = getchar(); int x = 0, f = 1;
	while(ch < '0' || ch > '9') { if(ch == '-') f = -1; ch = getchar(); }
	while(ch >= '0' && ch <= '9') {
		x = x * 10 + ch - '0';
		ch = getchar();
	}
	return x * f;
}
void write(int x) {
	if(x < 0) putchar('-'), x = -x;
	if(x > 9) write(x / 10);
	putchar('0' + x % 10);
}
int inv(int x) {
	return Pow(x, mod - 2);
}
int n, m, q, ans;
signed main() {
	cin >> n >> m >> q;
	for(int i = 1; i <= n; i++) {
		int p = i * (n - i + 1) % mod * inv((n + 1) * n / 2 % mod) % mod;
		int P = (1 - p * inv(2) % mod + mod) % mod;
		int z = (2 * m % mod * P % mod + 1) % mod * inv(2 * m + 1) % mod;
		int S = (Pow(z, q) - 1 + mod) % mod * inv(z - 1) % mod;	
		ans = (ans + p * inv(2 * m + 1) % mod * (m - 1) % mod * inv(2) % mod * ((q - S + mod) % mod) % mod) % mod;
	}
	cout << ans * Pow((2 * m + 1) % mod * n % mod * (n + 1) % mod * inv(2) % mod, q) % mod;
	return 0;
}