线段树入门

前言

本篇博客按照灯神的线段树进行讲解(这个思路太清晰易懂了!!!)这只是最基本的线段树,线段树还有一些进阶的用法,待后续补充。

先例相对基础,比较适合对线段树没有了解的人来入门。如果对线段树有一定了解可以直接跳过先例。

先例

对于一个数组arr
当我们想要求数组中某个区间(L, R)的和,时间复杂度为O(n),我们把这种操作叫做query
当我们想要更新数组中某个位置的元素值,时间复杂度为O(1),我们把这种操作叫做update

但是当我们的query的操作过多,这种方法就会变得很慢。这时我们就可以用前缀和来将query的时间复杂度降为O(1)。

img

前缀和就是数组的每一位都是这一位和之前所有位之和。即Sum[2] = arr[2]+arr[1]+arr[0]。

当我们要用前缀和数组来进行query(L, R)时,我们只需要用Sum[R]-Sum[L]。
可以看出query的时间复杂度为O(1)
但是此时update的时间复杂度又变成了O(n)(因为当我们修改数组某一个位置的值得时候,这个位置之后的所有Sum数组都要修改)

我们发现当我们把其中一个操作时间复杂度变成O(1)之后,另一个操作时间复杂度又会升到O(n)。而我们有没有一个办法将两种操作的时间复杂度平均一下呢?(废话,肯定有啊,不然我写这篇文章干嘛)当然是线段树了!它可以将query和update的时间复杂度都变成O(logn)

线段树讲解

那么线段树长什么样子呢?他比较像对前缀和数组的一种改版。

举个栗子:

建树

我们要对一个长度为6的数组arr构建线段树,数组如图所示,建树的步骤:

  1. 数组的根节点为数组所有元素的总和,即数组0号元素——5号元素之和:36
  2. 则其左孩子和右孩子将数组劈成两半
    • 左孩子为数组0号元素——2号元素之和:9
    • 右孩子为数组3号元素——5号元素之和:27
  3. 依次类推,直到所有元素不可再分割为止,即可得到如图所示的线段树。

query和update

现在我们再来试着看一看query操作和update操作:

  1. query:当我们要查询(L,R)的和,最坏的情况也就是从树的根结点一直查询到叶子结点,我们都知道这种操作的时间复杂度为O(logn)。例如查询(0,1)的和:
    • 从根结点出发,发现1<3,则往根结点的左孩子找
    • 来到根结点的左孩子,发现1<2,则继续往其左孩子找
    • 来到此处,我们发现这个区间正是我们要找的区间,返回结果
  2. update:当我们要更新某个结点的值的时候,就要先从根结点一直查询到叶子结点,之后再逐步向上更新,返回到根结点,这种操作的时间复杂度当然也是O(logn)。例如更新数组0号元素的值:
    • 从根结点出发,发现0<3,前往左结点
    • 来到[0-2]结点处,发现0<2,前往左结点
    • 来到[0-1]结点处,发现0<1,前往左结点
    • 来到0号元素的结点,更新结点值,返回父节点
    • 来到[0-1]结点处,更新结点值,返回父节点
    • 来到[0-2]结点处,更新结点值,返回父节点
    • 来到根结点处,更新结点值

线段树的存储

我们已经明白了线段树的基本原理,在上手实现线段树代码之前,我们要考虑怎么来存储线段树。

我们发现当线段树建好之后除了最后一层之外其余所有层都被填满了,这种结构很像完全二叉树,所以我们可以用数组来存储这颗树。(本来想用流程图画,结果感觉太麻烦了,还不如Notability来的方便)

319290DB6B0AAC4F709D96229DA1C3E6

如图,我们对树上的每个结点依次进行编号(从1开始),在叶子结点那一层空出来的地方我们用空结点给他补上(数组中可以设置一个值当做空值,例如-1)。这样当一个结点的编号为i的时候,其左孩子就是2×i,其右孩子就是2×i+1,其父结点就是i÷2

代码

现在我们了解了线段树的大体思路之后,线段树的代码就很清晰易懂了。

建树

代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
void build_tree(int arr[], int tree[], int node, int start, int end) 
{
if (start == end)
{
tree[node] = arr[start];
return;
}
else
{
int mid = (start + end) / 2;
int left = 2 * node + 1;
int right = 2 * node + 2;
build_tree(arr, tree, left, start, mid);
build_tree(arr, tree, right, mid + 1, end);
tree[node] = tree[left] + tree[right];
}
}

参数变量意义

先来讲一下每个参数的意义:

  • arr:要将这个数组建成一颗线段树,这个数组就是原数组。
  • tree:这个数组就是建成树之后的数组,就是线段树本树了。
  • node:就是当前递归层的根结点。例如我从根结点开始递归,那node就是根结点0了,根结点的下一层就是它的左右孩子了,我先从左孩子开始递归,那么这一层的node就是根结点的左孩子1了,其他的以此类推。
  • start:就是当前结点代表的区间的开始位置,例如当前结点代表的是0-5的和,那start就是0,end就是5了。当走到叶子结点的时候start==end。
  • end:同上。

