杂题与思维题口胡

发布时间 2023-08-24 00:55:26作者: cqbzlzh

Flip Machines

\(tag\)根号分治,dp,期望,贪心

比较神仙,没想到这还能数据分治。

考虑每次交换的卡片上的数,若\(b_i \leq a_i\) 那么这次操作一定是不优的,反之一定能够更优,能使答案变大。

\(a_i \geq b_i\)构成集合为 \(P\),\(a_i < b_i\)构成集合为\(Q\)。若一个操作\(x\)\(y\)都在\(P\)内显然不可能使期望变大,若都在\(Q\)内一定最优,能使答案变大。

考虑\(x\)\(y\)\(P\),\(Q\)中都有的是否选取,可以设\(dp_{i,mask}\)表示其中一个集合选取前\(i\)个中的若干元素和另一个中选取方案为\(mask\)的使答案改变的最大值,显然复杂度可以\(O(n2^S)\)

然而\(|P|+|Q|=n\),因此选择二者最小的一个集合状压即可,总复杂度\(O(m+n2^{\frac{n}{2}})\)

点击查看代码
#include<bits/stdc++.h>

using namespace std;

template <class T>
void read(T &x){
    x=0;char c=getchar();bool f=0;
    while(!isdigit(c)) f=c=='-',c=getchar();
    while(isdigit(c)) x=x*10+c-'0',c=getchar();
    x=f? (-x):x;
}

const int MAXN=45;
const int MAXM=5e5+5;
const int MAXS=1050000;

int n,m;

int a[MAXN],b[MAXN];
int x[MAXM],y[MAXM];
int bel[MAXN];

vector <int> G[MAXN];

bool vis[MAXN];
bool linked[MAXN][MAXN];
double ans;
double dp[MAXN][MAXS];

void add(int u,int v){
    if (linked[u][v]) return;
    linked[u][v]=1;
    G[u].push_back(v);
}

int a0[MAXN],a1[MAXN],id0[MAXN],id1[MAXN],cnt0,cnt1;

int main(){
    read(n);read(m);
    for (int i=1;i<=n;i++){
        read(a[i]);read(b[i]);
        if (a[i]>=b[i]) bel[i]=0;
        else bel[i]=1;
    }
    for (int i=1;i<=m;i++){
        read(x[i]);read(y[i]);
        if (x[i]==y[i]){
            if (b[x[i]]>a[x[i]]) swap(b[x[i]],a[x[i]]);
            bel[x[i]]=0;
        }
    }
    for (int i=1;i<=m;i++){
        int _x=x[i],_y=y[i];
        if (_x==_y) continue;
        if (bel[_x]&&bel[_y]){
            vis[_x]=vis[_y]=1;
        }
        else if (bel[_x]||bel[_y]){
            add(_x,_y);add(_y,_x);
        }
    }
    for (int i=1;i<=n;i++){
        if (bel[i]==0){
            a0[++cnt0]=i;
            id0[i]=cnt0;                
        }
        else{
            if (!vis[i]){
                a1[++cnt1]=i;
                id1[i]=cnt1;
            }
        }
        if (vis[i]) ans+=((double)a[i]+b[i])*1.0/2.0;
        else ans+=a[i];
    }
    for (int i=0;i<=n;i++){
        for (int mask=0;mask<MAXS;mask++) dp[i][mask]=-0x3f3f3f3f;
    }
    dp[0][0]=0;
    if (cnt0<=cnt1){
        for (int i=1;i<=n;i++){
            if (bel[i]&&!vis[i]){
                for (int mask=0;mask<(1<<cnt0);mask++){
                    if (dp[i-1][mask]==-0x3f3f3f3f) continue;
                    for(const auto &j:G[i]){
                        int v=id0[j];
                        if (!v) continue;
                        dp[i][mask|(1<<(v-1))]=max(dp[i][mask|(1<<(v-1))],dp[i-1][mask]+((double)b[i]-a[i])*1.0/2.0);
                    }
                }                
            }
            for (int mask=0;mask<(1<<cnt0);mask++) dp[i][mask]=max(dp[i][mask],dp[i-1][mask]);
        }
        double res=0;
        for (int mask=0;mask<(1<<cnt0);mask++){
            double cur=dp[n][mask];
            for (int j=1;j<=cnt0;j++){
                if (mask&(1<<(j-1))){
                    int x=a0[j];
                    cur+=((double)b[x]-a[x])*1.0/2.0;
                }
            }
            res=max(res,cur);
        }
        printf("%.10lf\n",ans+res);
    }
    else{
        for (int i=1;i<=n;i++){
            if (!bel[i]){
                for (int mask=0;mask<(1<<cnt1);mask++){
                    if (dp[i-1][mask]==-0x3f3f3f3f) continue;
                    int tmp=mask;
                    for(const auto &j:G[i]){
                        int v=id1[j];
                        if (!v) continue;   
                        tmp|=(1<<(v-1));
                    }
                    dp[i][tmp]=max(dp[i][tmp],dp[i-1][mask]+((double)b[i]-a[i])*1.0/2.0);
                }
            }
            for (int mask=0;mask<(1<<cnt1);mask++) dp[i][mask]=max(dp[i][mask],dp[i-1][mask]);
        }
        double res=0;
        for (int mask=0;mask<(1<<cnt1);mask++){
            double cur=dp[n][mask];
            for (int j=1;j<=cnt1;j++){
                if (mask&(1<<(j-1))){
                    int x=a1[j];
                    cur+=((double)b[x]-a[x])*1.0/2.0;
                }
            }
            res=max(res,cur);
        }
        printf("%.10lf\n",ans+res);
    }
}