线段树

定义

线段树(segment tree),顾名思义,是用来存放给定区间(segment,or interval)内对应信息的一种数据结构。与树状数组(binary indexed tree)相似, 线段树也用来处理数组相应的区间查询(range query)元素更新(update)操作。与树状数组不同的是,线段树不止可以适用于区间求和的查询,也可以进行区间最大值, 区间最小值(Range Minimum/Maximum Query problem)或者区间异或值的查询

对应于树状数组,线段树进行更新(update)的操作为 \(O(log_n)\),进行区间查询(range query)的操作也为 \(O(long_n)\)

实现原理

从数据结构的角度来说,线段树是用一个完全二叉树来存储对应于每一个区间(segment)的数据。该二叉树的每一个结点中保存着对应于这一个区间的信息。 同时,线段树所使用的这个二叉树是用一个数组保存的,与堆(Heap)的实现方式相同。

例如,给定一个长度为 \(N\) 的数组 \(arr\),其所对应的线段树 \(T\) 各个结点的含义如下:

  1. \(T\) 的根节点代表整个数组所在的区间对应的信息,及 \(arr[0:N]\)不含N)所对应的信息。
  2. \(T\) 的每一个叶结点存储对应于输入数组的每一个单个元素构成的区间 \(arr[i]\) 所对应的信息,此处 \(0\leq i\lt N\)
  3. \(T\) 的每一个中间结点存储对应于输入数组某一区间 \(arr[i:j]\) 对应的信息,此处 \(0 \leq i \lt j \lt N\)

以根结点为例,根结点代表 \(arr[0:N]\) 区间所对应的信息,接着根节点被分为两个子树,分别存储 \(arr[0:(N - 1)/2]\)\(arr[(N - 1)/2:]\) 两个子区间对应的信息。也就是说,对于每一个结点,其左右子结点分别存储母结点区间拆分为两半之后各自区间的信息。也就是说对于长度为 \(N\) 的输入数组, 线段树的高度为 \(log_N\)

对于一个线段树来说,其应该支持的两种操作为:

  1. Update: 更新输入数组中的某一个元素并对线段树做相应的改变。
  2. Query: 用来查询某一区间对应的信息(如最大值,最小值,区间和等)。

线段树的初始化

线段树的初始化是自底向上进行的。从每一个叶子结点开始(也就是原数组中的每一个元素),沿从叶子结点到根结点的路径向上按层构建。 在构建的每一步中,对应两个子结点的数据将被用来构建应当存储于它们母结点中的值。每一个中间结点代表它的左右两个子结点对应区间 融合过后的大区间所对应的值。这个融合信息的过程可能依所需要处理的问题不同而不同(例如对于保存区间最小值得线段树来说,merge 的过程应为 min 函数, 用以取得两个子区间中的最小区间最小值作为当前融合过后的区间最小值)。 但从叶子结点(长度为1的区间)到根结点(代表输入的整个区间)更新的这一过程是统一的。

注意此处我们对于 segmentTree 数组的索引从 1 开始算起。则对于数组中的任意结点 \(i\),其左子结点为 \(2*i\),右子结点为 \(2*i + 1\),其母结点为 \(i/2\)

构建线段树的算法描述如下:

1
2
3
4
5
6
7
8
9
10
11
construct(arr):
// n 是子结点所在那一层的节点个数
n = length(arr)
// 开 2 * n 的数组,这样就可以一半用来存储子结点(叶子结点存储的是一个数组成的区间),一半用来存储非叶子结点(非叶子结点存储的是区间汇总信息)
segmentTree = new int[2 * n]
// 先初始化的是叶子结点,叶子结点的范围是 [n, 2 * n - 1]
for i from n to 2 * n - 1:
segmentTree[i] = arr[i - n]
// 初始化父结点,就是合并区间的操作,通过左右叶子结点得到区间汇总信息
for i from n - 1 to 1:
segmentTree[i] = merge(segmentTree[2 * i], segmentTree[2 * i + 1])

这里有几个点要说明(对于上面的遍历范围):

  1. 完全二叉树每一层的结点个数为 \(2 ^ {n - 1}\),(\(n\) 为自顶向下的层数,从 \(1\) 算起)
  2. 假设完全二叉树的总节点数量为 \(2 * n\),那么完全二叉树的叶子结点的下标范围为 \([n, 2 * n - 1]\)

例如给定一个输入数组 [1, 5, 3, 7, 3, 2, 5, 7],其所对应的最小值线段树应如下图所示:

上面所示线段树每一个节点代表的区间则如下图所示:

1
1 2
1 3 2 5
1 5 3 7 3 2 5 7

