【垫底模拟】CSP-46

发布时间 2023-09-27 21:17:04作者: Sonnety

T2 序列(array):思维题

题干概括:

有两个长度为 \(m\) 的序列 \(a,b\),在 \(0\leq b_i\leq n\)\(\sum_{i=1}^{m}\limits a_ib_i\leq D\) 的情况下,求:

\[\sum_{i=1}^{m}b_i+k\cdot \min_{i=1}^{m}b_i \]

的最大值。

思路:

容易证明的是,在 \(a_i\leq a_{i+1}\) 的条件下,\(b_{i}\geq b_{i+1}\) 一定更优。

因此我们对整个 \(a\) 进行排序,枚举 \(b\) 序列前 \(i\) 个元素大小为 \(n\),第 \(i+1\) 个元素大小为 \(x\),第 \(i+2\) 到第 \(m\) 个元素大小为 \(y\)。(\(n\geq x\geq y\)

对于每一个 \(i\),答案就是 \(f(x)\)

\[\begin{aligned} &f(x)=i\cdot n+x+y\cdot\left(k+n-\left(i+2\right)+1\right)\\ &y=\left\lfloor{\frac{D-n\sum_{j=1}^{n}\limits a_i-a_{i+1}\cdot x}{\sum_{j=i+2}^{n}\limits a_j}}\right\rfloor\\ &f(x)=i\cdot n+x+\left\lfloor{\frac{D-n\sum_{j=1}^{n}\limits a_i-a_{i+1}\cdot x}{\sum_{j=i+2}^{n}\limits a_j}}\right\rfloor\times (k+n-i-1) \end{aligned} \]

于是我们把这个函数化简成:

\[\begin{aligned} &f(x)=x+Ay\\ &y=\left\lfloor{\frac{C-Bx}{S}}\right\rfloor\\ &f(x)=x+A\left\lfloor{\frac{C-Bx}{S}}\right\rfloor\\ &(A\geq 0,B\geq 0,C\geq 0,S\geq 0) \end{aligned} \]

然后画出他们的函数图像,发现其为锯齿状:

上行锯齿

下行锯齿

为什么是锯齿状的?因为:

\[y=\left\lfloor{\frac{C-Bx}{S}}\right\rfloor \]

\(y\) 是一个向下取整的函数,在 \(x\) 变化的一定区间内,\(y\) 不变,在 \(y\) 不变的区间内,\(x\) 增大,\(f(x)\) 一定增大。

\(y\geq 0\),可以得到:

\[Sy+Bx\leq C \]

\(x\) 有一个定义域,由 \(n\)\(D\) 决定,在定义域内,答案有三种情况:

  • 在第一个峰:

因为 \(y\leq x\),所以我们可以先假设 \(x=y\),求出最小的 \(y\),然后将剩下的值 \(C-Sy\) 全部赋给 \(x\)

// 1:保证x尽可能小,因为x>=y,假设x等于y之后填x
ll y=Min(C/(B+S),n),x=Min(C/(B+S)+C%(B+S)/B,n);
ans=Max(ans,n*i+x+A*y);
  • \(x\) 的最大值上:

直接求 \(x\) 的最大值,假设 \(y=0\) 即可。

// 2:保证x尽可能大,假设y=0
 x=Min(C/B,n),y=0;
if(x>=y)    ans=Max(ans,n*i+x+A*y);
  • 在最后一个峰上:

由上不等式得到:\(x\) 越大,\(y\) 越小。

设在答案点上的 \(x=xx,y=yy\)\(ans=xx+Ayy\)

\[\begin{aligned} &y\leq \frac{C-Bxx}{S}\\ &yy=\left\lfloor{\frac{C-B xx}{S}}\right\rfloor\\ &Syy\leq C-Bxx\\ &Syy\geq C-Bn\\ &Syy> C-Bn-1\\ &yy>\left\lfloor{\frac{C-Bn-1}{S}}\right\rfloor\\ &yy=\left\lfloor{\frac{C-Bn-1}{S}}\right\rfloor+1\\ &xx\leq \frac{C-Syy}{B}\\ &xx=\left\lfloor{\frac{C-Syy}{B}}\right\rfloor \end{aligned} \]

// 3:最后的峰点
 if(C-B*n>1){
     int y=(C-B*n)/S,x=(C-S*y)/B;
     if(x>=y)    ans=Max(ans,n*i+x+A*y);
}

时间复杂度:\(O(Tm)\)

Miku's Code
#include<bits/stdc++.h>
using namespace std;
#define il inline
#define rg register int
#define cout std::cout
#define endl '\n'
#define int long long
typedef long long ll;
typedef unsigned long long ull;
typedef double ff;
typedef long double llf;
const ff eps=1e-8;
typedef pair<int,int> PII;
typedef vector<int> VI;
typedef set<int> SI;
int Max(int x,int y)    <% return x<y?y:x; %>
int Min(int x,int y)    <% return x<y?x:y; %>
int Abs(int x)  <% return x>0?x:-x; %>
#if ONLINE_JUDGE
char INPUT[1<<20],*p1=INPUT,*p2=INPUT;
#define getchar() (p1==p2 && (p2=(p1=INPUT)+fread(INPUT,1,1<<20,stdin),p1==p2)?EOF:*p1++)
#endif
il int read(){
    char c=getchar();
    int x=0,f=1;
    while(c<48) <% if(c=='-')f=-1;c=getchar(); %>
    while(c>47) x=(x*10)+(c^48),c=getchar();
    return x*f;
}const int maxm=2e5+5;

int T;
int n,m,k,D,a[maxm],sum[maxm],ans;

il void clear(){
    ans=0;
    for(rg i=0;i<=m;++i)    sum[i]=0;
}

il void input(){
    n=read(),m=read(),k=read(),D=read();
    for(rg i=1;i<=m;++i)    a[i]=read();
}

signed main(){
freopen("array.in","r",stdin);
#if ONLINE_JUDGE
freopen("array.out","w",stdout);
#endif
    T=read();
    while(T--){
        input();
        sort(a+1,a+m+1);
        for(rg i=1;i<=m;++i)    sum[i]=sum[i-1]+a[i];
        if(n*sum[m]<=D) { printf("%lld\n",n*(m+k));continue; }
        for(rg i=0;i<m;++i){
            if(n*sum[i]>D)  break;
            ll A=m-i-1+k,B=a[i+1],C=D-sum[i]*n,S=sum[m]-sum[i+1];
            if(!S){ ans=Max(ans,i*n+Min(C/B,n)*(1+k));continue; }
            else{
                // 1:保证x尽可能小,因为x>=y,假设x等于y之后填x
                ll y=Min(C/(B+S),n),x=Min(C/(B+S)+C%(B+S)/B,n);
                ans=Max(ans,n*i+x+A*y);
                // 2:保证x尽可能大,假设y=0
                x=Min(C/B,n),y=0;
                if(x>=y)    ans=Max(ans,n*i+x+A*y);
                // 3:最后的峰点
                if(C-B*n>1){
                    int y=(C-B*n)/S,x=(C-S*y)/B;
                    if(x>=y)    ans=Max(ans,n*i+x+A*y);
                }
            }
        }
        printf("%lld\n",ans);
        clear();
    }
    return 0;
}