各种分块和各种莫队

分块部分是自己练习记录的, 莫队主要学习 tsy 学长的课. tsy yyds!

线性序列上, 把区间分成 $\sqrt n$ 个 $\sqrt n$ 的小区间.

$O(\sqrt n)$ 个整块, 两个 $O(\sqrt n)$ 大小的散块

  • 单点修改 $O(1)$
  • 单点查询 $O(1)$
  • 区间修改 $O(\sqrt n)$
  • 区间查询 $O(\sqrt n)$

数列分块入门 1

出一个长为 $n$ 的数列, 以及 $n$ 个操作, 操作涉及区间加法, 单点查值.

  • 修改: 对整块打标记, 散块暴力, $O(\sqrt n)$, 注意修改散块时的时候不需要下放标记, 因为对该块中的每一个数最后都需要加上标记值
  • 查询: 直接返回序列上的点值加上包括他的块的标记
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
int T, n, q, block_size = 0, id[MAXN];
LL a[MAXN], tag[MAXN];

void update(int l, int r, LL c) {
	for (int i = l; i < min(r+1, (id[l]+1) * block_size); i++)
		a[i] += c;
	for (int i = id[l]+1; i < id[r]; i++)
		tag[i] += c;
	if (id[l] != id[r])
		for (int i = id[r] * block_size; i <= r; i++)
			a[i] += c;
}

int main() {
	scanf("%d", &n);
	q = n;
	block_size = sqrt(n);
	for (int i = 0; i < n; i++) {
		scanf("%lld", a + i);
		id[i] = i / block_size;
	}
	while (q--) {
		int op, l, r;
		LL c;
		scanf("%d%d%d%lld", &op, &l, &r, &c);
		l--, r--;
		if (op)
			printf("%lld\n", a[r] + tag[id[r]]);
		else
			update(l, r, c);
	}
	return 0;
}

由于是入门, 稍微解释一下代码.

首先认为从 $0$ 开始, 这样方便求块的编号, 就是 id = pos / block_size. 我们可以把每个位置对应在哪个块现求出来.

然后处理 $l$ 所在的散块, 注意不要超过块大小以及 $r$ 对应长度, 因为 $l, r$ 可能在同一块里. 再处理 $l$ 之后, $r$ 之前的整块. 最后看看 $l, r$ 是不是在同一块里, 如果不是, 还需要处理 $r$ 所在的散块.

数列分块入门 2

给出一个长为 $n$ 的数列, 以及 $n$ 个操作, 操作涉及区间加法, 询问区间内小于某个值 $x$ 的元素个数.

下面分析复杂, 来确定块大小取多少最优.

设块大小为 $n^x$, 块个数为 $n^{1-x}$.

由于需要查找, 我们可以开块个数个 vector 有序存块中的元素, 预处理排序共 $O(n^{1-x} n^x \log n^x) = O(xn \log n)$.

查询的时候如果是整块, 那么直接二分查找, 一共 $n^{1-x}$ 个块, 每个块中 $n^x$ 个元素, 在其中二分需要 $O(\log n^x) = O(x \log n)$ 次, 总复杂度 $O(qxn^{1-x} \log n)$; 如果是散块, 暴力查找, $O(qn^x)$.

修改的时候如果是整块, 直接打标记 $O(q)$, 如果是散块, 需要全部加上, 然后重新排散块的序, $O(q(n^x + n^x \log n^x)) = O(qn^x + xqn^x \log n) = O(xqn^x \log n)$

题设 $O(q) = O(n)$, 所以可能是最大的复杂度的是:

  1. $O(xn \log n)$
  2. $O(xn^{2-x} \log n)$
  3. $O(n^{1+x})$
  4. $O(n)$
  5. $O(xn^{1+x} \log n)$

因为 $0 < x < 1$, 通过比较去掉一些, 可以知道, 只有 $O(xn^{2-x} \log n)$ 和 $O(xn^{1+x} \log n)$ 可能是最大的复杂度, 令其相等得到最优的 $x = \frac{1}{2}$.

