CF1585F Non-equal Neighbours - 容斥 - dp - 单调栈

发布时间 2023-07-10 19:13:17作者: SkyRainWind

题目链接:https://codeforces.com/problemset/problem/1585/F

题解:
难难难
考虑容斥:设 \(A_i\) 表示 \(b_i \neq b_{i+1}\) (\(i=1,2,\cdots,n-1\)) 时对应的 \(\{b_i\}\) 方案的答案
那么答案就是 $$\bigcap_{{i=1}}^{n}A_{i} = |U|-\left|\bigcup_{i=1}^n\overline{A_i}\right|$$
后者可以用容斥原理化简。也就是这个式子:

\[\begin{aligned}\left|\bigcup _{i=1}^{n}A_{i}\right|={};\sum _{i=1}^{n}|A_{i}|-\sum _{1\leq i;j\leq n}|A_{i}\cap A_{j}|+\sum _{1\leq i;j;k\leq n}|A_{i}\cap A_{j}\cap A_{k}|-\cdots +(-1)^{n-1}\left|A_{1}\cap \cdots \cap A_{n}\right|.\end{aligned} \]

考虑这个过程的意义:\(|\overline{A_i}|\) 的含义就是我钦定了 \(b_i=b_{i+1}\),然后剩下的随便选(也就是没有限制了),\(|\overline{A_i}\cup\overline{A_j}|\) 的含义就是钦定了 \(b_i=b_{i+1}\and b_j=b_{j+1}\),然后剩下的没有限制。
那这个贡献如何算呢?我们发现,如果有 \(b_i=b_{i+1}=b_{i+2}\) 的话,那么这段的贡献就是三个数的 min,再和别的部分相乘。
可以将每一个相同的连续部分看成一“段”,如果已经钦定了 \(k\) 个连续的部分,那相当于有 \(n-k\) 个“段”,考虑对段的贡献进行 dp。
\(dp_{i,j}\) 表示考虑到第 \(i\) 个位置,当前划分了 \(j\) 个段的贡献,那么显然 \(dp_{i,j} \leftarrow dp_{k,j-1}\times min(a_{k+1}..a_i)\)
注意这里“段”的意义是我“钦定”的段。一个例子:我可以将 \(1..5\) 划分为 \([1,3], [4,4], [5,5]\),段内的元素一定是相同的,这是我钦定的,但是段间的元素也可能相同,这是我钦定之后没有其它限制条件的结果。例如 \(a_4=a_5\) 是合法的。
然后由于容斥系数的正负只和 \(j\) 的奇偶性有关,可以变成 \(O(n^2)\),再利用单调栈记录一下上一个比当前位置小的位置,写出来转移发现可以化简。后面的部分和这篇博客差别不大。

代码:

// by SkyRainWind
#include <bits/stdc++.h>
#define mpr make_pair
#define debug() cerr<<"Yoshino\n"
#define pii pair<int,int>
#define pb push_back

using namespace std;

typedef long long ll;
typedef long long LL;

const int inf = 1e9, INF = 0x3f3f3f3f, maxn = 2e5+5, mod =998244353;

int n,a[maxn];
int dp[maxn][2], pre[maxn][2];
stack<int>stk;

signed main(){
	scanf("%d",&n);
	for(int i=1;i<=n;i++)scanf("%d",&a[i]);
	dp[0][0] = pre[0][0] = 1;
	for(int i=1;i<=n;i++){
		while(!stk.empty() && a[stk.top()] >= a[i])stk.pop();
		if(stk.empty()){
			dp[i][0] = 1ll * pre[i-1][1] * a[i] % mod;
			dp[i][1] = 1ll * pre[i-1][0] * a[i] % mod;
		}else{
			int p = stk.top();
			for(int j=0;j<=1;j++){
				(dp[i][j] = 1ll*dp[p][j] + 1ll*(pre[i-1][j^1]-pre[p-1][j^1]+mod)*a[i]%mod)%=mod;
			}
		}
		for(int j=0;j<=1;j++)
			(pre[i][j] = pre[i-1][j] + dp[i][j])%=mod;
		stk.push(i);
	}
	int sgn = (n&1) ? -1 : 1;
	printf("%d\n",((dp[n][0]-dp[n][1])*sgn+mod)%mod);

	return 0;
}