[CF1158F] Density of subarrays

发布时间 2023-08-25 17:55:34作者: 灰鲭鲨

Let $ c $ be some positive integer. Let's call an array $ a_1, a_2, \ldots, a_n $ of positive integers $ c $ -array, if for all $ i $ condition $ 1 \leq a_i \leq c $ is satisfied. Let's call $ c $ -array $ b_1, b_2, \ldots, b_k $ a subarray of $ c $ -array $ a_1, a_2, \ldots, a_n $ , if there exists such set of $ k $ indices $ 1 \leq i_1 < i_2 < \ldots < i_k \leq n $ that $ b_j = a_{i_j} $ for all $ 1 \leq j \leq k $ . Let's define density of $ c $ -array $ a_1, a_2, \ldots, a_n $ as maximal non-negative integer $ p $ , such that any $ c $ -array, that contains $ p $ numbers is a subarray of $ a_1, a_2, \ldots, a_n $ .

You are given a number $ c $ and some $ c $ -array $ a_1, a_2, \ldots, a_n $ . For all $ 0 \leq p \leq n $ find the number of sequences of indices $ 1 \leq i_1 < i_2 < \ldots < i_k \leq n $ for all $ 1 \leq k \leq n $ , such that density of array $ a_{i_1}, a_{i_2}, \ldots, a_{i_k} $ is equal to $ p $ . Find these numbers by modulo $ 998,244,353 $ , because they can be too large.
$ 1 \leq n, c \leq 3,000 $ $ 1 \leq a_i \leq c $

考虑如何算出一个序列的喜爱值。考虑现在到了第 \(i\) 个点,然后往后面每个不同的 美丽值第一次出现的地方去跳,最远能跳到哪就跳到哪,然后我们就得到了答案。发现最大喜爱值为 \(\frac nc\)

有一个显然的 dp,定义 \(dp_{i,j}\) 为现在选的最后一个数为 \(i\),目前的喜爱值为 \(j\),然后枚举喜爱值增加的位置即可。复杂度 \(\frac{n^3}c\),几乎没办法优化。

还有一种 dp 方式,定义 \(dp_{i,j,s}\) 为前 \(i\) 个点,喜爱值为 \(j\),选了的集合为 \(s\),有多少种方案,当 \(s\) 满了之后 \(j+1\)。复杂度 \(\frac {n^2}c2^c\)

考虑将两个平衡一下,当 \(c\le11\) 时跑方法2,否则跑方法1,复杂度为 \(\frac {n^3}{logn}\)

严重卡常,需要特判 \(c=1\)

#include<bits/stdc++.h>
#pragma GCC optimzie(2)
using namespace std;
const int N=3005,P=998244353,iv=499122177;
typedef long long LL;
int n,c,pw[N],ipw[N],s[N],a[N],ans[N],sg[N];
int read()
{
	int s=0;
	char ch=getchar();
	while(ch<'0'||ch>'9')
		ch=getchar();
	while(ch>='0'&&ch<='9')
		s=s*10+ch-48,ch=getchar();
	return s;
}
int pown(int x,int y)
{
	if(!y)
		return 1;
	int t=pown(x,y>>1);
	if(y&1)
		return 1LL*t*t%P*x%P;
	return 1LL*t*t%P;
}
struct solve2{
	short  ss[N][N],nx[N][N],pr[N];
	int dp[N][N],v[N][N];
	solve2()
	{
		for(int i=1;i<=c;i++)
			nx[n+1][i]=n+1;
		for(int i=n-1;~i;i--)
		{
			for(int j=1;j<=c;j++)
				nx[i][j]=nx[i+1][j];
			nx[i][a[i+1]]=i+1;
		}
		for(int i=1;i<=n;i++)
		{
			memcpy(ss[i],ss[i-1],sizeof(ss[0]));
			ss[i][a[i]]++;
		}
		int p=0;
		LL pp=1;
		memset(s,0,sizeof(s));
		sg[n]=0;
		for(int i=n-1;~i;--i)
		{
			if(!s[a[i+1]])
				++p;
			else
				pp=pp*ipw[s[a[i+1]]]%P*(pw[s[a[i+1]]+1]-1)%P;
			s[a[i+1]]++;
			if(p==c)
				sg[i]=pp;
		}
		for(int i=0;i<=n;i++)
		{
			for(int k=1;k<=c;k++)
				pr[i]=max(pr[i],nx[i][k]);
			if(pr[i]>n)
				break;
			LL p=1;
			for(int k=1;k<=c;k++)
				p=p*(pw[s[k]=ss[pr[i]][k]-ss[i][k]]-1)%P;
			for(int k=pr[i];k<=n;k++)
			{
				v[i][k]=p*ipw[s[a[k]]]%P;
				if(k^n)
					++s[a[k+1]],p=p*ipw[s[a[k+1]]-1]%P*(pw[s[a[k+1]]]-1)%P;
			}
		}
		dp[0][0]=1;
		for(int i=0;i<=n;i++)
		{
			for(int j=0;j<=n/c;j++)
			{
				if(!dp[i][j])
					continue;
				for(int k=pr[i];k<=n;k++)
					(dp[k][j+1]+=dp[i][j]*1LL*v[i][k]%P)%=P;
				(ans[j]+=dp[i][j]*1LL*(pw[n-i]+P-sg[i])%P)%=P;
			}
		}
		for(int i=0;i<=n;i++)
			printf("%d ",ans[i]-(!i));
	}
};
struct solve1{
	int C[N][N];
	solve1()
	{
		for(int i=C[0][0]=1;i<=n;i++)
		{
			C[i][0]=C[i][i]=1;
			for(int j=1;j<i;j++)
				C[i][j]=(C[i-1][j]+C[i-1][j-1])%P;
		}
		printf("0 ");
		for(int i=1;i<=n;i++)
			printf("%d ",C[n][i]);
	}
};
struct solve3{
	int dp[2][N/2][8192];
	solve3(){
		dp[0][0][0]=1;
		for(int i=1;i<=n;i++)
		{
			for(int j=0;j<=i/c;j++)
			{
				for(int k=0;k<(1<<c);k++)
				{
					dp[i&1][j][k]=dp[i&1^1][j][k];
					if(!k&&j)
					{
						(dp[i&1][j][0]+=dp[i&1][j-1][(1<<c)-1])%=P;
						dp[i&1][j-1][(1<<c)-1]=0;
					}
					if(k>>(a[i]-1)&1)
						(dp[i&1][j][k]+=(dp[i&1^1][j][k]+dp[i&1^1][j][k^1<<a[i]-1])%P)%=P;
				}
				if(i==n)
				{
					for(int k=0;k+1<(1<<c);k++)
						(ans[j]+=dp[i&1][j][k])%=P;
				}
			}
		}
		for(int j=0;j<=n;j++)
			printf("%d ",ans[j]-(!j));
	} 
};
int main()
{
//	freopen("crixis4.in","r",stdin);
//	freopen("crixis4.out","w",stdout); 
	n=read(),c=read();
	for(int i=ipw[0]=pw[0]=1;i<=n;i++)
	{
		pw[i]=(pw[i-1]<<1)%P;
		ipw[i]=pown(pw[i]-1,P-2);
	}
	for(int i=1;i<=n;i++)
		a[i]=read();
	if(c==1)
	{
		solve1();
		return 0;
	}
	if(c>11)
	{
		solve2();
		return 0;
	}
	solve3();
}