(不知道哪里推错了 hzwer 博客说有更优的? 可能是我太菜了, 算不来就直接丢掉导致误差大了qaq)

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
int T, n, q, block_size = 0, id[MAXN];
LL a[MAXN], tag[MAXN];
vector<LL> v[MAXN];

void sort_block(int idx) {
	v[idx].clear();
	for (int i = idx * block_size; i < min(n, (idx+1) * block_size); i++)
		v[idx].push_back(a[i]);
	sort(v[idx].begin(), v[idx].end());
}

void update(int l, int r, LL c) {
	for (int i = l; i < min(r+1, (id[l]+1) * block_size); i++)
		a[i] += c;
	sort_block(id[l]);
	for (int i = id[l]+1; i < id[r]; i++)
		tag[i] += c;
	if (id[l] != id[r]) {
		for (int i = id[r] * block_size; i <= r; i++)
			a[i] += c;
		sort_block(id[r]);
	}
}

int query(int l, int r, int x) {
	int res = 0;
	for (int i = l; i < min(r+1, (id[l]+1) * block_size); i++)
		res += a[i] + tag[id[l]] < x;
	for (int i = id[l]+1; i < id[r]; i++)
		res += lower_bound(v[i].begin(), v[i].end(), x - tag[i]) - v[i].begin();
	if (id[l] != id[r])
		for (int i = id[r] * block_size; i <= r; i++)
			res += a[i] + tag[id[r]] < x;
	return res;
}

int main() {
	scanf("%d", &n);
	q = n;
	block_size = sqrt(n);
	for (int i = 0; i < n; i++) {
		scanf("%lld", a + i);
		id[i] = i / block_size;
	}
	for (int i = 0; i <= id[n-1]; i++)
		sort_block(i);
	while (q--) {
		int op, l, r;
		LL c;
		scanf("%d%d%d%lld", &op, &l, &r, &c);
		l--, r--;
		if (op)
			printf("%d\n", query(l, r, c*c));
		else
			update(l, r, c);
	}
	return 0;
}

需要注意边界, 最后一个块可能不满 block_size, 枚举的不要超过 $n$.

块中不仅可以维护数组, 还可以维护如 set 的数据结构. 反正就各种暴力.

还有一个涉及增加元素的, 这个时候把块用链表链起来, 增加就暴力插入, 如果块大小超过 2block_size, 就把他裂开来.

我直接贴上学长的视频(狗头)

  1. 维护区间答案
  2. 维护区间上的数据结构

将序列 $\sqrt n$ 分块

把查询 $[l, r]$ 离线, 排序

排序按照: 如果 $l$ 不在同一块, 则按照 $l$ 递增排序; 如果在同一块里, 则按照 $r$ 递增排序.

如:

1 2
2 1000000
3 4
5 1000001
5 6

排序后就有:

1 2
3 4
5 6
2 1000000
5 1000001

由我们的排序规则可得, 当这一次查询的 $l$ 和上一次的在同一块中时, $l$ 指针移动 $O(\sqrt n)$. 一共有 $\sqrt n$ 个 $l$ 不同的块, 且每个查询的 $l$ 所在这些块"递增", 即"枚举一次", 故复杂度为 $O(\max\{q \sqrt n, \sqrt n \cdot \sqrt n\}) = \max\{q\sqrt n, n\}$. 当 $O(q) = O(n)$ 时, 复杂度为 $O(n\sqrt n)$.

然后考虑 $r$, 当这一次查询的 $l$ 和上一次的在同一块中时, $r$ 指针单调移动, 一次最多 $O(\sqrt n)$, 一共 $O(\sqrt n)$ 个快, 所以复杂度 $O(n)$. 当这一次查询的 $l$ 和上一次的在不在一个块中时, $r$ 一次移动 $O(n)$, 由于 $l$ 所在块"递增", 即"枚举一次", 故复杂度为 $O(\max\{n, n \sqrt n\})$

和分块一样可以现把每个点属于哪个块处理出来. 由于我们不需要改块中的信息, 所以从 $1$ 其实并不会特别麻烦.

