In this HackerEarth Breaking Edges problem solution You are given a tree (undirected connected graph with no cycles) consisting of N nodes and N – 1 edges. There is a number associated with each node (vi) of the tree. You can break any edge of the tree you want and this would result in formation of 2 trees which were part of the original tree.
Let us denote by treeOr, the bitwise OR of all the numbers written on each node in a tree.
You need to find how many edges you can choose, such that, if that edge was removed from the tree, the treeOr of the 2 resulting trees is equal.
HackerEarth Breaking Edges problem solution.
#include <bits/stdc++.h>
using namespace std;
const int MAXN = 2e5 + 55;
vector<int> adj[MAXN];
int cnt[22][MAXN];
int POWER[22];
int dsu[MAXN + 55];
void DFS(int s, int par = -1) {
for(auto it : adj[s]) {
if(it == par) continue;
DFS(it, s);
for(int j = 0; j < 22; j++) {
cnt[j][s] += cnt[j][it];
}
}
}
int ans = 0;
void dfs(int s, int par = -1) {
for(auto it : adj[s]) {
if(it == par) continue;
dfs(it, s);
bool can = true;
for(int j = 0; j < 22 and can; j++) {
if(cnt[j][1] != 0 and (cnt[j][1] == cnt[j][it] or cnt[j][it] == 0)) can = false;
}
if(can) ans++;
}
}
int findRoot(int x) {
if(dsu[x] == x) return x;
return dsu[x] = findRoot(dsu[x]);
}
int main() {
POWER[0] = 1;
for(int i = 1; i < 22; i++) POWER[i] = POWER[i - 1] << 1;
int n;
cin >> n;
assert(n >= 2 and n <= 200000);
for(int i = 1; i <= n; i++) {
dsu[i] = i;
int val;
scanf("%d", &val);
for(int j = 0; j < 22; j++) if(POWER[j] & val) cnt[j][i]++;
assert(val >= 0 and val <= 1048575);
}
for(int i = 1; i < n; i++) {
int x, y;
scanf("%d%d", &x, &y);
adj[x].push_back(y);
adj[y].push_back(x);
int r1 = findRoot(x);
int r2 = findRoot(y);
dsu[r1] = r2;
assert(r1 != r2);
assert(x != y);
assert(x >= 1 and x <= n and y >= 1 and y <= n);
}
set<int> components;
for(int i = 1; i <= n; i++) components.insert(findRoot(i));
assert((int)components.size() == 1);
DFS(1);
dfs(1);
cout << ans << endl;
return 0;
}
Second solution
from collections import defaultdict
maxb = 20
n = int(raw_input())
assert 2 <= n <= 200000
v = map(int, raw_input().split())
assert all(0 <= x <= 1048575 for x in v)
acounts = [0 for __ in xrange(maxb)]
for x in v:
for j in range(maxb):
acounts[j] += (x>>j)&1
graph = defaultdict(list)
for i in range(n-1):
a,b = map(int, raw_input().split())
graph[a].append(b)
graph[b].append(a)
ans = 0
p = [-1 for __ in xrange(n+1)]
q = [1]
p[1] = 0
for front in xrange(n):
cur = q[front]
for a in graph[cur]:
if p[a] == -1:
q.append(a)
p[a] = cur
counts = [[]] + [[(v[i-1]>>j)&1 for j in range(maxb)] for i in xrange(1,n+1)]
for nid in xrange(n-1,-1,-1):
node = q[nid]
for a in graph[node]:
if a == p[node]:
continue
for b in range(maxb):
counts[node][b] += counts[a][b]
ans += all(y == 0 or 0 < x < y for x,y in zip(counts[node], acounts))
print ans