树链剖分原理和实现

转自大佬: banananana
另一位大佬:ivanovcraft

树链剖分原理和实现

树链剖分就是将树分割成多条链,然后利用数据结构(线段树、树状数组等)来维护这些链。

首先就是一些必须知道的概念:
重结点:子树结点数目最多的结点;
轻节点:父亲节点中除了重结点以外的结点;
重边:父亲结点和重结点连成的边;
轻边:父亲节点和轻节点连成的边;
重链:由多条重边连接而成的路径;
轻链:由多条轻边连接而成的路径;
树链剖分

比如上面这幅图中,用黑线连接的结点都是重结点,其余均是轻结点,2-11就是重链,2-5就是轻链,用红点标记的就是该结点所在链的起点,也就是我们👇提到的top结点,还有每条边的值其实是进行dfs时的执行序号。

算法中定义了以下的数组用来存储上边提到的概念:

1
2
3
4
5
6
7
8
名称	解释
siz[u] 保存以u为根的子树节点个数
top[u] 保存当前节点所在链的顶端节点
son[u] 保存重儿子
dep[u] 保存结点u的深度值
faz[u] 保存结点u的父亲节点
tid[u] 保存树中每个节点剖分以后的新编号(DFS的执行顺序)
rnk[u] 保存当前节点在树中的位置

除此之外,还包括两种性质:
如果(u, v)是一条轻边,那么size(v) < size(u)/2;
从根结点到任意结点的路所经过的轻重链的个数必定都小与O(logn);
首先定义以下数组:

1
2
3
4
5
6
7
8
9
const int MAXN = (100000 << 2) + 10;
//Heavy-light Decomposition STARTS FORM HERE
int siz[MAXN];//number of son
int top[MAXN];//top of the heavy link
int son[MAXN];//heavy son of the node
int dep[MAXN];//depth of the node
int faz[MAXN];//father of the node
int tid[MAXN];//ID -> DFSID
int rnk[MAXN];//DFSID -> ID

算法大致需要进行两次的DFS,第一次DFS可以得到当前节点的父亲结点(faz数组)、当前结点的深度值(dep数组)、当前结点的子结点数量(size数组)、当前结点的重结点(son数组)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
void dfs1(int u, int father, int depth) {
/*
* u: 当前结点
* father: 父亲结点
* depth: 深度
*/
// 更新dep、faz、siz数组
dep[u] = depth;
faz[u] = father;
siz[u] = 1;
// 遍历所有和当前结点连接的结点
for (int i = head[u]; i; i = edg[i].next) {
int v = edg[i].to;
// 如果连接的结点是当前结点的父亲结点,则不处理
if (v != faz[u]) {
dfs1(v, u, depth + 1);
// 收敛的时候将当前结点的siz加上子结点的siz
siz[u] += siz[v];
// 如果没有设置过重结点son或者子结点v的siz大于之前记录的重结点son,则进行更新
if (son[u] == -1 || siz[v] > siz[son[u]]) {
son[u] = v;
}
}
}
}

第二次DFS的时候则可以将各个重结点连接成重链,轻节点连接成轻链,并且将重链(其实就是一段区间)用数据结构(一般是树状数组或线段树)来进行维护,并且为每个节点进行编号,其实就是DFS在执行时的顺序(tid数组),以及当前节点所在链的起点(top数组),还有当前节点在树中的位置(rank数组)。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
void dfs2(int u, int t) {
/**
* u:当前结点
* t:起始的重结点
*/
top[u] = t; // 设置当前结点的起点为t
tid[u] = cnt; // 设置当前结点的dfs执行序号
rnk[cnt] = u; // 设置dfs序号对应成当前结点
cnt++;
// 如果当前结点没有处在重链上,则不处理
if (son[u] == -1) {
return;
}
// 将这条重链上的所有的结点都设置成起始的重结点
dfs2(son[u], t);
// 遍历所有和当前结点连接的结点
for (int i = head[u]; i; i = edg[i].next) {
int v = edg[i].to;
// 如果连接结点不是当前结点的重子结点并且也不是u的父亲结点,则将其的top设置成自己,进一步递归
if (v != son[u] && v != faz[u]){
dfs2(v, v);
}
}
}