稍微推一下可以知道, 扩张区间的时候, 分子增加 cnt[c], 收缩区间的时候, 分子减少 cnt[c]-1 (以上两个cnt均为没修改的)

每次移动指针, 先扩张后收缩, 先左端后右端. 否则可能出现非法区间, 如: 从 $[1,5]$ 移动到 $[11,15]$, 不能出现 $[11,5]$ 这样的非法过渡态. (笑死jyz全校的莫队都是假的)

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
int T, n, q, a[MAXN], id[MAXN], cnt[MAXN], block_size;
LL cur = 0;
struct Query {
	int idx, l, r;
	bool operator < (const Query &Q) const {
		return id[l] == id[Q.l] ? r < Q.r : l < Q.l;
	}
} query[MAXN];
pair<LL, LL> ans[MAXN];

void add(int c) {
	cur += cnt[c]++;
}

void del(int c) {
	cur -= --cnt[c];
}

int main() {
	scanf("%d%d", &n, &q);
	block_size = sqrt(n);
	for (int i = 1; i <= n; i++) {
		scanf("%d", a + i);
		id[i] = (i-1) / block_size + 1;
	}
	for (int i = 1; i <= q; i++) {
		scanf("%d%d", &query[i].l, &query[i].r);
		query[i].idx = i;
	}
	sort(query + 1, query + q + 1);
	int l = query[1].l, r = l-1;
	for (int i = 1; i <= q; i++) {
		while (l > query[i].l)
			add(a[--l]);
		while (r < query[i].r)
			add(a[++r]);
		while (l < query[i].l)
			del(a[l++]);
		while (r > query[i].r)
			del(a[r--]);
		LL len = r - l + 1, p = len * (len - 1) / 2;
		LL d = __gcd(cur, p);
		ans[query[i].idx] = make_pair(cur / d, p / d);
	}

	for (int i = 1; i <= q; i++)
		printf("%lld/%lld\n", ans[i].first, ans[i].second);

	return 0;
}

练手

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50

int T, n, q, block_size, a[MAXN], ans[MAXQ], id[MAXN], cnt[MAXA], cur = 0;
struct Query {
	int idx, l, r;
	bool operator < (const Query &Q) const {
		return id[l] == id[Q.l] ? r < Q.r : l < Q.l;
	}
} query[MAXQ];

void add(int p) {
	cnt[a[p]]++;
	if (cnt[a[p]] == 1)
		cur++;
}

void del(int p) {
	cnt[a[p]]--;
	if (cnt[a[p]] == 0)
		cur--;
}

int main() {
	scanf("%d", &n);
	block_size = sqrt(n);
	for (int i = 1; i <= n; i++) {
		scanf("%d", a + i);
		id[i] = (i-1) / block_size + 1;
	}
	scanf("%d", &q);
	for (int i = 1; i <= q; i++) {
		scanf("%d%d", &query[i].l, &query[i].r);
		query[i].idx = i;
	}
	sort(query+1, query+q+1);
	int l = query[1].l, r = l-1;
	for (int i = 1; i <= q; i++) {
		while (l > query[i].l)
			add(--l);
		while (r < query[i].r)
			add(++r);
		while (l < query[i].l)
			del(l++);
		while (r > query[i].r)
			del(r--);
		ans[query[i].idx] = cur;
	}
	for (int i = 1; i <= q; i++)
		printf("%d\n", ans[i]);
	return 0;
}

记录某个颜色出现次数的次数. 假设众数为 $x$, 众数的那么 cnt[x] 是最大的并且 num[cnt[x]] > 0

增加一个数 $x$ 的时候, 首先增加前 $x$ 出现的次数 cnt[x]$

维护数据结构的例题, 本题很显然可以维护一个BIT, 但是这样修改一个点复杂度带一个 log, 尤其询问还多, 过不去. 所以可以维护分块数据结构, 这样修改是 $O(1)$ 的, 只有最后查询的时候是 $O(\log n)$.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
int T, n, q, s[MAXN], block_size, id[MAXN], ans[MAXQ], cnt[MAXN], tag[MAXN];

