# CF890 E PermuTree 题解 [背包 dp][二进制拆分]

# 题目大意

给出一个以11 为根节点大小为NN 的树。
给这个数确定一个长度为NN 的排列pp,分配给各节点作为权值。对于一个点对(u,v)(u,v),若满足au<alca(u,v)<ava_u\lt a_{lca(u,v)} \lt a_v,则产生11 个贡献。
求总贡献的最大值。

# 题解

首先对一个排列计算贡献的过程,实际上就是对于每一个点xx,从它的一个子树中选一个点uu 满足pu<pxp_u\lt p_x,再从另一个子树中选择一个点vv 满足pv>pxp_v\gt p_x。很容易看出来,如果权值比 x 小的节点独立分布在某些子树内,并且权值比 x 大的节点独立分布在其余的子树内,这时候的总贡献更大,为两部分节点个数和的乘积。即对于每对(u,v)(u,v) 满足pu<px<pvp_u\lt p_x \lt p_v,都能做贡献。
考虑每个子树,设根节点为xx,我们假设有s0s_0 个点的权值比pxp_x 要小,剩下的s1s_1 个点权值比pxp_x 要大。要将所有子树进行划分,使得s0s_0s1s_1 的差最小,这时总贡献最大。这是一个经典 01背包 问题。
直接应用 01背包dp ,只能做到O(N2)O(N^2) 复杂度。
注意到这是一个树上问题。我们背包物品的体积是子树的节点大小,而树的节点总和是O(N)O(N) 的,意味着子树大小的种类是O(N)O(\sqrt N) 的。也就是说最多只有O(N)O(\sqrt N) 种物品,但一种物品可能出现很多个,其实是一个多重背包。应用二进制拆分优化,将物品数优化到O(NlogN)O(\sqrt N \log N) 级别。
另外,做背包时重儿子不需加入转移,因为我们只考虑划分两部分中的其中一部分,这部分不包含重儿子,另一部分(包含重儿子)自然是存在的。我们只关注两部分之差最小,兑换顺序是等效的。这样也保证了树上背包的复杂度O(NNlogN)O(N\sqrt N\log N)

另外吐槽官方题解中使用 bitset 优化的做法。实际上作用并不明显,反而有点慢。
上:使用了bitset  下:bool数组

# 代码

#include<cstdio>
#include<vector>
#include<algorithm>
using std::vector;
using std::max;
int read(){
	int out(0),c(getchar());
	for(;c<'0' || c>'9';c=getchar());
	for(;c<='9' && c>='0';c=getchar())
		out=(out<<3)+(out<<1)+(c^48);
	return out;
}
const int MAXN=1e6+10;
int N;
vector<int> to[MAXN];
int sz[MAXN],son[MAXN];
int f[MAXN];
int cnt[MAXN];
long long ans;
vector<int> a;
void dfs(int x){
	sz[x]=1;
	for(int tt:to[x]){
		dfs(tt);
		sz[x]+=sz[tt];
		if(sz[tt]>sz[son[x]])
			son[x]=tt;
	}
	int ma=0;
	for(int tt:to[x]){
		if(tt!=son[x]){
			++cnt[sz[tt]];
			ma=max(ma,sz[tt]);
		}
	}
	int msz=sz[x]-1-sz[son[x]];
	a.clear();
	for(int i=1;i<=ma;++i){
		if(cnt[i]){
			int t=1;
			while(t<=cnt[i]){
				a.push_back(t*i);
				cnt[i]-=t;
				t<<=1;
			}
			if(cnt[i]){
				a.push_back(cnt[i]*i);
				cnt[i]=0;
			}
		}
	}
	f[0]=1;
	for(int i=1;i<=msz;++i)
		f[i]=0;
	std::sort(a.begin(),a.end());
	int s=0;
	for(int v:a){
		for(int i=s+v;i>=v;--i)
			if(f[i-v])
				f[i]=1;
			// f[i]|=f[i-v];
		s+=v;
	}
	long long res=0;
	for(int i=0;i<=s;++i)
		if(f[i])
			res=max(res,(long long)i*(sz[x]-1-i));
	ans+=res;
}
int main(){
	// freopen("E.in","r",stdin);
	// freopen("E.out","w",stdout);
	N=read();
	for(int i=2;i<=N;++i)
		to[read()].push_back(i);
	dfs(1);
	printf("%lld\n",ans);
}
更新于 阅读次数