Just a Blog

增强复用型憨憨线段树合并

线段树合并

前置芝士:动态开点线段树

线段树节省空间的一种方法,就是不一开始就建好整棵树,而是在访问到这个节点的同时再进行更新。

普通的线段树要在最前面加入一个build函数

1
2
3
4
5
6
void build(int rt, int l, int r){
if(l == r){sum[rt] = a[l]; return;}
build(ls, l, mid);
build(rs, mid + 1, r);
push_up(rt);
}

但是在动态开点线段树中却被newnode函数取代

1
2
3
4
5
void newnode(int &rt, int l, int r){
if(l > r)return;
rt = ++cnt;
sum[rt] = r - l + 1;
}

并且此时由于建点没有必然顺序,所以左右儿子用ls[rt], rs[rt]表示。

然后以单点修改为例

1
2
3
4
5
6
7
8
9
10
11
12
13
int change(int rt, int l, int r, int pos, int k){
if(!rt)rt = newnode(rt, l, r);
if(!ls[rt])ls[rt] = newnode(ls[rt], l, mid);
if(!rs[rt])rs[rt] = newnode(rs[rt], mid + 1, r);
if(l == r){
sum[rt] = k;
return;
}
push_down(rt, l, r);
if(pos <= mid)ls[rt] = change(ls[rt], l, mid, pos, k);
else rs[rt] = change(rs[rt], mid + 1, r, pos, k);
return rt;
}

就是这样。

线段树合并方式

你现在要维护一个子树中的权值最大值。

有两棵动态开点开出的线段树,你要把他们摁到一棵线段树里。

你直接递归进入两个根节点的左右儿子,重复此操作直到成为叶节点。

因为线段树是动态建的,所以在没有该节点时可以直接返回另一个节点的值

如果递归到 $ l == r $,那么就可以直接权值相加了。

代码:

1
2
3
4
5
6
7
8
9
int merge(int a, int b, int l, int r){
if(!a)return b; if(!b)return a;
if(l == r){val[a] += val[b];pos[a] = l; return a;}
int mid = (l + r) >> 1;
ls[a] = merge(ls[a], ls[b], l, mid);
rs[a] = merge(rs[a], rs[b], mid + 1, r);
push_up(a);
return a;
}

值得注意的一点是,这样直接覆盖节点,由于b树的信息被保留在a树上,在更新时可能会导致错误。
因此只适用于离线做法

但是如果在线段树合并的时候参考主席树的做法新建节点,就不会有问题了,缺点是非常炸空间…

线段树合并解决的问题

你有一棵树,你要解决一种子树问题,(比如子树中出现次数最多的权值),对于每一个点都可以非常方便的写一棵线段树维护,这时候使用线段树合并进行统计就非常方便了。

一道例题:

luoguP4556 [Vani有约会]雨天的尾巴

题目传送门

针对每一个节点建立权值线段树存救济粮种类,差分后转化成单点修改,最后使用线段树合并统计答案。

由于本题只有128MB,所以采用覆盖节点的离线做法。

代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
#include <bits/stdc++.h>
using namespace std;
#define fr(i, j, k) for(register int i = j; i <= k; ++i)
#define rf(i, j, k) for(register int i = j; i >= k; --i)
#define fre(i, now) for(int i = h[now]; i; i = nt[i])
#define gc ch = getchar()
int read(){
int ret = 0, f = 1; char gc;
while(!isdigit(ch) && (ch ^ '-'))gc;
if(!(ch ^ '-'))f = -1, gc;
while(isdigit(ch)){ret = (ret << 3) + (ret << 1) + (ch ^ '0'), gc;}
return ret * f;
}
#undef gc
const int maxn = 100010;
const int maxt = 6000010;
int nt[maxn << 1], to[maxn << 1];
int h[maxn], tot = 1;
void adde(int a, int b){
nt[++tot] = h[a], to[tot] = b, h[a] = tot;
nt[++tot] = h[b], to[tot] = a, h[b] = tot;
}

int siz[maxn], son[maxn], dep[maxn], fa[maxn], top[maxn];
void dfs1(int now, int ff){
fa[now] = ff, siz[now] = 1, son[now] = 0, dep[now] = dep[ff] + 1;
fre(i, now){
int nex = to[i];
if(nex == ff)continue;
dfs1(nex, now);
siz[now] += siz[nex];
if(!son[now] || siz[nex] > siz[son[now]])son[now] = nex;
}
}
void dfs2(int now, int t){
top[now] = t;
if(!son[now])return;
dfs2(son[now], t);
fre(i, now){
int nex = to[i];
if(nex == son[now] || nex == fa[now])continue;
dfs2(nex, nex);
}
}
inline int lca(int x, int y){
while(top[x] != top[y]){
if(dep[top[x]] < dep[top[y]])swap(x, y);
x = fa[top[x]];
}
if(dep[x] > dep[y])return y;
return x;
}
int rt[maxn], l[maxt], r[maxt], val[maxt], pos[maxt];
int ans[maxn], x[maxn], y[maxn], z[maxn];
int cnt = 0, R, n, m;
void push_up(int a){
if(val[l[a]] >= val[r[a]])val[a] = val[l[a]], pos[a] = pos[l[a]];
else val[a] = val[r[a]], pos[a] = pos[r[a]];
}
int change(int &a, int x, int y, int p, int v){
if(!a) a = ++cnt;
if(x == y){val[a] += v, pos[a] = x; return a;}
int mid = (x + y) >> 1;
if(p <= mid){l[a] = change(l[a], x, mid, p, v);}
else r[a] = change(r[a], mid + 1, y, p, v);
push_up(a);
return a;
}
int merge(int a, int b, int x, int y){
if(!a)return b; if(!b)return a;
if(x == y){val[a] += val[b];pos[a] = x; return a;}
int mid = (x + y) >> 1;
l[a] = merge(l[a], l[b], x, mid);
r[a] = merge(r[a], r[b], mid + 1, y);
push_up(a);
return a;
}
void get_ans(int now){
fre(i, now){
int nex = to[i];
if(nex == fa[now])continue;
get_ans(nex);
rt[now] = merge(rt[now], rt[nex], 1, R);
}
if(val[rt[now]])ans[now] = pos[rt[now]];
}
int main(){
n = read(), m = read();
int a, b;
fr(i, 1, n - 1){
a = read(), b = read();
adde(a, b);
}
dfs1(1, 0);
dfs2(1, 1);
fr(i, 1, m){
x[i] = read(), y[i] = read(), z[i] = read();
R = max(R, z[i]);
}
fr(i, 1, m){
int L = lca(x[i], y[i]);
rt[x[i]] = change(rt[x[i]], 1, R, z[i], 1);
rt[y[i]] = change(rt[y[i]], 1, R, z[i], 1);
rt[L] = change(rt[L], 1, R, z[i], -1);
if(fa[L])rt[fa[L]] = change(rt[fa[L]], 1, R, z[i], -1);
}
get_ans(1);
fr(i, 1, n){
cout<<ans[i]<<endl;
}
return 0;
}

 评论