int query(int l, int r) {
	int res = 0;
	for (int i = l; i <= min(id[l]*block_size, r); i++)
		res += cnt[i] > 0;
	for (int i = id[l]+1; i < id[r]; i++)
		res += tag[i];
	if (id[l] != id[r])
		for (int i = (id[r]-1)*block_size + 1; i <= r; i++)
			res += cnt[i] > 0;
	return res;
}

struct Query {
	int idx, l, r, a, b;
	bool operator < (const Query &Q) const {
		return id[l] == id[Q.l] ? r < Q.r : l < Q.l;
	}
} qry[MAXQ];

void add(int c) {
	if (++cnt[c] == 1)
		tag[id[c]]++;
}

void del(int c) {
	if (--cnt[c] == 0)
		tag[id[c]]--;
}

int main() {
	scanf("%d%d", &n, &q);
	block_size = sqrt(n);
	for (int i = 1; i <= n; i++) {
		scanf("%d", s + i);
		id[i] = (i-1) / block_size + 1;
	}
	for (int i = 1; i <= q; i++) {
		scanf("%d%d%d%d", &qry[i].l, &qry[i].r, &qry[i].a, &qry[i].b);
		qry[i].idx = i;
	}
	sort(qry+1, qry+q+1);
	int l = qry[1].l, r = l-1;
	for (int i = 1; i <= q; i++) {
		while (l > qry[i].l)
			add(s[--l]);
		while (r < qry[i].r)
			add(s[++r]);
		while (l < qry[i].l)
			del(s[l++]);
		while (r > qry[i].r)
			del(s[r--]);
		ans[qry[i].idx] = query(qry[i].a, qry[i].b);
	}
	for (int i = 1; i <= q; i++)
		printf("%d\n", ans[i]);
	return 0;
}

加入一维时间 $t$. 三个关键字排序, 先左端点所在块号递增, 再右端点所在块号递增, 最后时间 $t$ 递增. 然后三维指针乱跳. 带修改的话就不是分成 $\sqrt n$ 大小, 用复杂度算. 设块大小为 $n^x$.

我们把时间考虑进来, 本来是 $n^{1-x}$ 个块, 现在有 $tn^{1-x}$, 其中每个左右端点相同的块有 $t$ 个, 分别对应的是时间. 下面分析 $l, r, t$ 指针的复杂度.

$l$ 所在块编号相同时, $l$ 移动 $O(n^x)$, 复杂度 $O(qn^x)$; 不同时, 移动 $O(n^x)$, 一共 $O(n^{1-x})$ 个块, 复杂度 $O(n)$. 总复杂度 $O(qn^x)$, 当 $O(q) = O(n)$ 时为 $O(n^{1+x})$.

$l$ 所在块编号相同时, $r$ 单调移动, 最多 $O(n^{1-x})$ 个块, 也就是说, 最多能把询问 $r$ 排成 $n^{1-x}$ 个单调段, 复杂度 $O(n^{2-x})$.

按 $t$ 排序时, 首先要满足这一次的 $l, r$ 和上一次的 $l’, r’$ 都在同一块里, 这样 $t$ 是才递增的. 那么最坏情况下能划分成多少个单调段呢? 我们让 $l, r$ 取遍所有块, 这样就根据鸽子原理, 之后的询问一定会和前面的某些询问的时间满足单调递增. 一个端点 $O(n^{1-x})$ 种取法, 所以共 $O((n^{1-x})^2)$ 个单调段, 所以总复杂度是 $O(tn^{2-2x})$. 当 $O(t) = O(n)$ 时, 复杂度 $O(n^{3-2x})$.

让这三个指针的复杂度之和最小, 即 $\max \{ O(n^{1+x}), O(n^{2-x}), O(n^{3-2x}) \}$ 最小, 稍加计算可知, $x = \frac{2}{3}$ 时最优, 为 $O(n^{\frac{5}{3}})$. 如果还是取 $x = \frac{1}{2}$, 那么复杂是 $O(n^2)$ 的布鲁特福斯惊呼内行