而修改和查询操作原理是类似的,以查询操作为例,其实就是个LCA,不过这里使用了top来进行加速,因为top可以直接跳转到该重链的起始结点,轻链没有起始结点之说,他们的top就是自己。需要注意的是,每次循环只能跳一次,并且让结点深的那个来跳到top的位置,避免两个一起跳从而插肩而过。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
INT64 query_path(int x, int y) {
/**
* x:结点x
* y:结点y
* 查询结点x到结点y的路径和
*/
INT64 ans = 0;
int fx = top[x], fy = top[y];
// 直到x和y两个结点所在链的起始结点相等才表明找到了LCA
while (fx != fy) {
if (dep[fx] >= dep[fy]) {
// 已经计算了从x到其链中起始结点的路径和
ans += query(1, tid[fx], tid[x]);
// 将x设置成起始结点的父亲结点,走轻边,继续循环
x = faz[fx];
} else {
ans += query(1, tid[fy], tid[y]);
y = faz[fy];
}
fx = top[x], fy = top[y];
}

// 即便找到了LCA,但是前面也只是分别计算了从一开始到最终停止的位置和路径和
// 如果两个结点不一样,表明仍然需要计算两个结点到LCA的路径和
if (x != y) {
if (tid[x] < tid[y]) {
ans += query(1, tid[x], tid[y]);
} else {
ans += query(1, tid[y], tid[x]);
}
} else ans += query(1, tid[x], tid[y]);
return ans;
}

void update_path(int x, int y, int z) {
/**
* x:结点x
* y:结点y
* z:需要加上的值
* 更新结点x到结点y的值
*/
int fx = top[x], fy = top[y];
while(fx != fy) {
if (dep[fx] > dep[fy]) {
update(1, tid[fx],tid[x], z);
x = faz[fx];
} else {
update(1, tid[fy], tid[y], z);
y = faz[fy];
}
fx = top[x], fy = top[y];
}
if (x != y)
if (tid[x] < tid[y]) update(1, tid[x], tid[y], z);
else update(1, tid[y], tid[x], z);
else update(1, tid[x], tid[y], z);
}


个人写法

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
int n, m, r, rt, mod, v[maxn], head[maxn], cnt, fa[maxn], size[maxn], d[maxn], son[maxn], top[maxn], id[maxn], rk[maxn];
//fa 父亲节点 size 大小 d 深度 son 重儿子 , top 链顶端 id dfs序 rk dfs序对应的节点

vector<int> G[maxn];

void add(int x, int y) {
G[x].emplace_back(y);
}

void dfs1(int x) {
size[x] = 1;
d[x] = d[fa[x]] + 1;
for (auto u:G[x]) {
if (u != fa[x]) {
fa[u] = x;
dfs1(u);
size[x] += size[u];
if (size[son[x]] < size[u]) {
son[x] = u;
}
}
}
}

void dfs2(int x, int tp) {
top[x] = tp;
id[x] = ++cnt;
rk[cnt] = x;
if (son[x]) {
dfs2(son[x], tp);
}
for (auto u:G[x]) {
if (u != fa[x] && u != son[x]) {
dfs2(u, u);
}
}
}


inline int sum(int x, int y) {
int res = 0;
while (top[x] != top[y]) {
if (d[top[x]] < d[top[y]]) {
swap(x, y);
}
//TODO
res = (res + query(id[top[x]], id[x], rt)) % mod;
x = fa[top[x]];
}
if (id[x] > id[y]) {
swap(x, y);
}
// TODO
res = (res + query(id[x], id[y], rt)) % mod;
return res;
}

inline void updates(int x, int y, int c) {
while (top[x] != top[y]) {
if (d[top[x]] < d[top[y]]) {
swap(x, y);
}
//TODO
update(id[top[x]], id[x], c, rt);
x = fa[top[x]];
}
if (id[x] > id[y]) {
swap(x, y);
}
// TODO
update(id[x], id[y], c, rt);
}