树状数组和线段树

树状数组和线段树是两种常用的数据结构,其可以大大提升数组的区间查询的效率,同时也保证了数据修改的灵活度。

一般数组 前缀和数组 树状数组 线段树
单点查询 \(O(1)\) \(O(1)\) \(O(logn)\) \(O(logn)\)
区间查询 \(O(n)\) \(O(1)\) \(O(logn)\) \(O(logn)\)
单点修改 \(O(1)\) \(O(n)\) \(O(logn)\) \(O(logn)\)
区间修改 \(O(n)\) \(O(n^2)\) \(O(nlogn)\) \(O(logn)\)

树状数组

树状数组的原理讲解可以参考视频:五分钟丝滑动画讲解 | 树状数组

用于区间求和和单点修改的树状数组模版
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
class BIT {
private:
int n;
std::vector<int> tree;

inline int lowbit(int x) {
// 数状数组第 i 项的区间长度为 lowbit(i)
return x & -x;
}

int count(int m) {
// 求前 m 项的和
int res = 0;
while (m) {
res += tree[m];
m -= lowbit(m);
}

return res;
}
public:
BIT(int _n): n(_n), tree(_n + 1, 0) {}

void add(int idx, int val) {
// 单点修改
while (idx <= n) {
tree[idx] += val;
idx += lowbit(idx);
}
}

int rangeSum(int left, int right) {
// 区间求和
return count(right) - count(left - 1);
}
};
  • 树状数组的核心就是 lowbit() 函数,数状数组第 i 项代表的区间长度为 lowbit(i)
  • 树状数组的下标只能从 1 开始,不能从 0 开始
用于求前缀最大值的树状数组模版
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
class BIT {
private:
int n;
std::vector<int> tree;

public:
BIT(int _n): n(_n), tree(_n + 1) {}

int query(int idx) {
int res = 0;
while (idx) {
res = std::max(res, tree[idx]);
idx &= idx - 1;
}

return res;
}

void update(int idx, int val) {
while (idx <= n) {
tree[idx] = std::max(tree[idx], val);
idx += idx & -idx;
}
}
};

线段树

线段树相对于树状数组而言则更为灵活,其可以实现高效区间修改。

线段树的原理就是将数组的区间储存在二叉树的节点中,[left, right] 区间对应的左右节点分别为 [left, mid][mid + 1, right]mid = (left + right) / 2)。

‼️注:线段树的数组长度要开到原数组长度的 4 倍

用于区间求和的线段树模版
  • 单点更新
  • 区间求和
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
class SegmentTree {
private:
int n;
std::vector<int> tree;

void update(int idx, int val, int node, int start, int end) {
// 单点更新
if (start == end) {
tree[node] = val;
return;
}
int leftNode = 2 * node + 1, rightNode = 2 * node + 2;
int mid = (start + end) / 2;
if (idx <= mid) {
update(idx, val, leftNode, start, mid);
}
else {
update(idx, val, rightNode, mid + 1, end);
}

tree[node] = tree[leftNode] + tree[rightNode];
}

int rangeSum(int left, int right, int node, int start, int end) {
// 区间求和
if (start > right || end < left) return 0;
if (start >= left && end <= right) return tree[node];
int leftNode = 2 * node + 1, rightNode = 2 * node + 2;
int mid = (start + end) / 2;

return rangeSum(left, right, leftNode, start, mid) + rangeSum(left, right, rightNode, mid + 1, end);
}

public:
SegmentTree(int _n): n(_n), tree(4 * _n, 0) {}

void update(int idx, int val) {
update(idx, val, 0, 0, n - 1);
}

int rangeSum(int left, int right) {
return rangeSum(left, right, 0, 0, n - 1);
}
};
用于求区间极值的线段树模版
  • 单点更新
  • 区间求极值
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
class SegmentTree {
private:
int n;
std::vector<int> tree;

void update(int idx, int val, int node, int start, int end) {
// 单点更新
if (start == end) {
tree[node] = val;
return;
}
int leftNode = 2 * node + 1, rightNode = 2 * node + 2;
int mid = (start + end) / 2;
if (idx <= mid) {
update(idx, val, leftNode, start, mid);
}
else {
update(idx, val, rightNode, mid + 1, end);
}

tree[node] = std::max(tree[leftNode], tree[rightNode]);
}

int maxVal(int left, int right, int node, int start, int end) {
// 求区间极大值
if (start > right || end < left) return INT_MIN;
if (start >= left && end <= right) return tree[node];
int leftNode = 2 * node + 1, rightNode = 2 * node + 2;
int mid = (start + end) / 2;

return std::max(maxVal(left, right, leftNode, start, mid), maxVal(left, right, rightNode, mid + 1, end));
}

public:
SegmentTree(int _n): n(_n), tree(4 * _n, 0) {}

void update(int idx, int val) {
update(idx, val, 0, 0, n - 1);
}

int maxVal(int left, int right) {
return maxVal(left, right, 0, 0, n - 1);
}
};

