In this HackerRank Kth Ancestor problem solution, Susan likes to play with graphs, and Tree data structure is one of her favorites. She has designed a problem and wants to know if anyone can solve it. Sometimes she adds or removes a leaf node. Your task is to figure out the Kth parent of a node at any instant.
Problem solution in Python.
import sys from collections import defaultdict, namedtuple from array import array AddNode = namedtuple('AddNode', 'child parent') RemoveNode = namedtuple('RemoveNode', 'node') QueryParent = namedtuple('QueryParent', 'node kth') def log(msg): print(msg, file=sys.stderr) def find_all_leaf_nodes(tree): queue = set([0, ]) while queue: node = queue.pop() if node in tree.children: for c in tree.children[node]: queue.add(c) else: yield node def solve_queries(tree, queries): #log('Tree:') #print_tree(tree) for q in queries: #log(q) if type(q) == AddNode: tree.add_node(q.child, q.parent) elif type(q) == RemoveNode: tree.remove_leaf(q.node) elif type(q) == QueryParent: yield tree.get_kth_parent(q.node, q.kth) def read_ints(reader): for p in (_.strip().split() for _ in reader): yield tuple([int(_) for _ in p]) class Tree(object): def __init__(self): self.children = defaultdict(set) self.parents = dict() self.levels = dict() self.levels[0] = 0 self.ten_p = dict() self.hundred_p = dict() self.thousand_p = dict() self.cache_hits = [0, 0, 0] def add_node(self, child, parent): # first get the level level = self.levels[parent] + 1 self.levels[child] = level self.children[parent].add(child) self.parents[child] = parent if level > 10 and level % 10 == 0: self.ten_p[child] = self.get_kth_parent(child, 10) if level > 100 and level % 100 == 0: self.hundred_p[child] = self.get_kth_parent(child, 100) if level > 1000 and level % 1000 == 0: self.thousand_p[child] = self.get_kth_parent(child, 1000) def remove_leaf(self, node): level = self.levels.pop(node) parent = self.parents.pop(node) self.children[parent].remove(node) if level % 10 == 0: try: self.ten_p.pop(node) except KeyError: pass if level % 100 == 0: try: self.hundred_p.pop(node) except KeyError: pass if level % 1000 == 0: try: self.thousand_p.pop(node) except KeyError: pass def get_kth_parent(self, node, max_back): if node not in self.parents: return 0 if self.levels[node] < max_back: return 0 zero_counter = max_back while node != 0 and zero_counter != 0: if zero_counter > 1000 and node in self.thousand_p: self.cache_hits[2] += 1 node = self.thousand_p[node] zero_counter -= 1000 continue if zero_counter > 100 and node in self.hundred_p: self.cache_hits[1] += 1 node = self.hundred_p[node] zero_counter -= 100 continue if zero_counter > 10 and node in self.ten_p: self.cache_hits[0] += 1 node = self.ten_p[node] zero_counter -= 10 continue node = self.parents[node] zero_counter -= 1 # we are at the root-root node return node def read_instructions(int_lines): number_cases = next(int_lines)[0] for _ in range(number_cases): nodes_in_tree = next(int_lines)[0] tree = Tree() for pos, (child, parent) in enumerate(int_lines): tree.add_node(child, parent) if pos == nodes_in_tree - 1: break number_queries = next(int_lines)[0] queries = list() for pos, vals in enumerate(int_lines): if vals[0] == 0: # notice reversal, to make same as input queries.append(AddNode(vals[2], vals[1])) elif vals[0] == 1: queries.append(RemoveNode(vals[1])) elif vals[0] == 2: queries.append(QueryParent(vals[1], vals[2])) else: raise Exception('Do not know how to handle query of type %d in %s' % (vals[0], vals)) if pos == number_queries - 1: break yield (tree, queries) def main(): for tree, queries in read_instructions(read_ints(sys.stdin)): for answer in solve_queries(tree, queries): print(answer) log('Cached: %s' % (tree.cache_hits)) if __name__ == '__main__': main()
Problem solution in Java.
import java.io.ByteArrayInputStream; import java.io.IOException; import java.io.InputStream; import java.io.PrintWriter; import java.util.Arrays; import java.util.InputMismatchException; public class Solution { static InputStream is; static PrintWriter out; static String INPUT = ""; static void solve() { for(int T = ni(); T >= 1;T--){ int n = 100001; int m = ni(); int[] from = new int[m]; int[] to = new int[m]; for(int i = 0;i < m;i++){ from[i] = ni(); to[i] = ni(); } int[][] g = packU(n, from, to); int[] par = parents(g, 0); int[][] spar = new int[17][n]; for(int i = 0;i < n;i++){ spar[0][i] = par[i]; } for(int d = 1;d < 17;d++){ for(int i = 0;i < n;i++){ spar[d][i] = spar[d-1][i] == -1 ? -1 : spar[d-1][spar[d-1][i]]; } } int Q = ni(); for(int z = 0;z < Q;z++){ int type = ni(); if(type == 0){ // insert int y = ni(), x = ni(); spar[0][x] = y; for(int d = 1;d < 17;d++){ spar[d][x] = spar[d-1][x] == -1 ? -1 : spar[d-1][spar[d-1][x]]; } }else if(type == 1){ // remove int y = ni(); for(int d = 0;d < 17;d++){ spar[d][y] = -1; } }else if(type == 2){ // kth int y = ni(), K = ni(); for(int d = 0;d < 17;d++){ if(K<<31-d<0){ y = spar[d][y]; if(y == -1)break; } } if(y == -1)y = 0; out.println(y); } } } } static int[][] packU(int n, int[] from, int[] to) { int[][] g = new int[n][]; int[] p = new int[n]; for(int f : from) p[f]++; for(int t : to) p[t]++; for(int i = 0;i < n;i++) g[i] = new int[p[i]]; for(int i = 0;i < from.length;i++){ g[from[i]][--p[from[i]]] = to[i]; g[to[i]][--p[to[i]]] = from[i]; } return g; } public static int[] parents(int[][] g, int root) { int n = g.length; int[] par = new int[n]; Arrays.fill(par, -1); int[] q = new int[n]; q[0] = root; for(int p = 0, r = 1;p < r;p++) { int cur = q[p]; for(int nex : g[cur]){ if(par[cur] != nex){ q[r++] = nex; par[nex] = cur; } } } return par; } public static void main(String[] args) throws Exception { long S = System.currentTimeMillis(); is = INPUT.isEmpty() ? System.in : new ByteArrayInputStream(INPUT.getBytes()); out = new PrintWriter(System.out); solve(); out.flush(); long G = System.currentTimeMillis(); tr(G-S+"ms"); } private static boolean eof() { if(lenbuf == -1)return true; int lptr = ptrbuf; while(lptr < lenbuf)if(!isSpaceChar(inbuf[lptr++]))return false; try { is.mark(1000); while(true){ int b = is.read(); if(b == -1){ is.reset(); return true; }else if(!isSpaceChar(b)){ is.reset(); return false; } } } catch (IOException e) { return true; } } private static byte[] inbuf = new byte[1024]; static int lenbuf = 0, ptrbuf = 0; private static int readByte() { if(lenbuf == -1)throw new InputMismatchException(); if(ptrbuf >= lenbuf){ ptrbuf = 0; try { lenbuf = is.read(inbuf); } catch (IOException e) { throw new InputMismatchException(); } if(lenbuf <= 0)return -1; } return inbuf[ptrbuf++]; } private static boolean isSpaceChar(int c) { return !(c >= 33 && c <= 126); } private static int skip() { int b; while((b = readByte()) != -1 && isSpaceChar(b)); return b; } private static double nd() { return Double.parseDouble(ns()); } private static char nc() { return (char)skip(); } private static String ns() { int b = skip(); StringBuilder sb = new StringBuilder(); while(!(isSpaceChar(b))){ // when nextLine, (isSpaceChar(b) && b != ' ') sb.appendCodePoint(b); b = readByte(); } return sb.toString(); } private static char[] ns(int n) { char[] buf = new char[n]; int b = skip(), p = 0; while(p < n && !(isSpaceChar(b))){ buf[p++] = (char)b; b = readByte(); } return n == p ? buf : Arrays.copyOf(buf, p); } private static char[][] nm(int n, int m) { char[][] map = new char[n][]; for(int i = 0;i < n;i++)map[i] = ns(m); return map; } private static int[] na(int n) { int[] a = new int[n]; for(int i = 0;i < n;i++)a[i] = ni(); return a; } private static int ni() { int num = 0, b; boolean minus = false; while((b = readByte()) != -1 && !((b >= '0' && b <= '9') || b == '-')); if(b == '-'){ minus = true; b = readByte(); } while(true){ if(b >= '0' && b <= '9'){ num = num * 10 + (b - '0'); }else{ return minus ? -num : num; } b = readByte(); } } private static long nl() { long num = 0; int b; boolean minus = false; while((b = readByte()) != -1 && !((b >= '0' && b <= '9') || b == '-')); if(b == '-'){ minus = true; b = readByte(); } while(true){ if(b >= '0' && b <= '9'){ num = num * 10 + (b - '0'); }else{ return minus ? -num : num; } b = readByte(); } } private static void tr(Object... o) { if(INPUT.length() != 0)System.out.println(Arrays.deepToString(o)); } }
Problem solution in C++.
#include<iostream> #include<cstdio> #include<cmath> #include<string> #include<algorithm> #include<vector> using namespace std; const int MAXN = 100100; const int MAXK = 20; int parent[MAXN][MAXK]; int cntChild[MAXN]; void solve() { for(int i=0; i<MAXN; i++) { for(int j=0; j<MAXK; j++) parent[i][j] = 0; cntChild[i] = 0; } int N; cin>>N; for(int i=0; i<N; i++) { int x,y; // scanf("%d%d",&x,&y); cin>>x>>y; parent[x][0] = y; cntChild[y]++; } for(int i=1; i<MAXK; i++) { for(int v=1; v<MAXN; v++) { parent[v][i] = parent[parent[v][i-1]][i-1]; // if(parent[v][i]!=0) // cout<<v<<" "<<i<<" "<<parent[v][i]<<endl; } } int Q; cin>>Q; for(int _=0; _<Q; _++) { /* for(int i=0; i<20; i++) { cout<<i<<" : "; for(int j=0; j<20; j++) cout<<parent[i][j]<<" "; cout<<endl; }*/ int kind; cin>>kind; // scanf("%d",&kind); if(kind==1) { int x; cin>>x; // scanf("%d",&x); cntChild[parent[x][0]]--; if(cntChild[x]!=0) for(;;); for(int j=0; j<MAXK; j++) parent[x][j] = 0; } if(kind==0) { int x,y; // scanf("%d%d",&y,&x); cin>>y>>x; // if(y==0) for(;;); parent[x][0]= y; cntChild[y]++; for(int i=1; i<MAXK; i++) { parent[x][i] = parent[parent[x][i-1]][i-1]; } } if(kind==2) { int x,k; // scanf("%d%d",&x,&k); cin>>x>>k; while(k!=0) { int t = 1; int cnt = 0; while(t<=k) { t *= 2; cnt++; } t/=2;cnt--; x = parent[x][cnt]; k -= t; // cout<<x<<" "<<k<<endl; } cout<<x<<endl; } } } int main() { int T; cin>>T; for(int i=0; i<T; i++) solve(); return 0; }
Problem solution in C.
#include <assert.h> #include <limits.h> #include <math.h> #include <stdbool.h> #include <stdio.h> #include <stdlib.h> #include <string.h> #define MAXNODE 300000 #define LOGMAX 20 char* readline(); char** split_string(char*); int* kthparent(int n, int** edges, int q, char** queries, int* result_count){ int pow2parent[LOGMAX][MAXNODE]; for(int i = 0; i < LOGMAX; i++){ for(int j = 0; j < MAXNODE; j++){ pow2parent[i][j] = 0; } } for(int i = 0; i < n; i++){ pow2parent[0][edges[i][0]] = edges[i][1]; } for(int i = 1; i < LOGMAX; i++){ for(int j = 0; j < MAXNODE; j++){ pow2parent[i][j] = pow2parent[i - 1][pow2parent[i - 1][j]]; } } int *toreturn = NULL; *result_count = 0; for(int i = 0; i < q; i++){ char** splitquery = split_string(queries[i]); if(queries[i][0] == '0'){ int parent = atoi(splitquery[1]); int leaf = atoi(splitquery[2]); pow2parent[0][leaf] = parent; for(int j = 1; j < LOGMAX; j++){ pow2parent[j][leaf] = pow2parent[j - 1][pow2parent[j - 1][leaf]]; } } else if(queries[i][0] == '1'){ int leaf = atoi(splitquery[1]); for(int j = 0; j < LOGMAX; j++){ pow2parent[j][leaf] = 0; } } else if(queries[i][0] == '2'){ *result_count += 1; toreturn = realloc(toreturn, (*result_count)*sizeof(int)); int currnode = atoi(splitquery[1]); int target = atoi(splitquery[2]); for(int j = 0; j < LOGMAX; j++){ if(((target>>j)&1) == 1){ currnode = pow2parent[j][currnode]; } } toreturn[(*result_count) - 1] = currnode; } else{ exit(EXIT_FAILURE); } } return toreturn; } int main() { int t; scanf("%dn", &t); for(int i = 0; i < t; i++){ int n; scanf("%dn", &n); int** edges = malloc(n*sizeof(int*)); for(int j = 0; j < n; j++){ edges[j] = malloc(2*sizeof(int)); scanf("%d %dn", edges[j], edges[j] + 1); } int q; scanf("%dn", &q); char** queries = malloc(q*sizeof(char*)); for(int j = 0; j < q; j++){ queries[j] = readline(); } int result_count; int* result = kthparent(n, edges, q, queries, &result_count); for(int j = 0; j < result_count; j++){ printf("%dn", result[j]); } } return 0; } char* readline() { size_t alloc_length = 1024; size_t data_length = 0; char* data = malloc(alloc_length); while (true) { char* cursor = data + data_length; char* line = fgets(cursor, alloc_length - data_length, stdin); if (!line) { break; } data_length += strlen(cursor); if (data_length < alloc_length - 1 || data[data_length - 1] == 'n') { break; } size_t new_length = alloc_length << 1; data = realloc(data, new_length); if (!data) { break; } alloc_length = new_length; } if (data[data_length - 1] == 'n') { data[data_length - 1] = ' '; } data = realloc(data, data_length); return data; } char** split_string(char* str) { char** splits = NULL; char* token = strtok(str, " "); int spaces = 0; while (token) { splits = realloc(splits, sizeof(char*) * ++spaces); if (!splits) { return splits; } splits[spaces - 1] = token; token = strtok(NULL, " "); } return splits; }