[ZJOI2010]基站选址

发布时间 2023-03-22 21:10:44作者: Diamondan

线段树优化\(dp\)的板题?
首先根据题意列出\(dp\)方程
\(dp[i][j]\)表示前\(i\)个村庄中选取第\(i\)个作为第\(j\)个基站的方案数
\(dp[i][j]=min(dp[k][j-1]+cost[k][i])\)
然后滚动掉\(j\)这一维就变成了
\(dp[i]=min(dp[k]+cost[k][i])\)
上面\(cost\)表示\([k,i]\)这一段区间需要的赔偿
这个\(cost\)我们用线段树维护
考虑区间会很大,我们先对村庄离散化
存储village \(i\)对应的可以覆盖到的左右端点村庄\(st[i]\),\(ed[i]\)
每次\(i\)有移动,那么\([1,st[i]−1]\)里面的\(dp\)值肯定要加上\(w_i\)
每次转移查询\([1,i-1]\)\(dp\)最小值就可以直接转移
维护区间之间连锁的影响,我们用邻接表来维护
\(CODE\)

#include<bits/stdc++.h>
#define int long long
#define ls (o<<1)
#define rs (o<<1|1)
using namespace std;

const int N=20010;
const int inf=2e9;

int read() {

    int x=0,f=1;
    char ch=getchar();
    while(ch<48||ch>57) {
        if(ch=='-') f=-1;
        ch=getchar();
    }
    while(ch>=48&&ch<=57) {
        x=(x<<3)+(x<<1)+(ch^48);
        ch=getchar();
    }
    return x*f;

}

int head[N],to[N*2],tot=0,nxt[N*2];
void add(int u,int v) {
	nxt[++tot]=head[u];
	to[tot]=v;
	head[u]=tot;
}

int d[N],s[N],w[N],c[N];
int sum[N*4],tag[N*4],f[N];

void push_up(int o) {
	sum[o]=min(sum[ls],sum[rs]);
}
void build(int o,int l,int r) {

	tag[o]=0;
	if(l==r) {
		sum[o]=f[l];
		return ;
	}
	int mid=(l+r)>>1;
	build(ls,l,mid);
	build(rs,mid+1,r);
	push_up(o);

}

void push_down(int o,int l,int r) {

	sum[ls]+=tag[o];
	sum[rs]+=tag[o];
	tag[ls]+=tag[o];
	tag[rs]+=tag[o];
	tag[o]=0;

}

void update(int o,int l,int r,int x,int y,int t) {

	if(x<=l&&r<=y) {
		tag[o]+=t;
		sum[o]+=t;
		return ;
	}
	int mid=(l+r)>>1;
	if(tag[o]) push_down(o,l,r);
	if(x<=mid) update(ls,l,mid,x,y,t);
	if(y>mid) update(rs,mid+1,r,x,y,t);
	push_up(o);

}

int query(int o,int l,int r,int x,int y) {

	int ans=inf;
	if(x<=l&&r<=y) {
		return sum[o];
	}
	int mid=(l+r)>>1;
	if(tag[o]) push_down(o,l,r);
	if(x<=mid) ans=min(ans,query(ls,l,mid,x,y));
	if(y>mid) ans=min(ans,query(rs,mid+1,r,x,y));
	return ans;

}

int n,k,ANS,st[N],ed[N],num=0;

signed main() {

	n=read();
	k=read();
	for(int i=2;i<=n;i++) d[i]=read();
	for(int i=1;i<=n;i++) c[i]=read();
	for(int i=1;i<=n;i++) s[i]=read();
	for(int i=1;i<=n;i++) w[i]=read();
 	n++;
 	k++;
 	d[n]=w[n]=inf;
 	s[n]=c[n]=0;

	for(int i=1;i<=n;i++) {

		st[i]=lower_bound(d+1,d+n+1,d[i]-s[i])-d;
		ed[i]=lower_bound(d+1,d+n+1,d[i]+s[i])-d;
		if(d[ed[i]]>d[i]+s[i]) ed[i]--;
		add(ed[i],i);

	}
	//k=1
	for(int i=1;i<=n;i++) {

		f[i]=num+c[i];
		for(int j=head[i];j;j=nxt[j]) {
			int v=to[j];
			num+=w[v];
		}

	}

	ANS=f[n];
	for(int j=2;j<=k;j++) {

		build(1,1,n);
		for(int i=1;i<=n;i++) {
			if(i>=2) {
				f[i]=query(1,1,n,1,i-1)+c[i];
			} else {
				f[i]=inf;
			}
			for(int u=head[i];u;u=nxt[u]) {
				int v=to[u];
				if(st[v]>1) update(1,1,n,1,st[v]-1,w[v]);
			}
		}
		ANS=min(ANS,f[n]);

	}
	printf("%lld\n",ANS);
 
    return 0;
}