题解:[SCOI2008] 城堡

发布时间 2023-11-01 22:03:19作者: LgxTpre

应该是联赛前最后一次任性了,浪费的时间有点多,不过也揭露了我的基础知识和代码能力都很弱的问题,得加油啊。

先 sto dwt。

给定一棵基环树森林,起初有 \(m\) 个点已被选进 \(S\) 里,你需要再选 \(k\) 个点加入到 \(S\) 中,最小化其余点到 \(S\) 距离的最大值。

这个问题直接做非常困难,考虑二分答案转成判定性问题:选定若干个点能够使得所有点到 \(S\) 的距离不超过目前二分的答案 \(F\)。如果一个节点距离 \(S\) 中的点不超过 \(F\),称其为被覆盖了

基环树的题自然是树和环分开考虑,依然要预处理记好在环上的点和环上点之间的距离。如果放在树上,原题是这个,在基环树的树部分依然采用相同的 DP 方式,这个 DP 也是基于贪心的:能不放就不放。设 \(f_{u}\) 表示以 \(u\) 为根的子树中离 \(u\) 最远的没被覆盖的节点到 \(u\) 的距离;\(g_{u}\)\(u\) 为根的子树中在 \(S\) 中的离 \(u\) 最近的节点到 \(u\) 的距离,那么初值为 \(f_{u} = 0,g_{v} = \infty\),转移方程为 \(f_{u} = \max_{v \in son_{u}} \{ f_{v} \} + 1,g_{u} = \min_{v \in son_{u}} \{ g_{v} \} + 1\)。接下来考虑一些边界情况:

  • 如果当前点 \(u\) 本身就在 \(S\) 中,有 \(g_{u} = 0\),原因显然。
  • 如果 \(f_{u} + g_{u} \leq F\),这说明用 \(u\) 子树中的 \(S\) 中的点就可以覆盖整棵 \(u\) 的子树,\(u\)\(u\) 的祖先们没有决策影响了,那么 \(f_{u} = -\infty\)
  • 如果 \(f_{u} + (u,fa_{u}) > F\),如果 \(u\) 点不选入到 \(S\) 中,那么距离 \(u\) 最远的那个未覆盖节点就要不合法了,所以必须将 \(u\) 选入到 \(S\) 中,那么 \(f_{u} = -\infty,g_{u} = 0\),同时选点数量计数器要加一。

对于环上每个点为根做一遍这个树形 DP,可以得到树中最少需要选多少个点,以及根据根节点 \(i\)\(f_{i}\) 值和环上边长来知道必须要在环上的某一段区间内选点进入 \(S\) 才能使得 \(i\) 子树内的和 \(i\) 距离最大的未被覆盖的节点合法。特别的,如果环上存在一个节点 \(j\),使得 \(f_{i} + g_{j} + dis_{i \to j} \leq F\),即用 \(j\) 子树中的离 \(j\) 最近的 \(S\) 中的点来覆盖 \(i\) 中的点,那么对于 \(i\) 就不会存在环上的这段选择区间。这一部分的形态形如寻找前后缀 \(a \pm b\) 的最小值,可以用前缀 \(\min\) 线性解决。找环上的需要被覆盖的区间可以先算出最深的不合法节点到根节点的距离,然后推出在环上的延伸距离,由于我们预先处理好了环上边权的前缀和,就可以在前缀和上做一个二分查找来找到区间的左右端点。

解决上述步骤之后,现在问题转化成了在环上有若干个区间,每个区间至少选一个点,问最少要选多少个点。由于起初的 \(m\) 个点有一些可能在环上,先把这些点能满足的区间去掉。然后可以运用这个题的套路,加入当前选择了一个位置 \(i\),能够发现下一个选择的节点一定是左端点大于这个节点,且右端点最靠左的那个区间的右端点,可以用单指针线性求出从位置 \(i\) 开始下一个选择的位置在哪里,然后预处理倍增数组 \(st_{i,j}\) 表示选择了位置 \(i\),下 \(2^j\) 个位置选在哪里,枚举哪个节点作为第一个选择的节点,倍增的算接下来总共在环上要选几个点,于是这样就 \(\mathcal O(n \log n)\) 的解决这个部分了。

让我们梳理一下整个算法流程:首先预处理环上节点,断环成链并记录环上边权前缀和;然后开始二分答案,对于一棵基环树,先树形 DP 算出树部分最少需要放置的点以及能够覆盖到的环上的区间,然后倍增搭配贪心完成环上的选取;用所有基环树需要选取的点的数量总和来判断是否满足当前二分出的答案。因为要特殊处理二元环和自环的环上边权,我的代码中选择了使用 map 来解决这个问题。