因为时间需要"回溯", 所以得记录一个"修改前的值", 等回到这个时间再放回来. 可以对时间为 $t$ 的这个询问和原数值进行交换. 当然在这之前如果这个值在询问的区间中, 还得对答案进行更新. 这样等下一次再回到 $t$, 再交换一次就变回来了. 由于 $t$ 不会乱跳, 而是增加上去然后减小回来, 这样就能够保证"撤销"复原. 具体看例题代码.

注意一点细节, 查询和修改是分开来存储的. 不要忘记设置块大小和块id.

CF - 940F

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
int T = 0, n, q = 0, qc, block_size, a[MAXN], ans[MAXQ], id[MAXN], cur = 0, cnt[MAXN], ID = 0;
unordered_map<int, int> mp;
int app[MAXN<<1];

struct Query {
	int idx, t, l, r;
	bool operator < (const Query &Q) const {
		return id[l] == id[Q.l] ? id[r] == id[Q.r] ? t < Q.t : id[r] < id[Q.r] : id[l] < id[Q.l];
	}
} query[MAXQ];

struct Modify {
	int pos, a;
} modify[MAXQ];

void add(int c) {
	int now = ++app[c];
	int org = now-1;
	cnt[org]--;
	cnt[now]++;
}

void del(int c) {
	int now = --app[c];
	int org = now+1;
	cnt[org]--;
	cnt[now]++;
}

void change(int l, int r, int t) {
	int &pos = modify[t].pos, &c = modify[t].a;
	if (l <= pos && pos <= r) {
		del(a[pos]);
		add(c);
	}
	swap(c, a[pos]);
}

int main() {
	scanf("%d%d", &n, &qc);
	block_size = pow(n, 2./3.);
	for (int i = 1; i <= n; i++) {
		int A;
		scanf("%d", &A);
		if (!mp[A])
			mp[A] = ++ID;
		a[i] = mp[A];
		id[i] = (i-1) / block_size + 1;
	}
	for (int i = 1; i <= qc; i++) {
		int type;
		scanf("%d", &type);
		if (type == 1) {
			int l, r;
			scanf("%d%d", &l, &r);
			query[++q] = {q, T, l, r};
		}
		else {
			int p, A;
			scanf("%d%d", &p, &A);
			if (!mp[A])
				mp[A] = ++ID;
			modify[++T] = {p, mp[A]};
		}
	}
	sort(query+1, query+q+1);
	int l = query[1].l, r = l-1, t = 0;
	for (int i = 1; i <= q; i++) {
		while (l > query[i].l)
			add(a[--l]);
		while (r < query[i].r)
			add(a[++r]);
		while (l < query[i].l)
			del(a[l++]);
		while (r > query[i].r)
			del(a[r--]);
		while (t < query[i].t)
			change(l, r, ++t);
		while (t > query[i].t)
			change(l, r, t--);
		int res = 0;
		while (cnt[++res]);
		ans[query[i].idx] = res;
	}
	for (int i = 1; i <= q; i++)
		printf("%d\n", ans[i]);
	return 0;
}

需要注意这一题里求 mex 要放到区间指针全部更新完以后, 否则复杂度会爆炸.

还要注意即使是开了 O2 的 unordered_map 也不要尝试对他直接搞, 否则会慢 10 倍, 最好的方法是离散化掉.

如果在维护区间的时候, 增加的操作很好实现, 而删除的操作比较困难, 那么我们可以用回滚莫队, 回滚莫队也称作不删除莫队.

大概思路是这样的(默认块大小为 $\sqrt n$):

因为我们对询问进行了这样的排序: $l$ 在同一块的放在一起, 并且这些询问的 $r$ 递增. 考虑解决这样的一组询问.