上面所示线段树每一个结点代表的区间则如下图所示:

1: \([0, 8)\)
2: \([0, 4)\) 3: \([4, 8)\)
4: \([0, 2)\) 5: \([2, 4)\) 6: \([4, 6)\) 7: \([6, 8)\)
8: \(0\) 9: \(1\) 10: \(2\) 11: \(3\) 12: \(4\) 13: \(5\) 14: \(6\) 15: \(7\)

如果用其数组表示来说,则数组 segmentTree 中的每一个位置代表的区间如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
segmentTree[1] = arr[0:8)
segmentTree[2] = arr[0:4)
segmentTree[3] = arr[4:8)
segmentTree[4] = arr[0:2)
segmentTree[5] = arr[2:4)
segmentTree[6] = arr[4:6)
segmentTree[7] = arr[6:8)
segmentTree[8] = arr[0]
segmentTree[9] = arr[1]
segmentTree[10] = arr[2]
segmentTree[11] = arr[3]
segmentTree[12] = arr[4]
segmentTree[13] = arr[5]
segmentTree[14] = arr[6]
segmentTree[15] = arr[7]

更新

更新一个线段树的过程与上述构造线段树的过程相同。当输入数组中位于 \(i\) 位置的元素被更新时,我们只需要从这一元素对应的叶子结点开始, 沿二叉树的路径向上更新至根结点即可。显然,这一过程是一个 \(O(log_n)\) 的操作。其算法如下:

1
2
3
4
5
6
7
8
9
10
update(i, value):
// 原数组在线段树里面的存储位置是从下标 n 开始的
i = i + n;
// 更新原数组的值
segmentTree[i] = value;
while (i > 1):
// 向上更新关联的父结点,直到根结点。i / 2 是 i 的父结点在线段树中的下标。
i = i / 2;
// 如果父结点在线段树的下标为 i,那么左孩子在线段树的下标为 2 * i,右孩子在线段树的下标为 2 * i + 1
segmentTree[i] = merge(segmentTree[2 * i], segmentTree[2 * i + 1])

例如对于上图中的线段树,如果我们调用 update(5, 6),则其更新过程如下所示:

  1. 修改原数组下标为 5 的元素为 6.
1
1 2
1 3 2 5
1 5 3 7 3 6 5 7
  1. 修改原数组下标 5 在线段树里面的父结点,由 2 改为 3
1
1 2
1 3 3 5
1 5 3 7 3 6 5 7
  1. 修改上一个修改结点的父结点,由 2 改为 3
1
1 3
1 3 3 5
1 5 3 7 3 6 5 7
  1. 修改上一个修改结点的父结点,改为 1(跟原来一样)
1
1 3
1 3 3 5
1 5 3 7 3 6 5 7

区间查询

区间查询大体上可以分为 3 种情况讨论:

  1. 当前结点所代表的区间完全位于给定需要被查询的区间之外,则不应该考虑当前结点。
  2. 当前结点所代表的区间完全位于给定需要被查询的区间之内,则可以直接查看当前结点的母结点
  3. 当前结点所代表的区间部分位于需要被查询的区间之内,部分位于其外,则我们优先考虑位于区间外的部分,后考虑区间内的 (注意总有可能找到完全位于区间内的结点,因为叶子结点的区间长度为 1,因此我们总能组合出合适的区间)

以求最小值为例,其算法如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
minimum(left, right):
// left 在线段树中的下标位置
left = left + n
// right 在线段树中的下标位置
right = right + n
// 初始化最小值
minimum = Integer.MAX_VALUE
while left < right:
if left is odd:
// left is out of range of parent interval, check value of left node first, then shift it right in the same level
minimum = min(minimum, segmentTree[left])
left = left + 1
if right is odd:
// right is out of range of current interval, shift it left in the same level and then check the value
right = right - 1
minimum = min(minimum, segmentTree[right])

// 向根结点遍历
left = left / 2
right = right / 2

为什么是奇数的时候做处理呢?具体看下图:

假设要求 \([1, 7)\) 范围内的最小值:

1
1 2
1 3 2 5
1 5 3 7 3 2 5 7

我们来看看只有三个节点的情况:

1
1 2
  1. 如果 \(left\) 是奇数,那就是上图里面 2 的位置,那么在这三个节点里面,唯一可能的值就是右下角的值 2
  2. 如果 \(right\) 是奇数,那就是上图里面 2 的位置,在这三个节点里面,因为我们取值范围不含右边界,那么在这三个节点里面,唯一可能的值就是左下角的值

因此对于下标是奇数的情况,其父节点不需要考虑,只需要考虑当前结点。