综上,总复杂度我们做到了 \(\mathcal O(n \log n \log V)\)。但是这个东西实际跑起来非常缓慢,瓶颈主要在于倍增数组的初始化,更改倍增数组的类型可以起到很好的优化效果,同时这也说明了卡满这个做法的方法是一整个圈。还有一些卡常技巧也可以相应的使用,我的代码常数并不优秀,但 \(10^5\) 的数据还是跑到了一秒左右。

#include<bits/stdc++.h>
#define int long long
#define ull unsigned long long
#define ui unsigned int
#define ld long double
#define power(x) ((x)*(x))
#define eb emplace_back
#define pb pop_back
#define mp make_pair
#define fi first
#define se second
#define TT template<typename T>
#define TA template<typename T,typename ...Args>
#define dbg(x) cerr<<"In Line "<< __LINE__<<" the "<<#x<<" = "<<x<<'\n'
using namespace std;
using pii=pair<int,int>;
using pdi=pair<double,int>;
using piii=pair<pair<int,int>,int>;

bool Mbe;

namespace IO
{
    inline int read()
    {
        int s=0,w=0; char c=getchar();
        while(!isdigit(c)) w|=c=='-',c=getchar();
        while(isdigit(c)) s=(s*10)+(c^48),c=getchar();
        return s*(w?-1:1);
    }
    TT inline void read(T &s)
    {
        s=0; int w=0; char c=getchar();
        while(!isdigit(c)) w|=c=='-',c=getchar();
        while(isdigit(c)) s=(s*10)+(c^48),c=getchar();
        s*=w?-1:1;
    }
    TA inline void read(T &x,Args &...args) {read(x),read(args...);}
    TT inline void write(T x,char ch=' ')
    {
        if(x<0) x=-x,putchar('-');
        static char stk[30]; int top=0;
        do stk[top++]=x%10+'0',x/=10; while(x);
        while(top) putchar(stk[--top]);
        if(ch!='~') putchar(ch);
    }
}
using namespace IO;

namespace MTool
{
    TT inline void Swp(T &a,T &b) {static T t=a;a=b;b=t;}
    TT inline void cmax(T &a,T b) {a=max(a,b);}
    TA inline void cmax(T &a,T b,Args ...args) {a=max({a,b,args...});}
    TT inline void cmin(T &a,T b) {a=min(a,b);}
    TA inline void cmin(T &a,T b,Args ...args) {a=min({a,b,args...});}
}
using namespace MTool;

inline void file()
{
    freopen(".in","r",stdin);
    freopen(".out","w",stdout);
}

namespace LgxTpre
{
    constexpr int MAX=200010;
    constexpr int inf=2147483647;
    constexpr int INF=4557430888798830399;

	int n,m,k,w,v[MAX];
	vector<pii> G[MAX];
	unordered_map<int,int> p;
	int tag[MAX],dfn[MAX],ilp[MAX],tmp[MAX],lt[MAX],fa[MAX],tot;
	int f[MAX],g[MAX],all[MAX],mx[MAX];
	signed st[MAX][21]; pii A[MAX];
	vector<int> lop[MAX],dis[MAX],sum[MAX]; int cnt[MAX],num;
	int l,r,mid,ans;
	#define ps(i,j) (min(i,j)*100000+max(i,j))