有时候我们可能会遇到这样的需求:要多次将数组中一个区间内的每个元素都添加一个固定的值,如果逐一修改,则会消耗大量的时间,这个时候我们就可以使用带延迟标记的线段树。

什么是延迟标记?

—— 即对线段树的某个节点的数据更新完后不急于对其子节点进行更新,而是将更新信息存储下来,而当必须更新的时候再将信息传递给子节点。

使用延迟标记进行区间修改的线段树模版
  • 区间修改
  • 区间求和
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
class SegmentTree {
private:
int n;
std::vector<int> tree;
std::vector<int> lazy;

inline void maintain(int node, int start, int end) {
// 传递 lazy 标签
int leftNode = 2 * node + 1, rightNode = 2 * node + 2;
int mid = (start + end) / 2;
tree[leftNode] += (mid - start + 1) * lazy[node];
lazy[leftNode] += lazy[node];
tree[rightNode] += (end - mid) * lazy[node];
lazy[rightNode] += lazy[node];
lazy[node] = 0;
}

void add(int left, int right, int val, int node, int start, int end) {
// 给区间 [left, right] 的所有数添加 val
if (start > right || end < left) return;
if (start >= left && end <= right) {
tree[node] += (end - start + 1) * val;
lazy[node] += val;
return;
}
if (lazy[node]) maintain(node, start, end);
int leftNode = 2 * node + 1, rightNode = 2 * node + 2;
int mid = (start + end) / 2;
add(left, right, val, leftNode, start, mid);
add(left, right, val, rightNode, mid + 1, end);

tree[node] = tree[leftNode] + tree[rightNode];
}

int rangeSum(int left, int right, int node, int start, int end) {
// 区间求和
if (start > right || end < left) return 0;
if (start >= left && end <= right) return tree[node];
if (lazy[node]) maintain(node, start, end);
int leftNode = 2 * node + 1, rightNode = 2 * node + 2;
int mid = (start + end) / 2;

return rangeSum(left, right, leftNode, start, mid) + rangeSum(left, right, rightNode, mid + 1, end);
}


public:
SegmentTree(int _n): n(_n), tree(4 * _n), lazy(4 * _n) {}

void add(int left, int right, int val) {
add(left, right, val, 0, 0, n - 1);
}

int rangeSum(int left, int right) {
return rangeSum(left, right, 0, 0, n - 1);
}
};

案例

逆序对记数

题目链接

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
#include <iostream>
#include <algorithm>

const int MAX_N = 5e5 + 5;

struct {
int val;
int id;
} arr[MAX_N];

int rank[MAX_N];

namespace __BIT {
int bit[MAX_N];
int n;

inline int lowbit(int x) {
return x & -x;
}

int query(int idx) {
int res = 0;
while (idx) {
res += bit[idx];
idx -= lowbit(idx);
}

return res;
}

void add(int idx, int val) {
while (idx <= n) {
bit[idx] += val;
idx += lowbit(idx);
}
}
}

using namespace __BIT;
using i64 = long long;

int main() {
std::ios::sync_with_stdio(false);
std::cin.tie(nullptr);
std::cout.tie(nullptr);

std::cin >> n;
for (int i = 1; i <= n; i++) {
std::cin >> arr[i].val;
arr[i].id = i;
}

std::sort(arr + 1, arr + n + 1, [](const auto& a1, const auto& a2) {
return a1.val != a2.val ? a1.val < a2.val : a1.id < a2.id;
});

for (int i = 1; i <= n; i++) rank[arr[i].id] = i;
i64 res = 0;
for (int i = 1; i <= n; i++) {
res += i - 1 - query(rank[i]);
add(rank[i], 1);
}

std::cout << res << '\n';

return 0;
}
二维数点

题目链接

二维数点是树状数组的一个典型应用,本题的 AC 代码如下:

注:需要开启 O2 优化

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
#include <iostream>
#include <algorithm>

const int MAX_N = 5e5 + 5, MAX_M = 5e5 + 5;

int n, m;
struct Tree {
int x;
int y;
} tree[MAX_N];

struct Query {
int x;
int y;
int id;
int res;
} q[4 * MAX_M];

const int MAX_Y = 1e7 + 7;
namespace __BIT {
int bit[MAX_Y];
int len;

inline int lowbit(int x) {
return x & -x;
}

int query(int idx) {
idx = idx <= len ? idx : len;
int res = 0;
for (int i = idx; i; i -= lowbit(i)) res += bit[i];

return res;
}

void add(int idx, int val) {
for (int i = idx; i <= len; i += lowbit(i)) bit[i] += val;
}
}

using namespace __BIT;

