HackerRank Heavy Light 2 White Falcon problem solution YASH PAL, 31 July 2024 In this HackerRank Heavy Light 2 White Falcon problem solution, we have given a tree with N nodes and each node’s value initially 0. we need to first add x to value first, 2x to value second, and soon. and also we need to print the sum of the nodes’ values on the path between two nodes. Problem solution in Python programming. from operator import attrgetter MOD = 10**9 + 7 def solve(edges, queries): nodes, leaves = make_tree(edges) hld(leaves) results = [] for query in queries: if query[0] == 1: update(nodes[query[1]], nodes[query[2]], query[3]) elif query[0] == 2: results.append(sum_range(nodes[query[1]], nodes[query[2]])) return results def make_tree(edges): nodes = [ Node(i) for i in range(len(edges) + 1) ] # the tree is a graph for now # as we don't know the direction of the edges for edge in edges: nodes[edge[0]].children.append(nodes[edge[1]]) nodes[edge[1]].children.append(nodes[edge[0]]) # pick the root of the tree root = nodes[0] root.depth = 0 # for each node, remove its parent of its children stack = [] leaves = [] for child in root.children: stack.append((child, root, 1)) for node, parent, depth in stack: node.children.remove(parent) node.parent = parent node.depth = depth if len(node.children) == 0: leaves.append(node) continue for child in node.children: stack.append((child, node, depth + 1)) return nodes, leaves def hld(leaves): leaves = sorted(leaves, key=attrgetter('depth'), reverse=True) for leaf in leaves: leaf.chain = Chain() leaf.chain_i = 0 curr_node = leaf while curr_node.parent is not None: curr_chain = curr_node.chain if curr_node.parent.chain is not None: curr_chain.init_fenwick_tree() curr_chain.parent = curr_node.parent.chain curr_chain.parent_i = curr_node.parent.chain_i break curr_node.parent.chain = curr_chain curr_node.parent.chain_i = curr_chain.size curr_node.chain.size += 1 curr_node = curr_node.parent if curr_node.parent is None: curr_chain.init_fenwick_tree() def update(node1, node2, x): path_len = 0 chain1 = node1.chain chain_i1 = node1.chain_i depth1 = node1.depth chains1 = [] chain2 = node2.chain chain_i2 = node2.chain_i depth2 = node2.depth chains2 = [] while chain1 is not chain2: step1 = chain1.size - chain_i1 step2 = chain2.size - chain_i2 if depth1 - step1 > depth2 - step2: path_len += step1 chains1.append((chain1, chain_i1)) depth1 -= step1 chain_i1 = chain1.parent_i chain1 = chain1.parent else: path_len += step2 chains2.append((chain2, chain_i2)) depth2 -= step2 chain_i2 = chain2.parent_i chain2 = chain2.parent path_len += abs(chain_i1 - chain_i2) + 1 curr_val1 = 0 for (chain, chain_i) in chains1: chain.ftree.add(chain_i, chain.size-1, curr_val1, x) curr_val1 += (chain.size - chain_i) * x curr_val2 = (path_len + 1) * x for (chain, chain_i) in chains2: chain.ftree.add(chain_i, chain.size-1, curr_val2, -x) curr_val2 -= (chain.size - chain_i) * x if chain_i1 <= chain_i2: chain1.ftree.add(chain_i1, chain_i2, curr_val1, x) else: chain1.ftree.add(chain_i2, chain_i1, curr_val2, -x) def sum_range(node1, node2): sum_ = 0 chain1 = node1.chain chain_i1 = node1.chain_i depth1 = node1.depth chain2 = node2.chain chain_i2 = node2.chain_i depth2 = node2.depth while chain1 is not chain2: step1 = chain1.size - chain_i1 step2 = chain2.size - chain_i2 if depth1 - step1 > depth2 - step2: sum_ += chain1.ftree.range_sum(chain_i1, chain1.size - 1) depth1 -= step1 chain_i1 = chain1.parent_i chain1 = chain1.parent else: sum_ += chain2.ftree.range_sum(chain_i2, chain2.size - 1) depth2 -= step2 chain_i2 = chain2.parent_i chain2 = chain2.parent if chain_i1 > chain_i2: chain_i1, chain_i2 = chain_i2, chain_i1 sum_ += chain1.ftree.range_sum(chain_i1, chain_i2) return int(sum_ % MOD) class Node(): __slots__ = ['i', 'val', 'parent', 'children', 'depth', 'chain', 'chain_i'] def __init__(self, i): self.i = i self.val = 0 self.parent = None self.depth = None self.children = [] self.chain = None self.chain_i = -1 class Chain(): __slots__ = ['size', 'ftree', 'parent', 'parent_i'] def __init__(self): self.size = 1 self.ftree = None self.parent = None self.parent_i = -1 def init_fenwick_tree(self): self.ftree = RURQFenwickTree(self.size) def g(i): return i & (i + 1) def h(i): return i | (i + 1) class RURQFenwickTree(): def __init__(self, size): self.tree1 = RUPQFenwickTree(size) self.tree2 = RUPQFenwickTree(size) self.tree3 = RUPQFenwickTree(size) def add(self, l, r, k, x): k2 = k * 2 self.tree1.add(l, x) self.tree1.add(r+1, -x) self.tree2.add(l, (3 - 2*l) * x + k2) self.tree2.add(r+1, -((3 - 2*l) * x + k2)) self.tree3.add(l, (l**2 - 3*l + 2) * x + k2 * (1 - l)) self.tree3.add(r+1, (r**2 + 3*r - 2*r*l) * x + k2 * r) def prefix_sum(self, i): sum_ = i**2 * self.tree1.point_query(i) sum_ += i * self.tree2.point_query(i) sum_ += self.tree3.point_query(i) return ((sum_ % (2 * MOD)) / 2) % MOD def range_sum(self, l, r): return self.prefix_sum(r) - self.prefix_sum(l - 1) class RUPQFenwickTree(): def __init__(self, size): self.size = size self.tree = [0] * size def add(self, i, x): j = i while j < self.size: self.tree[j] += x j = h(j) def point_query(self, i): res = 0 j = i while j >= 0: res += self.tree[j] j = g(j) - 1 return res if __name__ == '__main__': nq = input().split() n = int(nq[0]) q = int(nq[1]) tree = [] for _ in range(n-1): tree.append(list(map(int, input().rstrip().split()))) queries = [] for _ in range(q): queries.append(list(map(int, input().rstrip().split()))) results = solve(tree, queries) print('n'.join(map(str, results))) Problem solution in Java Programming. import java.io.*; import java.util.*; public class Solution { static List<Integer>[] adj; static int[] chain; static int[] dep; static int[] par; static class NodeDfs { long size = 1; long maxs = 0; int u; int p; boolean start = true; NodeDfs nodep = null; public NodeDfs(int u, int p, NodeDfs nodep) { this.u = u; this.p = p; this.nodep = nodep; } } static void dfs(int u, int p) { Deque<NodeDfs> deque = new LinkedList<>(); deque.add(new NodeDfs(u, p, null)); while (!deque.isEmpty()) { NodeDfs node = deque.peekLast(); if (node.start) { par[node.u] = node.p; chain[node.u] = -1; for (int v: adj[node.u]) { if (v != node.p) { dep[v] = dep[node.u]+1; deque.add(new NodeDfs(v, node.u, node)); } } node.start = false; } else { if (node.nodep != null) { node.nodep.size += node.size; if (node.size > node.nodep.maxs) { node.nodep.maxs = node.size; chain[node.nodep.u] = node.u; } } deque.removeLast(); } } } static class NodeHld { int u; int p; int top; int start = 0; public NodeHld(int u, int p, int top) { this.u = u; this.p = p; this.top = top; } } static int[] dfn; static int tick = 0; static void hld(int u, int p, int top) { Deque<NodeHld> deque = new LinkedList<>(); deque.add(new NodeHld(u, p, top)); while (!deque.isEmpty()) { NodeHld node = deque.peekLast(); if (node.start == 0) { dfn[node.u] = tick++; if (chain[node.u] >= 0) { deque.add(new NodeHld(chain[node.u], node.u, node.top)); node.start = 1; } else { node.start = 2; } } else if (node.start == 1) { for (int v: adj[node.u]) { if (v != node.p && v != chain[node.u]) { deque.add(new NodeHld(v, node.u, v)); } } node.start = 2; } else { chain[node.u] = node.top; deque.removeLast(); } } } static class Pair { int first = 0; int second = 0; Pair() { } Pair(int first, int second) { this.first = first; this.second = second; } } static List<Pair> path(int u, int v) { List<Pair> ps0 = new ArrayList<>(); List<Pair> ps1 = new ArrayList<>(); while (chain[u] != chain[v]) { if (dep[chain[u]] > dep[chain[v]]) { ps0.add(new Pair(~ dfn[chain[u]], ~ (dfn[u]+1))); u = par[chain[u]]; } else { ps1.add(new Pair(dfn[chain[v]], dfn[v]+1)); v = par[chain[v]]; } } if (dep[u] > dep[v]) { ps0.add(new Pair(~ dfn[v], ~ (dfn[u]+1))); } else { ps1.add(new Pair(dfn[u], dfn[v]+1)); } for (int i = ps1.size()-1; i >= 0; i--) { ps0.add(ps1.get(i)); } return ps0; } static final int LN = 63-Long.numberOfLeadingZeros(50000-1)+1; static final int NN = 1 << LN; static final int MOD = 1_000_000_007; static final int INV2 = (MOD+1)/2; static Pair[] ap = new Pair[2*NN]; static long[] sum = new long[2*NN]; static long sum(long a, long b) { return (a + b) % MOD; } static long mult(long a, long b) { return (a * b) % MOD; } static void apply(int i, long start, Pair x) { long h = LN-(63-Long.numberOfLeadingZeros(i)); long k = 1L << h; long first = sum(x.first, mult((i<<h) - NN - start, x.second)); sum[i] = sum(sum[i], mult(mult(sum(2*first, mult(k-1, x.second)), k), INV2)); ap[i].first = (int) sum(ap[i].first, first); ap[i].second = (int) sum(ap[i].second, x.second); } static void untag(int i) { if (i < 0 || i >= NN) { return; } i += NN; for (int j, h = LN; h > 0; h--) { if ((j = i >> h) > 0 && ap[j].first != 0 || ap[j].second != 0) { apply(2*j, (j << h) - NN, ap[j]); apply(2*j+1, (j << h) - NN, ap[j]); ap[j].first = 0; ap[j].second = 0; } } } static void mconcat(int i) { sum[i] = sum(sum[2*i], sum[2*i+1]); } static long getSum(int l, int r) { long s = 0; untag(l-1); untag(r); for (l += NN, r += NN; l < r; l >>= 1, r >>= 1) { if ((l & 1) > 0) { s = sum(s, sum[l++]); } if ((r & 1) > 0) { s = sum(s, sum[--r]); } } return s; } static void modify(int l, int r, Pair x) { int start = l; boolean lf = false; boolean rf = false; untag(l-1); untag(r); for (l += NN, r += NN; l < r; ) { if ((l & 1) > 0) { lf = true; apply(l++, start, x); } l >>= 1; if (lf) { mconcat(l-1); } if ((r & 1) > 0) { rf = true; apply(--r, start, x); } r >>= 1; if (rf) { mconcat(r); } } for (l--; (l >>= 1) > 0 && (r >>= 1) > 0; ) { if (lf || l == r) { mconcat(l); } if (rf && l != r) { mconcat(r); } } } public static void main(String[] args) throws IOException { BufferedReader br = new BufferedReader(new InputStreamReader(System.in)); BufferedWriter bw = new BufferedWriter(new FileWriter(System.getenv("OUTPUT_PATH"))); StringTokenizer st = new StringTokenizer(br.readLine()); int n = Integer.parseInt(st.nextToken()); int q = Integer.parseInt(st.nextToken()); adj = new List[n]; for (int i = 0; i < n; i++) { adj[i] = new ArrayList<>(); } for (int i = 0; i < n-1; i++) { st = new StringTokenizer(br.readLine()); int u = Integer.parseInt(st.nextToken()); int v = Integer.parseInt(st.nextToken()); adj[u].add(v); adj[v].add(u); } chain = new int[n]; dep = new int[n]; par = new int[n]; dfs(0, -1); dfn = new int[n]; hld(0, -1, 0); for (int i = 0; i < ap.length; i++) { ap[i] = new Pair(); } while (q-- > 0) { st = new StringTokenizer(br.readLine()); int op = Integer.parseInt(st.nextToken()); int u = Integer.parseInt(st.nextToken()); int v = Integer.parseInt(st.nextToken()); List<Pair> ps = path(u, v); if (op == 1) { int x = Integer.parseInt(st.nextToken()); int y = x; for (Pair p: ps) { u = p.first; v = p.second; if (u >= 0) { modify(u, v, new Pair(y, x)); y = (int) sum(y, mult(v-u, x)); } else { modify(~ u, ~ v, new Pair((int)sum(y, mult(u-v-1, x)), - x)); y = (int) sum(y, mult(u-v, x)); } } } else { long ans = 0; for (Pair p: ps) { u = p.first; v = p.second; if (u < 0) { u = ~ u; v = ~ v; } ans = sum(ans, getSum(u, v)); } bw.write(sum(ans, MOD) + "n"); } } bw.newLine(); bw.close(); br.close(); } } Problem solution in C++ programming. #include <bits/stdc++.h> using namespace std; using ll = long long; const int mod = 1e9 + 7; template<int MOD> struct mod_int { static const int Mod = MOD; unsigned x; mod_int() : x(0) { } mod_int(int sig) { int sigt = sig % MOD; if (sigt < 0) sigt += MOD; x = sigt; } mod_int(long long sig) { int sigt = sig % MOD; if (sigt < 0) sigt += MOD; x = sigt; } int get() const { return (int)x; } mod_int &operator+=(mod_int that) { if ((x += that.x) >= MOD) x -= MOD; return *this; } mod_int &operator-=(mod_int that) { if ((x += MOD - that.x) >= MOD) x -= MOD; return *this; } mod_int &operator*=(mod_int that) { x = (unsigned long long)x * that.x % MOD; return *this; } mod_int &operator/=(mod_int that) { return *this *= that.inverse(); } mod_int operator+(mod_int that) const { return mod_int(*this) += that; } mod_int operator-(mod_int that) const { return mod_int(*this) -= that; } mod_int operator*(mod_int that) const { return mod_int(*this) *= that; } mod_int operator/(mod_int that) const { return mod_int(*this) /= that; } mod_int inverse() const { long long a = x, b = MOD, u = 1, v = 0; while (b) { long long t = a / b; a -= t * b; swap(a, b); u -= t * v; swap(u, v); } return mod_int(u); } }; using mint = mod_int<mod>; struct RS { using type = mint; static type id() { return 0; } static type op(const type& l, const type & r) { return l + r; } }; class lct_node { using M = RS; using T = typename M::type; using U = pair<mint, mint>; lct_node *l, *r, *p; bool rev; T val, all; int size; bool flag; U lazy; int pos() { if (p && p->l == this) return 1; if (p && p->r == this) return 3; return 0; } void update() { size = (l ? l->size : 0) + (r ? r->size : 0) + 1; all = M::op(l ? l->all : M::id(), M::op(val, r ? r->all : M::id())); } void update_lazy(const U& v) { if (!flag) lazy = make_pair(0, 0); int ls = !rev ? (l ? l->size : 0) : (r ? r->size : 0); val += v.first + v.second * ls; all += v.first * size + ((v.second * (size - 1)) * size) / 2; lazy = make_pair(M::op(lazy.first, v.first), M::op(lazy.second, v.second)); flag = true; } void rev_data() { lazy = make_pair(lazy.first + lazy.second * (size - 1), mint(0) - lazy.second); } void push() { if (pos()) p->push(); if (rev) { swap(l, r); if (l) l->rev ^= true, l->rev_data(); if (r) r->rev ^= true, r->rev_data(); rev = false; } if (flag) { if (l) l->update_lazy(lazy); if (r) r->update_lazy(make_pair(lazy.first + lazy.second * (l ? l->size + 1 : 1), lazy.second)); flag = false; } } void rot() { lct_node *par = p; lct_node *mid; if (p->l == this) { mid = r; r = par; par->l = mid; } else { mid = l; l = par; par->r = mid; } if (mid) mid->p = par; p = par->p; par->p = this; if (p && p->l == par) p->l = this; if (p && p->r == par) p->r = this; par->update(); update(); } void splay() { push(); while (pos()) { int st = pos() ^ p->pos(); if (!st) p->rot(), rot(); else if (st == 2) rot(), rot(); else rot(); } } public: lct_node() : l(nullptr), r(nullptr), p(nullptr), rev(false), val(M::id()), all(M::id()), size(1), flag(false) {} void expose() { for (lct_node *x = this, *y = nullptr; x; y = x, x = x->p) x->splay(), x->r = y, x->update(); splay(); } void link(lct_node *x) { x->expose(); expose(); p = x; } void evert() { expose(); rev = true; rev_data(); } T find() { expose(); return all; } void update(U v) { expose(); update_lazy(v); } }; const int MAX = 5e4; lct_node lct[MAX]; void build(int v, int prev, const vector<vector<int>>& G) { for (int to : G[v]) if (to != prev) { lct[to].link(&lct[v]); build(to, v, G); } } int main() { ios::sync_with_stdio(false), cin.tie(0); int N, Q; cin >> N >> Q; vector<vector<int>> G(N); for (int i = 0; i < N - 1; i++) { int u, v; cin >> u >> v; G[u].push_back(v); G[v].push_back(u); } build(0, -1, G); while (Q--) { int com, u, v; cin >> com >> u >> v; if (com == 1) { int x; cin >> x; lct[u].evert(); lct[v].update(make_pair(mint(x), mint(x))); } else { lct[u].evert(); printf("%dn", lct[v].find().get()); } } return 0; } Problem solution in C programming. #include <stdio.h> #include <stdlib.h> #include <string.h> typedef struct _lnode{ int x; int w; struct _lnode *next; } lnode; typedef struct _tree{ long long sum; long long offset1; long long offset2; } tree; #define MOD 1000000007 void insert_edge(int x,int y,int w); void dfs0(int u); void dfs1(int u,int c); void preprocess(); int lca(int a,int b); long long sum(int v,int tl,int tr,int l,int r,tree *t); void range_update(int v,int tl,int tr,int pos1,int pos2,long long o1,long long o2,tree *t); void push(int v,int tl,int tr,tree *t); void range_solve(int x,int y,int z); int min(int x,int y); int max(int x,int y); long long solve(int x,int ancestor); int N,cn,level[100000],DP[18][100000],subtree_size[100000],special[100000],node_chain[100000],node_idx[100000],chain_head[100000],chain_len[100000]={0}; lnode *table[100000]={0}; tree *chain[100000]; int main(){ int Q,x,y,i; scanf("%d%d",&N,&Q); for(i=0;i<N-1;i++){ scanf("%d%d",&x,&y); insert_edge(x,y,1); } preprocess(); while(Q--){ scanf("%d",&x); switch(x){ case 1: scanf("%d%d%d",&x,&y,&i); range_solve(x,y,i); break; default: scanf("%d%d",&x,&y); i=lca(x,y); printf("%lldn",(solve(x,i)+solve(y,i)-sum(1,0,chain_len[node_chain[i]]-1,node_idx[i],node_idx[i],chain[node_chain[i]])+MOD)%MOD); } } return 0; } void insert_edge(int x,int y,int w){ lnode *t=malloc(sizeof(lnode)); t->x=y; t->w=w; t->next=table[x]; table[x]=t; t=malloc(sizeof(lnode)); t->x=x; t->w=w; t->next=table[y]; table[y]=t; return; } void dfs0(int u){ lnode *x; subtree_size[u]=1; special[u]=-1; for(x=table[u];x;x=x->next) if(x->x!=DP[0][u]){ DP[0][x->x]=u; level[x->x]=level[u]+1; dfs0(x->x); subtree_size[u]+=subtree_size[x->x]; if(special[u]==-1 || subtree_size[x->x]>subtree_size[special[u]]) special[u]=x->x; } return; } void dfs1(int u,int c){ lnode *x; node_chain[u]=c; node_idx[u]=chain_len[c]++; for(x=table[u];x;x=x->next) if(x->x!=DP[0][u]) if(x->x==special[u]) dfs1(x->x,c); else{ chain_head[cn]=x->x; dfs1(x->x,cn++); } return; } void preprocess(){ int i,j; level[0]=0; DP[0][0]=0; dfs0(0); for(i=1;i<18;i++) for(j=0;j<N;j++) DP[i][j] = DP[i-1][DP[i-1][j]]; cn=1; chain_head[0]=0; dfs1(0,0); for(i=0;i<cn;i++){ chain[i]=(tree*)malloc(4*chain_len[i]*sizeof(tree)); memset(chain[i],0,4*chain_len[i]*sizeof(tree)); } return; } int lca(int a,int b){ int i; if(level[a]>level[b]){ i=a; a=b; b=i; } int d = level[b]-level[a]; for(i=0;i<18;i++) if(d&(1<<i)) b=DP[i][b]; if(a==b)return a; for(i=17;i>=0;i--) if(DP[i][a]!=DP[i][b]) a=DP[i][a],b=DP[i][b]; return DP[0][a]; } long long sum(int v,int tl,int tr,int l,int r,tree *t){ push(v,tl,tr,t); if(l>r) return 0; if(l==tl && r==tr) return t[v].sum; int tm=(tl+tr)/2; return (sum(v*2,tl,tm,l,min(r,tm),t)+sum(v*2+1,tm+1,tr,max(l,tm+1),r,t))%MOD; } void range_update(int v,int tl,int tr,int pos1,int pos2,long long o1,long long o2,tree *t){ push(v,tl,tr,t); if(pos2<tl || pos1>tr) return; int tm=(tl+tr)/2; if(pos1<=tl && pos2>=tr){ t[v].offset1=(o1+o2*(tl-pos1))%MOD; t[v].offset2=o2; } else{ range_update(v*2,tl,tm,pos1,pos2,o1,o2,t); range_update(v*2+1,tm+1,tr,pos1,pos2,o1,o2,t); push(v*2,tl,tm,t); push(v*2+1,tm+1,tr,t); t[v].sum=(t[v*2].sum+t[v*2+1].sum)%MOD; } return; } void push(int v,int tl,int tr,tree *t){ if(!t[v].offset1 && !t[v].offset2) return; t[v].sum=(t[v].sum+(t[v].offset1*2+t[v].offset2*(tr-tl))*(tr-tl+1)/2%MOD)%MOD; if(tl!=tr){ int tm=(tl+tr)/2; t[v*2].offset1=(t[v*2].offset1+t[v].offset1)%MOD; t[v*2+1].offset1=(t[v*2+1].offset1+t[v].offset1+t[v].offset2*(tm-tl+1))%MOD; t[v*2].offset2=(t[v*2].offset2+t[v].offset2)%MOD; t[v*2+1].offset2=(t[v*2+1].offset2+t[v].offset2)%MOD; } t[v].offset1=t[v].offset2=0; return; } void range_solve(int x,int y,int z){ int ca=lca(x,y),ty=y; long long cac=0,cay=0; while(node_chain[x]!=node_chain[ca]){ cac+=node_idx[x]+1; range_update(1,0,chain_len[node_chain[x]]-1,0,node_idx[x],z*cac%MOD,MOD-z,chain[node_chain[x]]); x=DP[0][chain_head[node_chain[x]]]; } cac+=node_idx[x]-node_idx[ca]+1; range_update(1,0,chain_len[node_chain[x]]-1,node_idx[ca],node_idx[x],z*cac%MOD,MOD-z,chain[node_chain[x]]); cac=z*cac%MOD; while(node_chain[ty]!=node_chain[ca]){ cay+=node_idx[ty]+1; ty=DP[0][chain_head[node_chain[ty]]]; } cay+=node_idx[ty]-node_idx[ca]; cay=(cac+z*cay)%MOD; while(node_chain[y]!=node_chain[ca]){ cay=(cay-z*(long long)node_idx[y]%MOD+MOD)%MOD; range_update(1,0,chain_len[node_chain[y]]-1,0,node_idx[y],cay,z,chain[node_chain[y]]); cay=(cay-z+MOD)%MOD; y=DP[0][chain_head[node_chain[y]]]; } cay=(cay-z*(long long)(node_idx[y]-node_idx[ca]-1)%MOD+MOD)%MOD; if((cay-z+MOD)%MOD!=cac) while(1); if(node_idx[y]!=node_idx[ca]) range_update(1,0,chain_len[node_chain[y]]-1,node_idx[ca]+1,node_idx[y],cay,z,chain[node_chain[y]]); return; } int min(int x,int y){ return (x<y)?x:y; } int max(int x,int y){ return (x>y)?x:y; } long long solve(int x,int ancestor){ long long ans=0; while(node_chain[x]!=node_chain[ancestor]){ ans=(ans+sum(1,0,chain_len[node_chain[x]]-1,0,node_idx[x],chain[node_chain[x]]))%MOD; x=DP[0][chain_head[node_chain[x]]]; } ans=(ans+sum(1,0,chain_len[node_chain[x]]-1,node_idx[ancestor],node_idx[x],chain[node_chain[x]]))%MOD; return ans; } coding problems data structure