    inline void lmy_forever()
    {
		read(n,m,k),l=0,r=1e10;
		for(int i=1;i<=n;++i) read(v[i]),++v[i];
		for(int i=1;i<=n;++i)
		{
			read(w),G[i].eb(mp(v[i],w));
			if(p.find(ps(i,v[i]))==p.end()) p[ps(i,v[i])]=w; else cmin(p[ps(i,v[i])],w);
			if(i!=v[i]) G[v[i]].eb(mp(i,w));
		}
		for(int i=1;i<=m;++i) tag[read()+1]=1;
		function<void(int,int)> FindLoop=[&](int now,int father)
		{
			dfn[now]=++tot,fa[now]=father;
			for(auto [to,_]:G[now]) if(to!=father)
			{
				if(dfn[to])
				{
					if(dfn[to]<dfn[now]) continue;
					tmp[++cnt[num]]=to,ilp[to]=1;
					for(;to!=now;to=fa[to]) tmp[++cnt[num]]=fa[to],ilp[fa[to]]=1;
					lop[num].resize(cnt[num]*2+1); for(int i=1;i<=cnt[num];++i) lop[num][i]=lop[num][i+cnt[num]]=tmp[i];
				}
				else FindLoop(to,now);
			}
		};
		for(int i=1;i<=n;++i) if(!dfn[i])
		{
			++num,FindLoop(i,0),dis[num].resize(cnt[num]*2+1),sum[num].resize(cnt[num]*2+1);
			for(int i=1;i<=cnt[num];++i) if(tag[lop[num][i]]) {lt[num]=i; break;}
			for(int i=1;i<=cnt[num]-1;++i) dis[num][i]=dis[num][i+cnt[num]]=p[ps(lop[num][i],lop[num][i+1])];
			dis[num][cnt[num]]=p[ps(lop[num][1],lop[num][cnt[num]])];
			for(int i=1;i<=cnt[num]*2-1;++i) sum[num][i+1]=sum[num][i]+dis[num][i];
			for(int i=1;i<=cnt[num];++i) all[num]+=dis[num][i];
		}
		auto check=[&](int F)->bool
		{
			int res=0;
			function<void(int,int)> dfs=[&](int now,int father)
			{
				f[now]=0,g[now]=INF;
				for(auto [to,val]:G[now]) if(!ilp[to]&&to!=father) dfs(to,now),cmax(f[now],f[to]+val),cmin(g[now],g[to]+val);
				if(tag[now]) g[now]=0;
				if(f[now]+g[now]<=F) f[now]=-INF;
				if(f[now]+p[ps(now,father)]>F) f[now]=-INF,g[now]=0,++res;
			};
			for(int id=1;id<=num;++id)
			{
				int FLG=0,use=0,N=cnt[id]*2;
				auto findL=[&](int p,int k)->int
				{
					int l=1,r=p,cur=1;
					while(l<=r) {int mid=(l+r)>>1; if(sum[id][p]<=sum[id][mid]+k) r=mid-1,cur=mid; else l=mid+1;}
					return cur;
				};
				auto findR=[&](int p,int k)->int
				{
					int l=p,r=N,cur=N;
					while(l<=r) {int mid=(l+r)>>1; if(sum[id][mid]<=sum[id][p]+k) l=mid+1,cur=mid; else r=mid-1;}
					return cur;
				};
				for(int i=1;i<=cnt[id];++i) dfs(lop[id][i],0),mx[i]=INF;
				for(int i=1,mix=INF;i<=cnt[id]*2;++i) cmin(mix,g[lop[id][i]]-sum[id][i]),cmin(mx[i>cnt[id]?i-cnt[id]:i],sum[id][i]+mix);
				for(int i=cnt[id]*2,mix=INF;i>=1;--i) cmin(mix,g[lop[id][i]]+sum[id][i]),cmin(mx[i>cnt[id]?i-cnt[id]:i],mix-sum[id][i]);
				for(int i=1;i<=cnt[id];++i) if(mx[i]+f[lop[id][i]]>F)
				{
					int Len=F-f[lop[id][i]];
					if(Len*2>=all[id]) {FLG|=lt[id]==0; continue;}
					if(sum[id][i]+dis[id][cnt[id]]>Len)
					{
						
						A[++use]=mp(findL(i,Len),findR(i,Len));
						if(A[use].se>cnt[id]) continue;
						++use,A[use]=mp(A[use-1].fi+cnt[id],A[use-1].se+cnt[id]);
					}
					else A[++use]=mp(findL(i+cnt[id],Len),findR(i+cnt[id],Len));
				}
				if(!use) {res+=FLG; continue;}
				int T=__lg(N),con=INF,mn=INF;
				for(int i=1;i<=N;++i) for(int j=0;j<=T;++j) st[i][j]=0;
				for(int i=1;i<=N;++i) mx[i]=INF;
				for(int i=1;i<=use;++i) cmin(mx[A[i].fi],A[i].se);
				for(int i=N;i;--i) st[i][0]=mn==INF?0:mn,cmin(mn,mx[i]);
				for(int j=1;j<=T;++j) for(int i=1;i+(1<<j)-1<=N;++i) st[i][j]=st[st[i][j-1]][j-1];
				if(lt[id])
				{
					int now=lt[id]; con=0;
					for(int i=T;~i;--i) if(st[now][i]&&st[now][i]<lt[id]+cnt[id]) con+=1<<i,now=st[now][i];
				}
				else
				{
					for(int j=1;j<=cnt[id];++j)
					{
						int cot=1,now=j;
						for(int i=T;~i;--i) if(st[now][i]&&st[now][i]<j+cnt[id]) cot+=1<<i,now=st[now][i];
						cmin(con,cot);
					}
				}
				for(int i=1;i<=use;++i) A[i]=mp(0,0);
				res+=max(FLG,con);
			}
			return res<=k;
		};
		while(l<=r) {mid=(l+r)>>1; if(check(mid)) r=mid-1,ans=mid; else l=mid+1;}
		write(ans);
    }
}

bool Med;

signed main()
{
//  file();
    int Tbe=clock();
    LgxTpre::lmy_forever();
    int Ted=clock();
    fprintf(stderr,"\nMemory: %.3lf MB",abs(&Mbe-&Med)/1024.0/1024.0);
    fprintf(stderr,"\nTime: %.3lf ms",1e3*(Ted-Tbe)/CLOCKS_PER_SEC);
    return (0-0);
}