再来讲一下函数中每个变量的意义:

  • mid:我们建树的时候每个结点都是它两个子结点的和,那他的子结点就是把当前结点的区间劈成了两半,mid就是中间劈开的那个位置了。例如根结点代表的是0-5的和,那mid就是2了。这样根结点的左孩子就是0-2的和,右孩子就是3-5的和了。
  • left:就是当前结点的左孩子在tree数组中的位置了,学过数据结构就很容易明白了(或者仔细看一看上面线段树的存储那一点)。
  • right:同上。

讲解

明白了参数和变量的意义,我们开始对代码的讲解。

建树是从根结点开始建立,递归地求出左子树和右子树的值。递归的出口就是叶子结点(即start==end时),这时将arr的值赋给tree。

更新

代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
void update_tree(int arr[], int tree[], int node, int start, int end, int index, int value) 
{
if (start==end)
{
tree[node] = value;
arr[index] = value;
return;
}
else
{
int mid = (start + end) / 2;
int left = 2 * node + 1;
int right = 2 * node + 2;
if (index <= mid)
{
update_tree(arr, tree, left, start, mid, index, value);
}
else
{
update_tree(arr, tree, right, mid + 1, end, index, value);
}
tree[node] = tree[left] + tree[right];
}
}

参数变量意义

重复参数或变量就略过了。

每个参数的意义:

  • index:即要更新的下标,这个下标是在arr数组中的下标。
  • value:要更新的值。

讲解

更新也是从根结点开始,递归地找到index代表的叶子结点,当在左子树就向左递归,在右子树就向右递归。递归的出口就是求到叶子结点时(即start==end时),将tree和arr的值修改。返回的过程中更新结点值。

查询

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
int query_tree(int tree[], int node, int start, int end, int L, int R)
{
if (L > end || R < start)
{
return 0;
}
else if (L <= start && end <= R)
{
return tree[node];
}

int mid = (start + end) / 2;
int left = 2 * node + 1;
int right = 2 * node + 2;

int sum_left = query_tree(arr, tree, left, start, mid, L, R);
int sum_right = query_tree(arr, tree, right, mid + 1, end, L, R);
return sum_left + sum_right;
}

参数变量意义

先来讲一下每个参数的意义:

  • L:查询区间的左边界
  • R:查询区间的右边界
  • return(返回值):即查询到的值

讲解

如果L和R分别在左子树和右子树,则需要递归去寻找左右子树中在L-R内的值。这个递归有两个出口,一个是当前区间不在L-R内,则返回0;两一个是当前区间在L-R内,则返回当前结点的值;如果当前区间部分在L-R内则需继续递归。返回的过程中要将左右子树的值和在一起。

Java模板

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
public static int []arr = new int[100010];
public static int []tree = new int[400010];

public static void build(int []arr, int []tree, int node, int start, int end) {
if (start==end) {
tree[node] = arr[start];
return;
}

int mid = (start+end)/2;
int left = node*2+1;
int right = node*2+2;
build(arr, tree, left, start, mid);
build(arr, tree, right, mid+1, end);
tree[node] = tree[left] + tree[right];
}

public static void update(int []arr, int []tree, int node, int start, int end, int index, int value) {
if (start==end) {
arr[index] = value;
tree[node] = arr[index];
return;
}

int mid = (start+end)/2;
int left = node*2+1;
int right = node*2+2;
if (index<=mid) {
update(arr, tree, left, start, mid, index, value);
}else {
update(arr, tree, right, mid+1, end, index, value);
}
tree[node] = tree[left] + tree[right];
}

public static int query(int []tree, int node, int start, int end, int L, int R) {
if (start>R || end<L) {
return 0;
}
if (L<=start && R>=end) {
return tree[node];
}

int mid = (start+end)/2;
int left = node*2+1;
int right = node*2+2;
int sum_left = query(tree, left, start, mid, L, R);
int sum_right = query(tree, right, mid+1, end, L, R);
return sum_left+sum_right;
}

C++模板

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
void build(int arr[], int tree[], int node, int start, int end) 
{
if (start == end)
{
tree[node] = arr[start];
return;
}
else
{
int mid = (start + end) / 2;
int left = 2 * node + 1;
int right = 2 * node + 2;
build_tree(arr, tree, left, start, mid);
build_tree(arr, tree, right, mid + 1, end);
tree[node] = tree[left] + tree[right];
}
}

void update(int arr[], int tree[], int node, int start, int end, int index, int value)
{
if (start==end)
{
tree[node] = value;
arr[index] = value;
return;
}
else
{
int mid = (start + end) / 2;
int left = 2 * node + 1;
int right = 2 * node + 2;
if (index <= mid)
{
update_tree(arr, tree, left, start, mid, index, value);
}
else
{
update_tree(arr, tree, right, mid + 1, end, index, value);
}
tree[node] = tree[left] + tree[right];
}
}

int query(int tree[], int node, int start, int end, int L, int R)
{
if (L > end || R < start)
{
return 0;
}
else if (L <= start && end <= R)
{
return tree[node];
}

int mid = (start + end) / 2;
int left = 2 * node + 1;
int right = 2 * node + 2;

int sum_left = query_tree(arr, tree, left, start, mid, L, R);
int sum_right = query_tree(arr, tree, right, mid + 1, end, L, R);
return sum_left + sum_right;
}
Donate comment here