int main() {
std::ios::sync_with_stdio(false);
std::cin.tie(nullptr);
std::cout.tie(nullptr);

std::cin >> n >> m;
for (int i = 0; i < n; i++) {
std::cin >> tree[i].x >> tree[i].y;
tree[i].x++, tree[i].y++;
len = std::max(len, tree[i].y);
}
std::sort(tree, tree + n, [](const Tree& t1, const Tree& t2) {
return t1.x < t2.x;
});
int all = 4 * m;
for (int i = 0, a, b, c, d; i < all; i += 4) {
std::cin >> a >> b >> c >> d;
a++, b++, c++, d++;
q[i] = {c, d, i};
q[i + 1] = {a - 1, d, i + 1};
q[i + 2] = {c, b - 1, i + 2};
q[i + 3] = {a - 1, b - 1, i + 3};
}
std::sort(q, q + all, [](const Query& q1, const Query& q2) {
return q1.x < q2.x;
});
int j = 0;
for (int i = 0; i < n; i++) {
while (j < all && tree[i].x > q[j].x) {
// 当前范围已经统计结束
q[j].res = query(q[j].y);
j++;
}
add(tree[i].y, 1);
}

while (j < all) {
q[j].res = query(q[j].y);
j++;
}

std::sort(q, q + all, [](const Query& q1, const Query& q2) {
return q1.id < q2.id;
});

for (int i = 0; i < all; i += 4) {
std::cout << q[i].res - q[i + 1].res - q[i + 2].res + q[i + 3].res << '\n';
}

return 0;
}
统计最长递增子序列

题目链接

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
class Solution {
private:
int n;
std::vector<int> tree;

int query(int idx) {
int res = 0;
while (idx) {
res = std::max(res, tree[idx]);
idx &= idx - 1;
}

return res;
}

void update(int idx, int val) {
while (idx <= n) {
tree[idx] = std::max(tree[idx], val);
idx += idx & -idx;
}
}

public:
int lengthOfLIS(vector<int>& nums) {
int n = nums.size();
this->n = n;
tree.resize(n + 1);
std::vector<std::pair<int, int>> p(n);
for (int i = 0; i < n; i++) {
p[i].first = nums[i];
p[i].second = i + 1;
}
std::sort(p.begin(), p.end(), [](const std::pair<int, int>& p1, const std::pair<int, int>& p2) {
return p1.first != p2.first ? p1.first < p2.first : p1.second > p2.second;
});

for (const auto& [_, idx] : p) {
update(idx, query(idx) + 1);
}

return query(n);
}
};
其他

网格图中最少访问的格子数(线段树 单点修改 + 区间查询)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
class SegTree {
// 单点修改 + 区间查询
private:
int n;
vector<int> tree;

void update(int idx, int val, int node, int start, int end) {
if (start == end) {
tree[node] = val;
return;
}
int leftNode = 2 * node + 1, rightNode = 2 * node + 2;
int mid = (start + end) / 2;
if (idx <= mid) {
update(idx, val, leftNode, start, mid);
}
else {
update(idx, val, rightNode, mid + 1, end);
}

tree[node] = min(tree[leftNode], tree[rightNode]);
}

int query(int left, int right, int node, int start, int end) {
if (start > right || end < left) return INT_MAX;
if (start >= left && end <= right) return tree[node];

int leftNode = 2 * node + 1, rightNode = 2 * node + 2;
int mid = (start + end) / 2;

return min(query(left, right, leftNode, start, mid), query(left, right, rightNode, mid + 1, end));
}

public:
SegTree(int _n): n(_n), tree(4 * _n, INT_MAX) {}

void update(int idx, int val) {
update(idx, val, 0, 0, n - 1);
}

int query(int left, int right) {
return query(left, right, 0, 0, n - 1);
}
};

class Solution {
public:
int minimumVisitedCells(vector<vector<int>>& grid) {
int m = grid.size(), n = grid[0].size();
SegTree trRow(m * n);
SegTree trCol(n * m);
trRow.update(m * n - 1, 0);
trCol.update(n * m - 1, 0);
for (int i = m - 1; i >= 0; i--) {
for (int j = n - 1; j >= 0; j--) {
int resRow = trRow.query(n * i + j, n * i + min(j + grid[i][j], n - 1));
// (i, j) ~ (i, j + grid[i][j]) 中的最小值
int resCol = trCol.query(m * j + i, m * j + min(i + grid[i][j], m - 1));
// (i, j) ~ (i + grid[i][j], j) 中的最小值
int res = min(resRow, resCol);
if (res != INT_MAX) {
trRow.update(n * i + j, res + 1);
trCol.update(m * j + i, res + 1);
}
}
}

int ans = trRow.query(0, 0);
return ans != INT_MAX ? ans : -1;
}
};

树状数组和线段树
https://goer17.github.io/2023/03/30/树状数组和线段树/
作者
Captain_Lee
发布于
2023年3月30日
许可协议