设 $l’$ 为当前组 $l$ 所在块的下一块的左端点, 然后每次询问都 $O(\sqrt n)$ 暴力搞 $[l, l’-1]$, 记录这里的数据称为临时信息, 每次询问做完后清空临时信息; 因为 $r$ 递增, 所以从 $[l’, r]$ 这里的数据就可以用总共 $O(n)$ 的时间维护, 记录这里的数据称为永久信息. 做完以后, 对临时信息和永久信息做一个"合并", 得到答案. 这一组数据做完以后, 清空所有信息, 包括临时信息和永久信息, 下一块和上面一样做.

清空的操作就是回滚, 即一步一步退回, 详见例题的代码.

需要注意的是, 临时信息除了在合并时需要用到永久信息里的内容外, 其余一律不能用永久信息或者与永久信息有关的全局变量. 例如答案要分开来统计 tmpcur 和 cur, 最后对两个 cur 合并为当前询问的回答. 还有临时信息在统计的时候不要用到 l 和 r, 因为他们是永久信息里的内容. 用的是 query[i] 的信息, 如 query[i].l, query[i].r.

复杂度也很好算, 每次维护临时信息都是 $O(\sqrt n)$ 的, 所以所有查询就是 $O(q\sqrt n)$. 然后永久信息的话, 每一块是 $O(n)$ 的, 一共 $\sqrt n$ 块. 所以总复杂为 $O(q\sqrt n + n \sqrt n)$.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
int n, q, block_size, id[MAXN], ID = 0, a[MAXN], cnt[MAXN], tmpcnt[MAXN];
LL num[MAXN], ans[MAXN];
unordered_map<int, int> mp;

int get_block_r(int block_id) {
	return min(block_id * block_size, n);
}
int get_block_l(int block_id) {
	return get_block_r(block_id-1) + 1;
}

struct Query {
	int idx, l, r;
	bool operator < (const Query &Q) const {
		return id[l] == id[Q.l] ? r < Q.r : id[l] < id[Q.l];
	}
} query[MAXN];

int main() {
	scanf("%d%d", &n, &q);
	block_size = sqrt(n);
	for (int i = 1; i <= n; i++) {
		int x;
		scanf("%d", &x);
		if (!mp[x])
			mp[x] = ++ID;
		num[a[i] = mp[x]] = x;
		id[i] = (i-1) / block_size + 1;
	}
	for (int i = 1; i <= q; i++) {
		scanf("%d%d", &query[i].l, &query[i].r);
		query[i].idx = i;
	}
	sort(query+1, query+1+q);
	int l = 1, r = l-1;
	LL mx = -1;
	for (int i = 1; i <= q; i++) {
		if (id[query[i].l] != id[query[i-1].l]) {
			mx = -1;
			while (r >= l)
				--cnt[a[r--]];
			l = get_block_l(id[query[i].l]+1);
			r = l-1;
		}
		while (r < query[i].r) {
			cnt[a[++r]]++;
			mx = max(mx, cnt[a[r]] * num[a[r]]);
		}
		LL tmpmx = -1;
		for (int j = query[i].l; j <= min(query[i].r, get_block_r(id[query[i].l])); j++)
			tmpmx = max(tmpmx, (++tmpcnt[a[j]] + cnt[a[j]]) * num[a[j]]);
		for (int j = query[i].l; j <= min(query[i].r, get_block_r(id[query[i].l])); j++)
			tmpcnt[a[j]]--;
		ans[query[i].idx] = max(tmpmx, mx);
	}
	for (int i = 1; i <= q; i++)
		printf("%lld\n", ans[i]);
	return 0;
}

2021 牛客多校第五场 I.Interval Queries

用括号序, 欧拉序搞成序列就行了.

可以维护链或者子树的信息.

详见: 树的序列化

