Just a Blog

dsu on tree

这是什么?

dsu on tree 指的是“树上启发式合并”。(和并查集没有关系

解决的是一类不带修改的子树统计问题。

引例

一棵有根树,每一个节点上都有一个颜色。

求每一个节点i的子树中出现颜色为k的节点。

怎么做?

可以很简单的得到一个( $n^2$ )的做法,从每个点开始dfs一遍就完了。

当然你也可以dfs序+主席树水过,但是此做法不在本博客讨论范围内。

考虑一下暴力的过程。

我们对于每个节点,暴力地通过dfs统计出其子树中的答案,再暴力地删除贡献,计算下一个节点。

这个过程中很明显进行了一堆重复的操作。

如何优化呢?

有一个很妙的想法(即dsu on tree),参照了重链剖分的思路。

那就是求出每个点的重儿子,然后先递归进入每一个点的轻儿子统计答案,随后消除其影响。

再针对重儿子统计答案,但是不消除其影响。

最后再得出该点的贡献。

算法流程

  • 剖出重儿子与轻儿子

  • 递归进入轻儿子统计答案并消除影响

  • 进入重儿子统计答案,不消除影响

  • 统计该点贡献

  • 如果该点是轻儿子,那么消除其影响

代码长这样:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
void dfs(int now, int opt){//①
fre(i, now){
if(nex == fa[now] || nex == son[now])continue;
dfs(nex, 0);
}
if(son[now])dfs(son[now], 1);
pos = son[now];
calc(son[now], 1);//②
pos = 0;
ans[now] = sum;
if(!opt){
clac(now, -1);//③
sum = maxx = 0;
}
}

①:opt == 1 表示不清除当前点的影响, opt == 0 表示清除影响

②:表示计算重儿子贡献并不清除dfs

③:表示去除轻儿子贡献

复杂度计算

例题

CF666E Lomsat gelral

题目传送门

一棵树有n个结点,每个结点都是一种颜色,每个颜色有一个编号,求树中每个子树的最多的颜色编号的和。

其实就是例题,直接dfs暴力统计就可以了。
代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
#include <bits/stdc++.h>
using namespace std;
#define fr(i, j, k) for(register int i = j; i <= k; ++i)
#define rf(i, j, k) for(register int i = j; i >= k; --i)
#define fre(i, now) for(register int i = h[now], nex = to[i]; i; i = nt[i], nex = to[i])
#define gc ch = getchar()
int read(){
int ret = 0, f = 1; char gc;
while(!isdigit(ch) && (ch ^ '-'))gc;
if(!(ch ^ '-'))f = -1, gc;
while(isdigit(ch)){ret = (ret << 3) + (ret << 1) + (ch ^ '0'), gc;}
return ret * f;
}
#undef gc
const int maxn = 1e5 + 10;
int n;
int to[maxn << 1], nt[maxn << 1];
int h[maxn], tot = 1;
void adde(int a, int b){
to[++tot] = b, nt[tot] = h[a], h[a] = tot;
to[++tot] = a, nt[tot] = h[b], h[b] = tot;
}
int son[maxn], fa[maxn], siz[maxn];
void dfs1(int now, int f){
// cout<<now<< " "<<f<<endl;
fa[now] = f;
siz[now] = 1;
fre(i, now){
// cout<<"hhhh"<<to[i]<<" "<<nex<<endl;
if(nex == f)continue;
// cout<<"!!"<<f<<" "<<nex<< " "<<now<<endl;
dfs1(nex, now);
siz[now] += siz[nex];
if(!son[now] || (siz[son[now]] < siz[nex]))son[now] = nex;
}
return;
}
int cnt[maxn], cor[maxn];
int now_son, maxx;
long long sum;
long long ans[maxn];
void add(int now, int val){
cnt[cor[now]] += val;
if(cnt[cor[now]] > maxx){maxx = cnt[cor[now]];sum = cor[now];}
else if(cnt[cor[now]] == maxx){sum += (long long)cor[now];}
fre(i, now){
if(nex == fa[now] || nex == now_son)continue;
add(nex, val);
}
}
void dfs2(int now, int opt){
fre(i, now){
if(nex == fa[now])continue;
if(nex != son[now])dfs2(nex, 0);
}
if(son[now]) dfs2(son[now], 1), now_son = son[now];
add(now, 1);
now_son = 0;
ans[now] = sum;
if(!opt){
add(now, -1);
sum = 0, maxx = 0;
}
}
int main(){
n = read();
int x, y;
fr(i, 1, n){
cor[i] = read();
}
fr(i, 1, n - 1){
x = read(), y = read();
adde(x, y);
}
// fr(i, 1, n){
// fre(j, i){
// cout<<i<< " "<<nt[j]<< " "<<to[j]<< " "<<endl;
// }
// cout<<endl;
// }
dfs1(1, 0);
// fr(i, 1, n){
// cout<<siz[i]<< " "<<son[i]<<endl;
// }
dfs2(1, 0);
fr(i, 1, n){
printf("%I64d ", ans[i]);
}
return 0;
}

CF741D(题目太长不贴了)

题目传送门

一棵根为1 的树,每条边上有一个字符(a-v共22种)。

一条简单路径被称为Dokhtar-kosh当且仅当路径上的字符经过重新排序后可以变成一个回文串。

求每个子树中最长的Dokhtar-kosh路径的长度。

字符范围只有[0, 22],考虑状压。

进一步分析,当且仅当所有字符都出现偶数次或者只有一个字符出现奇数次时才合法。

也就是说状压之后为2的幂次的路径才是合法的。

我们的目标就是对于每个点x,在其子树中找到点a,b。

使得d[a] ^ d[b] 为2的幂次且dep[a] + dep[b] - dep[lca] * 2 最大。

(d[i]指的是状压之后从1到i的路径)

使用dsu on tree 解决即可。

代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
#include <bits/stdc++.h>
using namespace std;
#define fr(i, j, k) for(register int i = j; i <= k; ++i)
#define rf(i, j, k) for(register int i = j; i >= k; --i)
#define fre(i, now) for(register int i = h[now]; i; i = nt[i])
#define gc ch = getchar()
int read(){
int ret = 0, f = 1; char gc;
while(!isdigit(ch) && (ch ^ '-'))gc;
if(ch == '-') f = -1, gc;
while(isdigit(ch)){ret = (ret << 3) + (ret << 1) + (ch ^ '0'), gc;}
return ret * f;
}
#undef gc
const int maxn = 5e5 + 10;
int n;
int to[maxn << 1], nt[maxn << 1], val[maxn << 1];
int h[maxn], tot = 1;
void adde(int a, int b, int c){
to[++tot] = b, nt[tot] = h[a], val[tot] = c, h[a] = tot;
to[++tot] = a, nt[tot] = h[b], val[tot] = c, h[b] = tot;
}
int d[maxn], fa[maxn], siz[maxn], son[maxn], dep[maxn];
void dfs1(int now, int ff){
fa[now] = ff, siz[now] = 1;
// cout<<now<< " "<<ff<<endl;
fre(i, now){
int nex = to[i];
if(nex == ff)continue;
d[nex] = d[now] ^ (1 << val[i]);
dep[nex] = dep[now] + 1;
dfs1(nex, now);
siz[now] += siz[nex];
if(!son[now] || siz[son[now]] < siz[nex])son[now] = nex;
}
}
int ccf[1 << 22], maxx, last, ans[maxn];
void update(int u){
ccf[d[u]] = max(dep[u], ccf[d[u]]);
}
void cal(int now){
if(ccf[d[now]]){maxx = max(maxx, dep[now] + ccf[d[now]] - last);}
fr(i, 0, 21)if(ccf[(1 << i) ^ d[now]]){maxx = max(maxx, dep[now] + ccf[(1 << i) ^ d[now]] - last);}
}
void getup(int now){
update(now);
fre(i, now){
int nex = to[i];
if(nex == fa[now])continue;
getup(nex);
}
}
void getcal(int now){
cal(now);
fre(i, now){
int nex = to[i];
if(nex == fa[now])continue;
getcal(nex);
}
}
void clear(int now){
ccf[d[now]] = 0;
fre(i, now){
int nex = to[i];
if(nex == fa[now])continue;
clear(nex);
}
}
void dfs2(int now, int opt){
// cout<<now<< " "<<fa[now]<<endl;
fre(i, now){
int nex = to[i];
if(nex == fa[now] || nex == son[now])continue;
dfs2(nex, 0);
}
if(son[now])dfs2(son[now], 1);
last = dep[now] << 1;
fre(i, now){
int nex = to[i];
if(nex == fa[now])continue;
maxx = max(ans[nex], maxx);
}
fre(i, now){
int nex = to[i];
if(nex == fa[now] || nex == son[now])continue;
getcal(nex);
getup(nex);
}
cal(now); update(now);
ans[now] = maxx;
if(!opt){
maxx = 0;
clear(now);
}
}
int main(){
// freopen("test.in", "r", stdin);
// freopen("test.out", "w", stdout);
n = read();
int x;
char z[15];
fr(i, 2, n){
x = read();
cin>>z;
adde(i, x, (int)z[0] - 'a');
}
dfs1(1, 0);
dfs2(1, 1);
fr(i, 1, n){
printf("%d ", ans[i]);
}
return 0;
}

参考资料

https://www.cnblogs.com/zcysky/p/6822395.html
https://www.cnblogs.com/zwfymqz/p/9683124.html

感谢。

 评论