In this HackerRank Tree Splitting problem solution Given a tree with vertices numbered from 1 to n. You need to process m queries.
Problem solution in Java.
import java.io.*; import java.util.*; public class Solution { static long x = 1; // Xorshift random number generators static long marsagliaXor32() { x ^= x << 13; x ^= x >> 17; return x ^= x << 5; } static class Node { int size = 1; long pri = marsagliaXor32(); Node l = null; Node r = null; Node p = null; Node mconcat() { this.size = size(l) + 1 + size(r); if (l != null) { l.p = this; } if (r != null) { r.p = this; } return this; } } static int size(Node x) { return x != null ? x.size : 0; } static Node root(Node x) { while (x.p != null) { x = x.p; } return x; } static long orderOf(Node x) { long r = size(x.l); while (x.p != null) { if (x.p.r == x) { r += size(x.p.l) + 1; } x = x.p; } return r; } static Node join(Node x, Node y) { if (x == null) return y; if (y == null) return x; if (x.pri < y.pri) { x.r = join(x.r, y); return x.mconcat(); } else { y.l = join(x, y.l); return y.mconcat(); } } static long[] dep; static List<Integer>[] es; static Node[] pre; static Node[] post; static Node tr = null; static class NodeDfs { int u; int p; boolean start = true; public NodeDfs(int u, int p) { this.u = u; this.p = p; } } static void dfs(int u, int p) { Deque<NodeDfs> queue = new LinkedList<>(); queue.add(new NodeDfs(u, p)); while (!queue.isEmpty()) { NodeDfs node = queue.peek(); if (node.start) { pre[node.u] = new Node(); tr = join(tr, pre[node.u]); for (int v: es[node.u]) { if (v != node.p) { dep[v] = dep[node.u] + 1; queue.push(new NodeDfs(v, node.u)); } } node.start = false; } else { post[node.u] = new Node(); tr = join(tr, post[node.u]); queue.remove(); } } } static Node[] split(Node x, long k, Node l, Node r) { if (x == null) { l = r = null; } else { long c = size(x.l) + 1; if (k < c) { Node[] res = split(x.l, k, l, x.l); l = res[0]; x.l = res[1]; r = x; } else { Node[] res = split(x.r, k - c, x.r, r); x.r = res[0]; r = res[1]; l = x; } x.mconcat(); x.p = null; } return new Node[] {l , r}; } static void cut(int u, int v) { if (dep[v] < dep[u]) { int t = v; v = u; u = t; } long il = orderOf(pre[v]); long ir = orderOf(post[v])+1; Node y = root(pre[v]); Node z = null; Node[] res = split(y, ir, y, z); y = res[0]; z = res[1]; Node x = null; res = split(y, il, x, y); x = res[0]; join(x, z); } public static void main(String[] args) throws IOException { BufferedReader br = new BufferedReader(new InputStreamReader(System.in)); BufferedWriter bw = new BufferedWriter(new FileWriter(System.getenv("OUTPUT_PATH"))); StringTokenizer st = new StringTokenizer(br.readLine()); int n = Integer.parseInt(st.nextToken()); dep = new long[n]; es = new List[n]; pre = new Node[n]; post = new Node[n]; for (int i = 0; i < n; i++) { es[i] = new ArrayList<>(); } for (int i = 0; i < n - 1; i++) { st = new StringTokenizer(br.readLine()); int u = Integer.parseInt(st.nextToken())-1; int v = Integer.parseInt(st.nextToken())-1; es[u].add(v); es[v].add(u); } dfs(0, -1); st = new StringTokenizer(br.readLine()); int queriesCount = Integer.parseInt(st.nextToken()); int result = 0; for (int i = 0; i < queriesCount; i++) { st = new StringTokenizer(br.readLine()); int u = Integer.parseInt(st.nextToken()); u = (result ^ u) - 1; result = size(root(pre[u])) / 2; bw.write(String.valueOf(result)); if (i != queriesCount - 1) { bw.write("n"); for (int v: es[u]) { cut(u, v); } } } bw.newLine(); bw.close(); br.close(); } }
Problem solution in C++.
#include <bits/stdc++.h> using namespace std; struct node { int size = 1; node *lch = nullptr; node *rch = nullptr; node *parent = nullptr; }; unsigned xor32() { static unsigned z = time(NULL); z ^= z << 13; z ^= z >> 17; z ^= z << 5; return z; } int size(node *x) { return x == nullptr ? 0 : x->size; } node *push(node *x) { x->size = 1 + size(x->lch) + size(x->rch); x->parent = nullptr; if (x->lch != nullptr) x->lch->parent = x; if (x->rch != nullptr) x->rch->parent = x; return x; } node *merge(node *x, node *y) { if (x == nullptr) return y; if (y == nullptr) return x; if (xor32() % (size(x) + size(y)) < size(x)) { x = push(x); x->rch = merge(x->rch, y); return push(x); } else { y = push(y); y->lch = merge(x, y->lch); return push(y); } } pair<node *, node *> split(node *x, int k) { if (x == nullptr) return{ nullptr, nullptr }; x = push(x); if (size(x->lch) >= k) { auto p = split(x->lch, k); x->lch = p.second; return{ p.first, push(x) }; } else { auto p = split(x->rch, k - size(x->lch) - 1); x->rch = p.first; return{ push(x), p.second }; } } node *root(node *x) { if (x->parent == nullptr) return x; return root(x->parent); } int index_of(node *x) { int result = -1; bool l = true; while (x != nullptr) { if (l) result += 1 + size(x->lch); if (x->parent == nullptr) break; l = x->parent->rch == x; x = x->parent; } return result; } vector<int> g[200200]; int depth[200200]; node *L[200200]; node *R[200200]; node *tr = nullptr; void dfs(int curr, int prev) { L[curr] = new node(); tr = merge(tr, L[curr]); for (int next : g[curr]) if (next != prev) { depth[next] = depth[curr] + 1; dfs(next, curr); } R[curr] = new node(); tr = merge(tr, R[curr]); } void cut(int u, int v) { if (depth[u] < depth[v]) swap(u, v); int l = index_of(L[u]); int r = index_of(R[u]); node *rt = root(L[u]); auto x = split(rt, r + 1); auto y = split(x.first, l); merge(y.first, x.second); } int main() { int n; cin >> n; for (int i = 0; i < n - 1; i++) { int u, v; scanf("%d %d", &u, &v); u--; v--; g[u].push_back(v); g[v].push_back(u); } int m; cin >> m; dfs(0, -1); int ans = 0; for (int i = 0; i < m; i++) { int x; scanf("%d", &x); int v = (ans ^ x) - 1; ans = size(root(L[v])) / 2; for (int u : g[v]) cut(u, v); printf("%dn", ans); } }
Problem solution in C.
#include <stdlib.h> #include <stdio.h> struct Set { int count; }; typedef struct Set Set; struct node{ int number; struct node * parent; struct node * next; struct node * prev; struct node * first_child; Set * set; }; typedef struct node node; void print_children(node * n){ node * child = n->first_child; while(child){ printf("%dn", child->number); child = child->next; } } void add_child(node * n, node * c){ node * cur = n->first_child; n->first_child = c; if (cur){ cur->prev = c; c->next = cur; } } void fill_children(node * root, node ** nodes, node ** result_nodes){ node * repr = nodes[root->number]; if(repr == 0){ return; } node * child = repr->first_child; while(child){ if (result_nodes[child->number] != 0){ child = child->next; continue; } node * c = calloc(1, sizeof(node)); c->number = child->number; c->parent = root; result_nodes[c->number] = c; add_child(root, c); fill_children(c, nodes, result_nodes); child = child->next; } } void compute_below(node * root, Set * set) { if (set == 0) { set = calloc(1, sizeof(set)); } root->set = set; set->count++; node * child = root -> first_child; while(child){ compute_below(child, set); child = child->next; } } void remove_node(node * item) { // subtract_below(item, item->below+1); int everyChild = item->parent != 0; node * child = item->first_child; int childCount = 0; int toRemove = 1; while (child) { childCount++; if (everyChild || childCount > 1) { compute_below(child, 0); toRemove += child->set->count; } child->parent = 0; child = child->next; } item->set->count -= toRemove; node * parent = item->parent; if(parent){ if(parent->first_child == item){ parent->first_child = item->next; } if(item->next){ item->next->prev = item->prev; } if(item->prev){ item->prev->next = item->next; } } } int main(int argc, char **argv){ int n; scanf("%dn", &n); int i = 0; node ** nodes = calloc(n+1, sizeof(node *)); for(i = 0; i < n-1; i++){ int a,b; scanf("%d %dn", &a, &b); node * node_a = nodes[a]; if(node_a == 0) { node_a = calloc(1, sizeof(node)); node_a->number = a; nodes[a] = node_a; } node * x = calloc(1, sizeof(node)); x->number = b; add_child(node_a,x); node * node_b = nodes[b]; if(node_b == 0){ node_b = calloc(1, sizeof(node)); node_b->number = b; nodes[b] = node_b; } x = calloc(1, sizeof(node)); x->number = a; add_child(node_b, x); } node * root = calloc(1, sizeof(node)); root->number = 1; node ** result_nodes = calloc(n+1, sizeof(node *)); result_nodes[1] = root; fill_children(root, nodes, result_nodes); compute_below(result_nodes[1], 0); int ans = 0; int num_queries; scanf("%dn", &num_queries); for(i = 0; i < num_queries; i++){ int m; scanf("%dn", &m); int q = m^ans; node * n = result_nodes[q]; ans = n->set->count; printf("%dn", ans); remove_node(n); } return 0; }