对于路径的信息, 这里需要做一点点变化. 由于莫队维护的东西如不同的颜色个数, 不具有加和性, 所以不能用u到根的权值加上v到根的权值减去lca到根的权值. 要稍微做一点点改变. 还是用括号序, 只不过出的时候权值还是计正值. 括号序有个性质, 一个点进出构成的区间, 要么包含另一个区间, 要么和另一个区间相离. 利用这个性质加上一进一出的性质可知:

  1. 如果lca=x, 那么 $[in_x, in_y]$ 就包含恰好一个在这条链的信息, 以及两个不在这条链上的信息. 两个的统计一下让他们抵消即可.
  2. 如果x,y没有祖先关系, 假设in[x] < in[y]那么 $[out_x, in_y]$ 就恰好包含除lca外, 一个路径上的信息, 两次不在这条路径上的信息. 统计的时候抵消两个的贡献, 再加上lca的贡献就是答案.
  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
int n, q, a[MAXN], m, org[MAXN<<1];
int sz[MAXN], son[MAXN], fa[MAXN], dep[MAXN];
int idx = 0, in[MAXN], out[MAXN], top[MAXN];
int id[MAXN<<1], block_size;
int cnt[MAXN], app[MAXN], cur, ans[MAXQ];
int ID = 0;
unordered_map<int, int> mp;
VI G[MAXN];

void dfs1(int u, int f) {
	sz[u] = 1;
	fa[u] = f;
	dep[u] = dep[f] + 1;
	int mx = 0;
	for (int v : G[u]) if (f != v) {
		dfs1(v, u);
		if (mx < sz[v]) {
			mx = sz[v];
			son[u] = v;
		}
		sz[u] += sz[v];
	}
}

void dfs2(int u, int f, int t) {
	in[u] = ++idx;
	org[idx] = u;
	top[u] = t;
	if (son[u])
		dfs2(son[u], u, t);
	for (int v : G[u]) if (v != f && v != son[u])
		dfs2(v, u, v);
	out[u] = ++idx;
	org[idx] = u;
}

int get_lca(int u, int v) {
	while(top[u] != top[v]) {
		if (dep[top[u]] > dep[top[v]])
			u = fa[top[u]];
		else
			v = fa[top[v]];
	}
	return dep[u] < dep[v] ? u : v;
}

struct Query {
	int idx, l, r, lca;
	bool operator < (const Query &Q) const {
		return id[l] == id[Q.l] ? r < Q.r : id[l] < id[Q.l];
	}
} query[MAXQ];

void sub(int clr) {
	if (--cnt[clr] == 0)
		cur--;
}

void pls(int clr) {
	if (++cnt[clr] == 1)
		cur++;
}

void add(int pos) {
	int u = org[pos];
	int clr = a[u];
	if (app[u])
		sub(clr);
	else
		pls(clr);
	app[u] ^= 1;
}

void del(int pos) {
	add(pos);
}

int main() {
	scanf("%d%d", &n, &q);
	for (int i = 1; i <= n; i++) {
		int x;
		scanf("%d", &x);
		if (!mp[x])
			mp[x] = ++ID;
		a[i] = mp[x];
	}
	for (int i = 1; i < n; i++) {
		int u, v;
		scanf("%d%d", &u, &v);
		G[u].push_back(v);
		G[v].push_back(u);
	}
	dfs1(1, 0);
	dfs2(1, 0, 1);
	m = 2 * n;
	block_size = sqrt(m);
	for (int i = 1; i <= m; i++)
		id[i] = (i-1) / block_size;
	for (int i = 1; i <= q; i++) {
		int x, y;
		scanf("%d%d", &x, &y);
		int lca = get_lca(x, y);
		if (lca == y)
			swap(x, y);
		if (lca == x)
			query[i] = Query{i, in[x], in[y], 0};
		else {
			if (in[x] > in[y])
				swap(x, y);
			query[i] = Query{i, out[x], in[y], lca};
		}
	}

	sort(query+1, query+1+q);
	int l = query[1].l, r = l - 1;
	for (int i = 1; i <= q; i++) {
		while (l > query[i].l)
			add(--l);
		while (r < query[i].r)
			add(++r);
		while (l < query[i].l)
			del(l++);
		while (r > query[i].r)
			del(r--);
		ans[query[i].idx] = cur + (query[i].lca && cnt[a[query[i].lca]] == 0);
	}
	for (int i = 1; i <= q; i++)
		printf("%d\n", ans[i]);
	return 0;
}