[Codeforces1011] Round #499 (Div. 2)(树转RMQ+莫队 or 启发式合并)

题目链接:http://codeforces.com/contest/600/problem/E
给一棵树,问从根节点开始每个节点的子树(包含这个节点)中的众数的和。

一看到子树上的某种计数我就会非常自然地想到树转RMQ再做一些区间操作,当然这个题也不例外。。。

于是考虑把这棵树的查询转成区间查询,要查每一个节点的子树,那么就是有$n​$次查询。线段树是解决不了区间众数问题的,于是就想到去分块。但是这题的意思是众数如果出现多次那就要把他们都加起来,所以我们在维护每个数出现的次数的同时,还要维护每个次数对应的数字之和。更新的时候注意加加减减就可以了,还有就是每次更新当前最大次数$ret​$的时候遇到remove要判断一下当前这个出现次数是否已经没有数字了,遇到这种情况的时候,因为每次出现次数都-1,于是我们直接把$ret-1​$就行。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
#include <bits/stdc++.h>
using namespace std;

using LL = long long;
using PLL = pair<LL, LL>;
using Query = struct { LL l, r, ret, id; };

inline bool scan_d(LL &num) {
char in;bool IsN=false;
in=getchar();
if(in==EOF) return false;
while(in!='-'&&(in<'0'||in>'9')) in=getchar();
if(in=='-'){ IsN=true;num=0;}
else num=in-'0';
while(in=getchar(),in>='0'&&in<='9'){
num*=10,num+=in-'0';
}
if(IsN) num=-num;
return true;
}

signed main() {
// freopen("in", "r", stdin);
LL u, v, n, sid, sz;
vector<vector<LL>> G;
vector<LL> c, be, w;
vector<Query> q;
unordered_map<LL, LL> vis, tot;
while(scan_d(n)) {
G = vector<vector<LL>>(n+1, vector<LL>(0));
c = vector<LL>(n+1, 0);
w = vector<LL>(n+1, 0);
q = vector<Query>(n+1);
sz = LL(sqrt(n)); sid = 0;
vis.clear(); tot.clear();
for(LL i = 1; i <= n; i++) {
scan_d(c[i]);
tot[0] += c[i];
be.emplace_back(i / sz);
}
for(LL i = 0; i < n - 1; i++) {
scan_d(u);
scan_d(v);
G[u].emplace_back(v);
G[v].emplace_back(u);
}
function<void(LL, LL)> dfs = [&](LL u, LL p) {
q[u].l = ++sid; w[sid] = c[u]; q[u].id = u;
for(LL i = 0; i < G[u].size(); i++) {
LL& v = G[u][i];
if(v == p) continue;
dfs(v, u);
}
q[u].r = sid;
};
dfs(1, -1);
sort(q.begin(), q.end(), [&](Query a, Query b) {
return be[a.l] == be[b.l] ? a.r < b.r : a.l < b.l;
});
function<void(LL, LL&)> add = [&](LL x, LL& ret) {
tot[vis[x]] -= x;
vis[x]++;
tot[vis[x]] += x;
ret = max(ret, vis[x]);
};
function<void(LL, LL&)> remove = [&](LL x, LL& ret) {
if(ret == vis[x]) {
if(tot[vis[x]] - x == 0) {
ret = vis[x] - 1;
}
}
tot[vis[x]] -= x;
vis[x]--;
tot[vis[x]] += x;
};
LL L = 1, R = 0;
LL ret = 0;
for(LL i = 1; i <= n; i++) {
while(L < q[i].l) { remove(w[L], ret); L++; }
while(L > q[i].l) { L--; add(w[L], ret); }
while(R < q[i].r) { R++; add(w[R], ret); }
while(R > q[i].r) { remove(w[R], ret); R--; }
// LL p = 0;
// printf("(%lld, %lld) : ", q[i].l,q[i].r);
// for(auto x : vis) {
// printf("%lld %lld, ", x.first, x.second);
// p = max(p, x.second);
// }
// printf("\n");
// for(auto x : vis) {
// if(x.second == p) q[i].ret += x.first;
// }
// cout << p << " " << ret << endl;
q[i].ret = tot[ret];
}
sort(q.begin(), q.end(), [=](Query a, Query b){
return a.id < b.id;
});
for(LL i = 1; i <= n; i++) {
printf("%lld%c", q[i].ret, " \n"[i==n]);
}
}
return 0;
}

这题其实就是想教大家在维护每个数出现的次数的同时,还要维护每个次数对应的数字之和。我们冷静分析以后,发现这个题可以直接DFS。

我们希望在每次DFS的时候把子树中的数字出现情况merge到当前这个父亲上,考虑对每个点维护上述两个数组,每次把小的树更新到大的树上,merge的具体操作跟上述莫队的更新一样。

如何保证小集合merge到大集合上的复杂度?

Tutorial是这么解释的:“every time when vertex v will be moved from one map to another the size of the new map will be at least two times larger. So each vertex can be moved not over than $\log n$ times.”

发现merge的时候大集合至少是小集合的2倍,实际上就是说每次每个集合被遍历的次数都会减半,相当于最坏情况下每个点被遍历$\log n$次,于是merge的复杂度就是$n\log n $. 整体复杂度就是$O(n\log^2 n)$.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
#include <bits/stdc++.h>
using namespace std;

using LL = long long;
const LL maxn = 100100;
LL n, c[maxn], id[maxn];
vector<LL> G[maxn];
map<LL, LL> cnt[maxn], tot[maxn];
LL ret[maxn];

void gao(LL u, LL v) {
if(cnt[id[u]].size() < cnt[id[v]].size()) swap(id[u], id[v]);
for(auto x : cnt[id[v]]) {
tot[id[u]][cnt[id[u]][x.first]] -= x.first;
cnt[id[u]][x.first] += x.second;
tot[id[u]][cnt[id[u]][x.first]] += x.first;
}
}

void dfs(LL u, LL p) {
for(LL i = 0; i < G[u].size(); i++) {
LL& v = G[u][i];
if(v == p) continue;
dfs(v, u);
gao(u, v);
}
ret[u] = tot[id[u]].rbegin()->second;
}

signed main() {
// freopen("in", "r", stdin);
LL u, v;
while(~scanf("%I64d", &n)) {
memset(ret, 0, sizeof(ret));
for(LL i = 1; i <= n; i++) {
cnt[i].clear();
tot[i].clear();
G[i].clear();
}
for(LL i = 1; i <= n; i++) {
scanf("%I64d", &c[i]);
id[i] = i;
cnt[i][c[i]] = 1;
tot[i][1] = c[i];
}
for(LL i = 0; i < n - 1; i++) {
scanf("%I64d%I64d",&u,&v);
G[u].emplace_back(v);
G[v].emplace_back(u);
}
dfs(1, -1);
for(LL i = 1; i <= n; i++) {
printf("%I64d%c", ret[i], " \n"[i==n]);
}
}
return 0;
}