In this HackerRank Similar Pair problem solution you have given a tree where each node is labeled from 1 to n, find the number of similar pairs in the tree. remember that a pair of nodes (a,b) is similar pair if node a is the ancestor of node b and the absolute difference between a and b is less than variable k.
Problem solution in Python.
import resource import sys sys.setrecursionlimit(2000000) def add(x, v): x += 1 while x <= n: a[x] += v x += x & -x def que(x): x += 1 if x <= 0: return 0 ret = 0 x = min(n, x) while x > 0: ret += a[x] x -= x & -x return ret st = [] vis = {} def dfs(x): global ans st.append(x) while st: x = st[-1] if not x in vis: ans += que(x + T) - que(x - T - 1) add(x, 1) vis[x] = 1 if nx[x]: st.append(nx[x][-1]) nx[x].pop() else: st.pop() add(x, -1) n, T = (int(x) for x in input().split()) a = [0 for i in range(4 * n)] nx = [[] for i in range(n)] pre = [-1 for i in range(n)] for i in range(n - 1): s, e = (int(x) - 1 for x in input().split()) nx[s].append(e) pre[e] = s s = 1 while pre[s] != -1: s = pre[s] ans = 0 dfs(s) print(ans)
Problem solution in Java.
import java.io.*; import java.util.*; public class Solution { public static LinkedList<Integer>[] nodes = new LinkedList[100002]; static int n , t, root; public static void main(String[] args) { /* Enter your code here. Read input from STDIN. Print output to STDOUT. Your class should be named Solution. */ Scanner scan = new Scanner(System.in); n = scan.nextInt(); t = scan.nextInt(); long[] stree = new long[4*n+1]; for(int i=1;i<=n;i++) nodes[i] = new LinkedList<Integer>(); int[] idegree = new int[n+1]; for(int i=1;i<n;i++) { int par = scan.nextInt(); int chd = scan.nextInt(); nodes[par].addFirst(chd); idegree[chd]++; } for(int i=1;i<=n;i++) { if(idegree[i] == 0) { root = i; break; } } long[] pairs = new long[1]; depthSearch(root,stree,pairs); System.out.println(pairs[0]); } public static void depthSearch(int nodeval, long[] stree, long[] pairs){ int min = (nodeval - t < 1) ? 1 : nodeval - t; int max = (nodeval + t > n) ? n : nodeval + t; pairs[0] += query(stree,1,1,n,min, max); updateTree(stree,1,1,n,nodeval,1); for(int chd : nodes[nodeval]){ depthSearch(chd, stree, pairs); } updateTree(stree,1,1,n,nodeval,-1); } public static void updateTree(long[] tree, int node,int tl, int tr, int val, long opt){ if(val < tl || val > tr || tl > tr) return; tree[node] += opt; int m = (tl + tr) >> 1; if(tl == tr) return; else if(val <= m) updateTree(tree,node<<1,tl,m,val,opt); else updateTree(tree,node<<1|1,m+1,tr,val,opt); } public static long query(long[] tree, int node, int tl, int tr, int min, int max){ if(max < tl || min > tr) return 0; else if(max == tr && min == tl) return tree[node]; else{ int mid = (tl + tr) >> 1; int lmax = (mid < max) ? mid : max; int rmin = (min > mid) ? min : mid + 1; return query(tree,node<<1, tl, mid, min, lmax) + query(tree,node<<1|1, mid+1, tr, rmin, max); } } }
Problem solution in C++.
#include<iostream> #include<vector> using namespace std; vector<int> graph[110001]; int T, N, deg[100001] = {0}; long long ST[100001*4] = {0}; void update(int node, int b, int e, int idx, int val) { if(b > node || e < node) return; if(b == e) { ST[idx] += val; return; } int m = (b + e) >> 1; int q = idx << 1; update(node, b, m, q, val); update(node, m + 1, e, q + 1, val); ST[idx] = ST[q] + ST[q+1]; } long long Query(int l, int r, int b, int e, int idx) { if( l > e || r < b) return 0; if(l <= b && r >= e) return ST[idx]; int m = (b + e) >> 1; int q = idx << 1; return Query(l, r, b, m, q) + Query(l, r, m + 1, e, q + 1); } long long SimilarPairs(int node) { int l = max(1, node - T), r = min(N, node + T); long long res = 0; res = Query(l, r, 1, N, 1); update(node, 1, N, 1, 1); for(int i = 0; i < graph[node].size(); i++) { res += SimilarPairs(graph[node][i]); } update(node, 1, N, 1, -1); return res; } int main() { long x, y, root, start; cin >> N >> T; for(int i = 0; i < N - 1; i++) { cin >> x >> y; graph[x].push_back(y); deg[y]++; } for(int i = 1; i <= N; i++) if(!deg[i]) root = i; long long result = SimilarPairs(root); cout << result << endl; cin.get(); return 0; }
Problem solution in C.
#include "stdio.h" #include "stdlib.h" #include "string.h" #include "math.h" typedef struct Node { struct Node *parent; struct Node *peer_next; struct Node *child_list; int val; struct Node *hash_next; }Node; unsigned long long int count; unsigned int n,T,size; Node **hash; Node *root=NULL; unsigned int diff(int a, int b) { if(a>b) return (a-b); else return (b-a); } void countup(Node *x) { int i,val; if(!x || !x->parent) return; if((n-T) < size) { count+=size; for(i=0;i<(((x->val-1)>T)?(x->val-1-T):0); i++) if(hash[i]) count--; for(i=(((x->val+T)>n)?n:(x->val+T));i<n; i++) if(hash[i]) count--; } else if(T > size) { val=x->val; x=x->parent; while(x) { if(diff(val,x->val) <= T) count++; x=x->parent; } } else { for(i=((x->val-1)>T)?(x->val-1-T):0; i<(((x->val+T)>n)?n:(x->val+T)); i++) { if(hash[i]) { //printf("%2d, 0x%xn",i,hash[i]); count++; } } } } void solve() { Node *tmp=root; Node *tmp1; int i; for(i=0;i<n;i++) hash[i]=NULL; size=0; while(tmp) { while(tmp->child_list) { hash[(tmp->val-1)%n]=tmp; size++; tmp=tmp->child_list; } countup(tmp); tmp1=tmp; tmp=tmp->parent; if(tmp)// && (tmp->child_list == tmp1)) { hash[(tmp->val-1)%n]=NULL; size--; tmp->child_list=tmp1->peer_next; } //printf("node = %3d (count = %d)n",tmp1->val,count); free(tmp1); } } Node* allocate(unsigned int val) { Node *node=malloc(sizeof(Node)); memset(node,0,sizeof(Node)); node->val=val; return node; } Node* insert(unsigned int val) { Node *tmp=hash[val%n]; if(!tmp) { return (hash[val%n]=allocate(val)); } while(tmp) { if(tmp->val==val) return tmp; if(!tmp->hash_next) break; tmp=tmp->hash_next; } return (tmp->hash_next=allocate(val)); } void connect(Node *parent, Node *child) { if(!parent || !child) return; /*if(!parent->child_list) parent->child_list=child; else { Node *peer=parent->child_list; while(peer->peer_next) peer=peer->peer_next; peer->peer_next=child; }*/ child->peer_next=parent->child_list; parent->child_list=child; child->parent=parent; } void build(){ int i,a,b; Node *parent,*child; for(i=0;i<n-1;i++) { scanf("%d %d",&a,&b); parent=insert(a); child=insert(b); //printf("%d %dn",parent->val,child->val); connect(parent,child); /*if(!parent->parent) root=parent;*/ } root=hash[1]; while(root && root->parent) root=root->parent; } void print(Node *node, int level) { int i=level; if(!node) return; while(i--) printf(" "); printf("%d (%d)n",node->val,node->parent?node->parent->val:0); node=node->child_list; while(node) { print(node,level+1); node=node->peer_next; } } int main(){ count=0; scanf("%d %d",&n,&T); hash=malloc(n*sizeof(Node*)); memset(hash,0,n*sizeof(Node*)); if (!hash) return -1; build(); //print(root, 0); solve(); printf("%llun",count); return 0; }