F - Delete 1, 4, 7, ...
设\(f(i)\)表示第一次操作后,第\(i\)个位置的数,那么\(f(i)=\lfloor \frac{3i+1}2\rfloor\)
那么\(k\)次操作后,第\(i\)个位置上的数就是:
设\(cnt_k\)表示\(k\)次操作后剩下的数的个数,那么显然有:\(cnt_i=\lfloor \frac {cnt_{i-1}\times2}3\rfloor\)
最后的答案就是\(\sum_{i=1}^{cnt_k}f^k(i)\)
若采用最暴力的方法递推出所有的\(f^k(i)\),那么复杂度为\(O(k\times cnt_k)\)
由因为\(cnt_k\leq n\times(\frac 23)^k\),所以复杂度为\(O(nk(\frac 23)^k)\)
当\(k\)很大时,这个复杂度就很优秀,但当\(k\)较小时,这个复杂度就过大了,所以考虑用另一种做法来解决较小的\(k\)
考虑\(f^k(i)\)与\(f^k(j)\)的关系,有:
当\(k=1\)时,由题意得证
当\(k>1\)时:
\[\begin{aligned} f^k(n+2^k)&=f^{k-1}(f(n+2^k))\\ &=f^{k-1}(\lfloor\frac{3n+3\times 2^k+1}2\rfloor)\\ &=f^{k-1}(3\times 2^{k-1}+\lfloor\frac{3n+1}2\rfloor)\\ &=f^{k-1}(3\times 2^{k-1}+f(n))\\ &=f^k(n)+3^k \end{aligned} \]
所以我们只需要求出前\(2^k\)个\(f^k(i)\)就可以得出所有的\(f^k(i)\)
这样复杂度就降到了\(O(k2^k)\),但还是差点,考虑继续优化
将\(f^k(i)\)拆成\(f^x(f^y(i))\),其中\(x+y=k\),那么有:
那么我们的答案就从\(\sum_{i=1}^{cnt_k}f^k(i)\)变成了:枚举\(i\in[1,2^y]\),\(\sum_jf^x(f^y(i)+3^yj)\)
然后又有\(i+2^yj\leq cnt_k\),那么求得\(j\)的上界为\(\lfloor\frac{cnt_k-i}{2^y}\rfloor+1\)
多算没关系,因为是0
设\(g(a,i)\)表示\(\sum_{j=0}^{2^i-1}f^x(a+3^yj)\),那么有转移:
还有一个优化,当\(a>2^x\)时,有:
这一部分的复杂度为\(O(2^y\log n+2^xk)\),当\(x=y=\frac k2\)时,复杂度为\(O((k+\log n)2^{\frac k2})\)
#include<bits/stdc++.h>
#define ll long long
using namespace std;
const int N=105,MOD=998244353;
ll n,k,cnt[N],ans,X,Y,p3[N],dp[1100000][45];
ll f(ll x,ll k){
for(int i=1;i<=k;++i) x=(x*3+1)/2;
return x;
}
ll g(ll a,ll i){
if(a>(1ll<<X)){
ll t=((a-1)/(1ll<<X))%MOD;
return (g((a-1)%(1ll<<X)+1,i)+t*(p3[X]%MOD)%MOD*((1ll<<i)%MOD)%MOD)%MOD;
}
if(!i) return f(a,X)%MOD;
if(dp[a][i]) return dp[a][i];
return dp[a][i]=(g(a,i-1)+g(a+p3[Y]*(1ll<<i-1),i-1))%MOD;
}
void add(ll &x,ll y){
x+=y;
if(x>=MOD) x-=MOD;
}
void work1(){
X=k/2,Y=k-X,p3[0]=1;
for(int i=1;i<=30;++i) p3[i]=p3[i-1]*3;
for(int i=1;i<=min(cnt[k],1ll<<Y);++i){
ll lim=(cnt[k]-i)/(1ll<<Y)+1,a=f(i,Y);
for(int j=0;j<=50;++j) if(lim&(1ll<<j)) add(ans,g(a,j)),a+=p3[Y]*(1ll<<j);
}
}
void work2(){
for(int i=1;i<=cnt[k];++i) add(ans,f(i,k)%MOD);
}
int main(){
scanf("%lld%lld",&n,&k);
cnt[0]=n;
for(int i=1;i<=k;++i) cnt[i]=cnt[i-1]*2/3;
if(k<=40) work1();
else work2();
printf("%lld",ans);
return 0;
}