# CF890 E PermuTree 题解 [背包 dp][二进制拆分]
# 题目大意
给出一个以 为根节点大小为 的树。
给这个数确定一个长度为 的排列,分配给各节点作为权值。对于一个点对,若满足,则产生 个贡献。
求总贡献的最大值。
# 题解
首先对一个排列计算贡献的过程,实际上就是对于每一个点,从它的一个子树中选一个点 满足,再从另一个子树中选择一个点 满足。很容易看出来,如果权值比 x 小的节点独立分布在某些子树内,并且权值比 x 大的节点独立分布在其余的子树内,这时候的总贡献更大,为两部分节点个数和的乘积。即对于每对 满足,都能做贡献。
考虑每个子树,设根节点为,我们假设有 个点的权值比 要小,剩下的 个点权值比 要大。要将所有子树进行划分,使得 与 的差最小,这时总贡献最大。这是一个经典 01背包
问题。
直接应用 01背包dp
,只能做到 复杂度。
注意到这是一个树上问题。我们背包物品的体积是子树的节点大小,而树的节点总和是 的,意味着子树大小的种类是 的。也就是说最多只有 种物品,但一种物品可能出现很多个,其实是一个多重背包。应用二进制拆分优化,将物品数优化到 级别。
另外,做背包时重儿子不需加入转移,因为我们只考虑划分两部分中的其中一部分,这部分不包含重儿子,另一部分(包含重儿子)自然是存在的。我们只关注两部分之差最小,兑换顺序是等效的。这样也保证了树上背包的复杂度。
另外吐槽官方题解中使用 bitset
优化的做法。实际上作用并不明显,反而有点慢。
# 代码
#include<cstdio> | |
#include<vector> | |
#include<algorithm> | |
using std::vector; | |
using std::max; | |
int read(){ | |
int out(0),c(getchar()); | |
for(;c<'0' || c>'9';c=getchar()); | |
for(;c<='9' && c>='0';c=getchar()) | |
out=(out<<3)+(out<<1)+(c^48); | |
return out; | |
} | |
const int MAXN=1e6+10; | |
int N; | |
vector<int> to[MAXN]; | |
int sz[MAXN],son[MAXN]; | |
int f[MAXN]; | |
int cnt[MAXN]; | |
long long ans; | |
vector<int> a; | |
void dfs(int x){ | |
sz[x]=1; | |
for(int tt:to[x]){ | |
dfs(tt); | |
sz[x]+=sz[tt]; | |
if(sz[tt]>sz[son[x]]) | |
son[x]=tt; | |
} | |
int ma=0; | |
for(int tt:to[x]){ | |
if(tt!=son[x]){ | |
++cnt[sz[tt]]; | |
ma=max(ma,sz[tt]); | |
} | |
} | |
int msz=sz[x]-1-sz[son[x]]; | |
a.clear(); | |
for(int i=1;i<=ma;++i){ | |
if(cnt[i]){ | |
int t=1; | |
while(t<=cnt[i]){ | |
a.push_back(t*i); | |
cnt[i]-=t; | |
t<<=1; | |
} | |
if(cnt[i]){ | |
a.push_back(cnt[i]*i); | |
cnt[i]=0; | |
} | |
} | |
} | |
f[0]=1; | |
for(int i=1;i<=msz;++i) | |
f[i]=0; | |
std::sort(a.begin(),a.end()); | |
int s=0; | |
for(int v:a){ | |
for(int i=s+v;i>=v;--i) | |
if(f[i-v]) | |
f[i]=1; | |
// f[i]|=f[i-v]; | |
s+=v; | |
} | |
long long res=0; | |
for(int i=0;i<=s;++i) | |
if(f[i]) | |
res=max(res,(long long)i*(sz[x]-1-i)); | |
ans+=res; | |
} | |
int main(){ | |
// freopen("E.in","r",stdin); | |
// freopen("E.out","w",stdout); | |
N=read(); | |
for(int i=2;i<=N;++i) | |
to[read()].push_back(i); | |
dfs(1); | |
printf("%lld\n",ans); | |
} |