\(n\) 不是 2 的次方怎么办?

注意上面的讨论中我们由于需要不断二分区间,给定的输入数组的长度 \(n\) 为 2 的次方。那么当 \(n\) 不是 2 的次方,或者说, 当 \(n\) 无法被完全二分为一些长度为 1 的区间时,该如何处理呢?

一个简单的方法就是在原数组的结尾补 0,直到其长度正好为 2 的次方位置。但事实上这个方法比较低效。 最坏情况下,我们需要 \(O(4n)\) 的空间来存储相应的线段树。例如,如果输入数组的长度刚好为 \(2^x + 1\),则我们首先需要补 0 直到数组长度为 \(2^(x + 1) = 2 * 2^x\) 为止。 那么对于这个补 0 过后的数组,我们需要的线段树数组的长度为 \(2 * 2 * 2^x = 4 * 2 ^x = O(4n)\)

其实上面所说的算法对于 \(n\) 不是 2 的次方的情况同样适用。这也是为什么我在上文中说线段树是一棵 完全二叉树 而非 满二叉树 的原因。

例如对于输入数组 \([4, 3, 9, 1, 6, 7]\),其构造出的线段树应当如下图所示:

1
1 3
1 6 4 3
8 1 6 7

可以看出,在构造过程中我们事实上把一些长度为 1 的区间直接放在了树的倒数第二层来实现这个线段树。

Java 实现

Range Minimum Query

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
package com.baiguiren.leetcode;

import java.util.ArrayList;

public class MinSegmentTree {
public static void main(String[] args) {
MinSegmentTree tree = new MinSegmentTree(new int[]{1, 5, 3, 7, 3, 6, 5, 7});
int min = tree.minimum(1, 2);
System.out.println("min=" + min);
}

private ArrayList<Integer> minSegmentTree;
private int n;

public MinSegmentTree(int[] arr) {
n = arr.length;
minSegmentTree = new ArrayList<>(2 * n);

for (int i = 0; i < n; i++) {
minSegmentTree.add(0);
}

for (int i = n; i < 2 * n; i++) {
minSegmentTree.add(arr[i - n]);
}

for (int i = n - 1; i > 0; i--) {
minSegmentTree.set(i, Math.min(minSegmentTree.get(2 * i), minSegmentTree.get(2 * i + 1)));
}
}

public void update(int i, int value) {
i = i + n;
minSegmentTree.set(i, value);

while (i > 1) {
i = i / 2;
minSegmentTree.set(i, Math.min(minSegmentTree.get(2 * i), minSegmentTree.get(2 * i + 1)));
}
}

public int minimum(int left, int right) {
left = left + n;
right = right + n;
int min = Integer.MAX_VALUE;

// 如果 left == right 说明,最小值已经产生,只可能是下一层的数
// left == right 说明缩减到只有一个数字的区间了。
while (left < right) {
if ((left & 1) == 1) {
min = Math.min(min, minSegmentTree.get(left));
left = left + 1;
}

if ((right & 1) == 1) {
right--;
min = Math.min(min, minSegmentTree.get(right));
}

left >>= 1;
right >>= 1;
}

return min;
}
}

Range Sum Query

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
public class SumSegmentTree {
public static void main(String[] args) {
SumSegmentTree tree = new SumSegmentTree(new int[]{1, 5, 3, 7, 3, 6, 5, 7});
int sum = tree.sum(1, 3);
System.out.println("sum=" + sum);
}

private final ArrayList<Integer> sumSegmentTree;
private final int n;

public SumSegmentTree(int[] arr) {
n = arr.length;
sumSegmentTree = new ArrayList<>(2 * n);

for (int i = 0; i < n; i++) {
sumSegmentTree.add(0);
}

for (int i = n; i < 2 * n; i++) {
sumSegmentTree.add(arr[i - n]);
}

for (int i = n - 1; i > 0; i--) {
sumSegmentTree.set(i, sumSegmentTree.get(2 * i) + sumSegmentTree.get(2 * i + 1));
}
}

public void update(int i, int value) {
i = i + n;
sumSegmentTree.set(i, value);

while (i > 1) {
i = i / 2;
sumSegmentTree.set(i, sumSegmentTree.get(2 * i) +sumSegmentTree.get(2 * i + 1));
}
}

public int sum(int left, int right) {
left = left + n;
right = right + n;
int sum = 0;

while (left < right) {
if ((left & 1) == 1) {
sum += sumSegmentTree.get(left);
left = left + 1;
}

if ((right & 1) == 1) {
right--;
sum += sumSegmentTree.get(right);
}

left >>= 1;
right >>= 1;
}

return sum;
}
}