Luogu P5127 子异和

题目描述

https://www.luogu.com.cn/problem/P5127

简要题意:给定一个棵有 $n$ 个点的无根树,现在有 $m$ 次操作,操作有两种,给定 $x,y$ 求 $x$ 到 $y$ 的简单路径上所有点组成的可重集合的子异和;给定 $x,y,z$ 将 $x$ 到 $y$ 上每个点异或上 $z$,其中集合 $S$ 子异和的定义为子集 $S$ 的所有子集的异或和的和

$n,m\le 2\times 10^5$

Solution

我们考虑集合 $S=\lbrace a_1,a_2\cdots,a_n\rbrace$ 的子异和,我们按位考虑对于第 $i$ 个二进制位,我们不妨假设有 $x$ 个 $1$ 和 $y$ 个 $0$,我们有 $x+y=n$,容易得到异或和为 $1$ 的子集个数为 $2^y\sum_{i=0}^x\binom{x}{i}[i\equiv 1(\bmod 2)]=2^{x+y-1}=2^{n-1}$,另外需要注意如果 $x$ 为 $0$,则答案为 $0$,这样容易得到答案就是 $2^{n-1}$ 乘上 $S$ 的所有数的或

我们首先树剖,那么现在只需要支持求区间或和区间异或,我们用线段树维护区间或、区间与和区间异或标记,能推出如下的转移

1
2
3
4
5
6
inline void Update(int i, int v) {
int a = T[i].orsum, b = T[i].andsum;
T[i].orsum = a & ~v | ~b & v;
T[i].andsum = b & ~v | ~a & v;
T[i].tag ^= v;
}
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
#include <iostream>
#define maxn 200010
#define ll long long
using namespace std;

const int p = 1000000007;

ll pow_mod(ll x, ll n) {
ll s = 1;
for (; n; n >>= 1, x = x * x % p)
if (n & 1) s = s * x % p;
return s;
}

int n, m, a[maxn];

struct Edge {
int to, next;
} e[maxn * 2]; int c1, head[maxn];
inline void add_edge(int u, int v) {
e[c1].to = v; e[c1].next = head[u]; head[u] = c1++;
}

int dep[maxn], son[maxn], sz[maxn], f[maxn];
void dfs(int u, int fa) {
int Max = 0; sz[u] = 1;
for (int i = head[u]; ~i; i = e[i].next) {
int v = e[i].to; if (v == fa) continue;
f[v] = u; dep[v] = dep[u] + 1;
dfs(v, u); sz[u] += sz[v];
if (sz[v] > Max) Max = sz[v], son[u] = v;
}
}

int top[maxn], id[maxn], c2, bl[maxn];
void dfs(int u, int fa, int topf) {
top[u] = topf; id[u] = ++c2; bl[c2] = u;
if (son[u]) dfs(son[u], u, topf);
for (int i = head[u]; ~i; i = e[i].next) {
int v = e[i].to; if (v == fa || v == son[u]) continue;
dfs(v, u, v);
}
}

#define lc i << 1
#define rc i << 1 | 1
struct Seg {
int orsum, andsum, tag;
} T[maxn * 4];
inline void maintain(int i) {
T[i].orsum = T[lc].orsum | T[rc].orsum;
T[i].andsum = T[lc].andsum & T[rc].andsum;
}

void build(int i, int l, int r) {
if (l == r) return T[i] = { a[bl[l]], a[bl[l]] }, void();
int m = l + r >> 1;
build(lc, l, m); build(rc, m + 1, r);
maintain(i);
}

inline void Update(int i, int v) {
int a = T[i].orsum, b = T[i].andsum;
T[i].orsum = a & ~v | ~b & v;
T[i].andsum = b & ~v | ~a & v;
T[i].tag ^= v;
}

inline void pushdown(int i) {
int &tag = T[i].tag; if (!tag) return ;
Update(lc, tag); Update(rc, tag);
tag = 0;
}

void update(int i, int l, int r, int L, int R, int v) {
if (l > R || r < L) return ;
if (L <= l && r <= R) return Update(i, v);
int m = l + r >> 1; pushdown(i);
update(lc, l, m, L, R, v); update(rc, m + 1, r, L, R, v);
maintain(i);
}

int query(int i, int l, int r, int L, int R) {
if (l > R || r < L) return 0;
if (L <= l && r <= R) return T[i].orsum;
int m = l + r >> 1; pushdown(i);
return query(lc, l, m, L, R) | query(rc, m + 1, r, L, R);
}

int query(int x, int y) {
int orsum = 0, len = 0;
while (top[x] != top[y]) {
if (dep[top[x]] < dep[top[y]]) swap(x, y);
orsum |= query(1, 1, n, id[top[x]], id[x]);
len += dep[x] - dep[top[x]] + 1; x = f[top[x]];
}
if (dep[x] > dep[y]) swap(x, y);
orsum |= query(1, 1, n, id[x], id[y]);
len += dep[y] - dep[x] + 1;
return orsum * pow_mod(2, len - 1) % p;
}

void update(int x, int y, int z) {
while (top[x] != top[y]) {
if (dep[top[x]] < dep[top[y]]) swap(x, y);
update(1, 1, n, id[top[x]], id[x], z);
x = f[top[x]];
}
if (dep[x] > dep[y]) swap(x, y);
update(1, 1, n, id[x], id[y], z);
}

inline void solve_1() {
int x, y; cin >> x >> y;
cout << query(x, y) << "\n";
}

inline void solve_2() {
int x, y, z; cin >> x >> y >> z;
update(x, y, z);
}

int main() { fill(head, head + maxn, -1);
ios::sync_with_stdio(false);
cin.tie(nullptr); cout.tie(nullptr);

cin >> n >> m;
for (int i = 1; i < n; ++i) {
int x, y; cin >> x >> y;
add_edge(x, y); add_edge(y, x);
}
for (int i = 1; i <= n; ++i) cin >> a[i];
dep[1] = 1; dfs(1, 0); dfs(1, 0, 1); build(1, 1, n);
for (int i = 1; i <= m; ++i) {
int opt; cin >> opt;
if (opt == 1) solve_1();
else solve_2();
}
return 0;
}