sol. ABC246Ex

发布时间 2023-08-18 13:59:55作者: WTR2007

动态 DP 模板题

[ABC246Ex] 01? Queries

题目大意:给定一个长度为 \(N\) 且只包含 ?,1,0 的字符串 \(a\)\(Q\) 次操作,每次操作会修改字符串中的一个字符,并求出此时整个字符串中本质不同的子序列个数。


众所周知,动态 DP 依然是 DP 的一种。

先考虑没有修改操作,那么本题变为一个普通的 DP,我们设 \(dp_{i,0/1}\) 分别表示到第 \(i\) 位,以 \(0/1\) 结尾的本质不同的子序列个数。

考虑转移,以 \(a_i\)0 为例,那么在 \(1\)\(i-1\) 之间的所有本质不同的子序列后都可以加第 \(i\) 位的这个 \(0\),这些串依旧本质不同且长度不小于 \(2\),再加上 0 自己本身这种串,又因为 \(a_i\)\(0\),不可能形成更多的以 \(1\) 结尾的本质不同子序列,所以:

\(a_i\)0 时,

  • \(dp_{i,0}=dp_{i-1,0}+(dp_{i-1,1}+1)\)

  • \(dp_{i,1}=dp_{i-1,1}\)

\(a_i\)1? 的转移同理:

\(a_i\)1 时,

  • \(dp_{i,0}=dp_{i-1,0}\)

  • \(dp_{i,1}=dp_{i-1,1}+(dp_{i-1,0}+1)\)

\(a_i\)? 时,

  • \(dp_{i,0}=dp_{i-1,0}+(dp_{i-1,1}+1)\)

  • \(dp_{i,1}=dp_{i-1,1}+(dp_{i-1,0}+1)\)

最终答案为 \(dp_{n,0}+dp_{n,1}\)

如果对于每次修改,重新做一遍 DP,时间复杂度变为 \(\mathcal{O}(NQ)\),无法通过,注意到方程的转移是一个线性递推式,这启示我们用矩阵优化转移。我们把递推式写成矩阵的形式,可以用以下的式子描述:

\[S_{a_i}\begin{bmatrix} dp_{i-1,0} \\ dp_{i-1,1} \\ 1\end{bmatrix}=\begin{bmatrix} dp_{i,0} \\ dp_{i,1} \\ 1\end{bmatrix} \]

不难写出:

\[S_{0}=\begin{bmatrix} 1&1&1 \\ 0&1&0 \\ 0&0&1\end{bmatrix},S_{1}=\begin{bmatrix} 1&0&0 \\ 1&1&1 \\ 0&0&1\end{bmatrix},S_{?}=\begin{bmatrix} 1&1&1 \\ 1&1&1 \\ 0&0&1\end{bmatrix} \]

那么整个 DP 的转移可以用以下式子表示:

\[\prod_{i=1}^n S_{a_i} \times \begin{bmatrix}0\\0\\1\end{bmatrix} = \begin{bmatrix} dp_{n,0} \\ dp_{n,1} \\ 1\end{bmatrix} \]

本质上,一次修改字符的操作就是修改一个字符的转移矩阵。又由上面的式子可知要维护的信息是整个字符串的 \(S_{a_i}\) 的乘积。由于矩阵具有结合律,所以可以用线段树维护。时间复杂度为 \(O(k^3(N + Q \log N))\),本题中 \(k = 3\)

#include<bits/stdc++.h>
#define int long long
#define L p<<1
#define R p<<1|1
#define mid ((l+r)>>1)
using namespace std;
const int INF=0x3f3f3f3f;
const int MOD=998244353;
const int N=100005;
int tot=0;
char a[N];
inline int read(){
    int w=0,f=1;
    char ch=getchar();
    while(ch<'0'||ch>'9'){
        if(ch=='-') f=-1;
        ch=getchar();
    }
    while(ch>='0'&&ch<='9'){
        w=(w<<1)+(w<<3)+(ch-48);
        ch=getchar();
    }
    return w*f;
}
struct Matrix{
    int x,y,val[5][5];
    inline void clear(int a,int b){
        x=a; y=b;
        for(int i=1;i<=x;i++)
            for(int j=1;j<=y;j++) val[i][j]=0;
    }
    inline void unit(int a,int b){
        assert(a==b);
        x=a; y=b;
        for(int i=1;i<=x;i++) 
            for(int j=1;j<=y;j++) val[i][j]=(i==j);
    } 
};
inline Matrix mult(Matrix a,Matrix b){
    Matrix ans;
    assert(a.y==b.x);
    ans.clear(a.x,b.y);
    for(int i=1;i<=a.x;i++){
        for(int k=1;k<=a.y;k++){
            if(!a.val[i][k]) continue;
            for(int j=1;j<=b.y;j++){
                ans.val[i][j]+=a.val[i][k]*b.val[k][j]%MOD;
                ans.val[i][j]=(ans.val[i][j]%MOD+MOD)%MOD;
            }
        }
    }
    return ans;
}
Matrix S1,S2,S3,res;
struct Segnode{
    Matrix val;
}tr[N<<2];
inline void Pushup(int p){
    tr[p].val=mult(tr[L].val,tr[R].val);
}
inline void Build(int p,int l,int r){
    if(l==r){
        if(a[l]=='0') tr[p].val=S1;
        else if(a[l]=='1') tr[p].val=S2;
        else tr[p].val=S3;
        return ;
    }
    Build(L,l,mid);
    Build(R,mid+1,r);
    Pushup(p);
}
inline void Modify(int p,int l,int r,int x,Matrix d){
    if(l==r){
        tr[p].val=d;
        return ;
    }
    if(x<=mid) Modify(L,l,mid,x,d);
    else Modify(R,mid+1,r,x,d);
    Pushup(p);
}
inline void Prework(){
    S1.unit(3,3);
    S1.val[1][2]=S1.val[1][3]=1;
    S2.unit(3,3);
    S2.val[2][1]=S2.val[2][3]=1;
    S3.unit(3,3);
    S3.val[1][2]=S3.val[1][3]=S3.val[2][1]=S3.val[2][3]=1;
}
signed main(){
    int n,Q;
    n=read();Q=read();
    scanf("%s",a+1);
    Prework();
    Build(1,1,n);
    while(Q--){
        int x=read();
        char c[2];
        scanf("%s",c+1);
        if(c[1]=='0') Modify(1,1,n,x,S1);
        else if(c[1]=='1') Modify(1,1,n,x,S2);
        else Modify(1,1,n,x,S3);
        res.clear(3,1);
        res.val[3][1]=1;
        res=mult(tr[1].val,res);
        printf("%lld\n",(res.val[1][1]+res.val[2][1])%MOD);
    }
    return 0;
}