In this HackerRank White Falcon And Tree problem solution, you are given a tree with N nodes. and each node contains a linear function. and first, we need to assign the ax + b function of all the nodes on the path and then calculate the modulo of the function with the given expression.
Problem solution in Java Programming.
import java.io.ByteArrayInputStream; import java.io.IOException; import java.io.InputStream; import java.io.PrintWriter; import java.util.Arrays; import java.util.InputMismatchException; public class Solution { static InputStream is; static PrintWriter out; static String INPUT = ""; static int mod = 1000000007; static void solve() { int n = ni(); int[][] co = new int[n][]; for(int i = 0;i < n;i++){ co[i] = new int[]{ni(), ni()}; } int[] from = new int[n-1]; int[] to = new int[n-1]; for(int i = 0;i < n-1;i++){ from[i] = ni()-1; to[i] = ni()-1; } int[][] g = packU(n, from, to); int[][] pars = parents3(g, 0); int[] par = pars[0], ord = pars[1], dep = pars[2]; int[] clus = decomposeToHeavyLight(g, par, ord); int[][] cluspath = clusPaths(clus, ord); int[] clusiind = clusIInd(cluspath, n); SegmentTreeNodePlus[] sts = new SegmentTreeNodePlus[cluspath.length]; for(int i = 0;i < cluspath.length;i++){ int[][] lco = new int[cluspath[i].length][]; for(int j = 0;j < cluspath[i].length;j++){ lco[j] = co[cluspath[i][j]]; } sts[i] = new SegmentTreeNodePlus(lco); } int[][] spar = logstepParents(par); int Q = ni(); for(int z = 0;z < Q;z++){ int t = ni(); if(t == 1){ int u = ni()-1, v = ni()-1, a = ni(), b = ni(); int lca = lca2(u, v, spar, dep); int[][] pr = query2(u, lca, v, clus, cluspath, clusiind, par); for(int[] e : pr){ sts[e[0]].update(Math.min(e[1], e[2]), Math.max(e[1], e[2])+1, a, b); } }else{ int u = ni()-1, v = ni()-1; long x = ni(); int lca = lca2(u, v, spar, dep); int[][] pr = query2(u, lca, v, clus, cluspath, clusiind, par); for(int[] e : pr){ if(e[1] <= e[2]){ x = sts[e[0]].apply(e[1], e[2]+1, x, false); }else{ x = sts[e[0]].apply(e[2], e[1]+1, x, true); } } out.println(x); } } } public static int lca2(int a, int b, int[][] spar, int[] depth) { if (depth[a] < depth[b]) { b = ancestor(b, depth[b] - depth[a], spar); } else if (depth[a] > depth[b]) { a = ancestor(a, depth[a] - depth[b], spar); } if (a == b) return a; int sa = a, sb = b; for (int low = 0, high = depth[a], t = Integer.highestOneBit(high), k = Integer .numberOfTrailingZeros(t); t > 0; t >>>= 1, k--) { if ((low ^ high) >= t) { if (spar[k][sa] != spar[k][sb]) { low |= t; sa = spar[k][sa]; sb = spar[k][sb]; } else { high = low | t - 1; } } } return spar[0][sa]; } protected static int ancestor(int a, int m, int[][] spar) { for (int i = 0; m > 0 && a != -1; m >>>= 1, i++) { if ((m & 1) == 1) a = spar[i][a]; } return a; } public static int[][] logstepParents(int[] par) { int n = par.length; int m = Integer.numberOfTrailingZeros(Integer.highestOneBit(n - 1)) + 1; int[][] pars = new int[m][n]; pars[0] = par; for (int j = 1; j < m; j++) { for (int i = 0; i < n; i++) { pars[j][i] = pars[j - 1][i] == -1 ? -1 : pars[j - 1][pars[j - 1][i]]; } } return pars; } public static class SegmentTreeNodePlus { public int M, H, N; public Node[] node; public int[][] cover; private static class Node { long co; long lc; long rc; public Node() { co = 1; lc = rc = 0; } public Node(long co, long lc, long rc) { this.co = co; this.lc = lc; this.rc = rc; } public long apply(long x, boolean dir) { if(!dir){ return (co * x + lc) % mod; }else{ return (co * x + rc) % mod; } } } public SegmentTreeNodePlus(int[][] co) { N = co.length; M = Integer.highestOneBit(Math.max(N-1, 1))<<2; H = M>>>1; node = new Node[M]; cover = new int[H][]; for(int i = 0;i < N;i++){ node[H+i] = new Node(co[i][0], co[i][1], co[i][1]); } for(int i = H-1;i >= 1;i--)propagate(i); } private void propagate(int cur) { node[cur] = prop2(node[2*cur], node[2*cur+1], cover[cur], node[cur], H/Integer.highestOneBit(cur)); } static final int mod = 1000000007; private Node prop2(Node L, Node R, int[] cover, Node C, int len) { if(L != null && R != null){ if(C == null)C = new Node(); if(cover == null){ C.co = L.co * R.co % mod; C.lc = (R.co * L.lc + R.lc) % mod; C.rc = (L.co * R.rc + L.rc) % mod; }else{ long co = cover[0], c = cover[1]; for(int x = len;x > 1;x >>>= 1){ long nco = co * co % mod; long nc = (co * c + c) % mod; co = nco; c = nc; } C.co = co; C.lc = C.rc = c; } return C; }else if(L != null){ return prop1(L, cover, C, len); }else if(R != null){ return prop1(R, cover, C, len); }else{ return null; } } private Node prop1(Node L, int[] cover, Node C, int len) { if(C == null)C = new Node(); if(cover == null){ C.co = L.co; C.lc = L.lc; C.rc = L.rc; }else{ long co = cover[0], c = cover[1]; for(int x = len;x > 1;x >>>= 1){ long nco = co * co % mod; long nc = (co * c + c) % mod; co = nco; c = nc; } C.co = co; C.lc = C.rc = c; } return C; } int[] temp = null; public void update(int l, int r, int a, int b) { temp = new int[]{a, b}; if(l < r)update(l, r, a, b, 0, H, 1); } protected void update(int l, int r, int a, int b, int cl, int cr, int cur) { if(cur >= H){ node[cur].co = a; node[cur].lc = node[cur].rc = b; }else if(l <= cl && cr <= r){ cover[cur] = temp; propagate(cur); }else{ int mid = cl+cr>>>1; boolean bp = false; if(cover[cur] != null){ if(2*cur < H){ cover[2*cur] = cover[cur]; cover[2*cur+1] = cover[cur]; cover[cur] = null; bp = true; }else{ node[2*cur].co = cover[cur][0]; node[2*cur].lc = node[2*cur].rc = cover[cur][1]; node[2*cur+1].co = cover[cur][0]; node[2*cur+1].lc = node[2*cur+1].rc = cover[cur][1]; cover[cur] = null; } } if(cl < r && l < mid){ update(l, r, a, b, cl, mid, 2*cur); }else if(bp){ propagate(2*cur); } if(mid < r && l < cr){ update(l, r, a, b, mid, cr, 2*cur+1); }else if(bp){ propagate(2*cur+1); } propagate(cur); } } public long apply(int l, int r, long x, boolean dir) { return apply(l, r, x, dir, 0, H, 1); } protected long apply(int l, int r, long x, boolean dir, int cl, int cr, int cur) { if(l <= cl && cr <= r){ return node[cur].apply(x, dir); }else{ int mid = cl+cr>>>1; if(cover[cur] != null){ long co = cover[cur][0], c = cover[cur][1]; for(int h = Math.min(r, cr) - Math.max(l, cl);h > 0;h >>>= 1){ if((h&1) == 1){ x = (co * x + c) % mod; } long nco = co * co % mod; long nc = (co * c + c) % mod; co = nco; c = nc; } return x; } if(!dir){ if(cl < r && l < mid){ x = apply(l, r, x, dir, cl, mid, 2*cur); } if(mid < r && l < cr){ x = apply(l, r, x, dir, mid, cr, 2*cur+1); } }else{ if(mid < r && l < cr){ x = apply(l, r, x, dir, mid, cr, 2*cur+1); } if(cl < r && l < mid){ x = apply(l, r, x, dir, cl, mid, 2*cur); } } return x; } } } public static int[][] parents3(int[][] g, int root) { int n = g.length; int[] par = new int[n]; Arrays.fill(par, -1); int[] depth = new int[n]; depth[0] = 0; int[] q = new int[n]; q[0] = root; for (int p = 0, r = 1; p < r; p++) { int cur = q[p]; for (int nex : g[cur]) { if (par[cur] != nex) { q[r++] = nex; par[nex] = cur; depth[nex] = depth[cur] + 1; } } } return new int[][] { par, q, depth }; } public static int[] decomposeToHeavyLight(int[][] g, int[] par, int[] ord) { int n = g.length; int[] size = new int[n]; Arrays.fill(size, 1); for(int i = n-1;i > 0;i--)size[par[ord[i]]] += size[ord[i]]; int[] clus = new int[n]; Arrays.fill(clus, -1); int p = 0; outer: for(int i = 0;i < n;i++){ int u = ord[i]; if(clus[u] == -1)clus[u] = p++; for(int v : g[u]){ if(par[u] != v && size[v] >= size[u]/2){ clus[v] = clus[u]; continue outer; } } for(int v : g[u]){ if(par[u] != v){ clus[v] = clus[u]; break; } } } return clus; } public static int[][] clusPaths(int[] clus, int[] ord) { int n = clus.length; int[] rp = new int[n]; int sup = 0; for(int i = 0;i < n;i++){ rp[clus[i]]++; sup = Math.max(sup, clus[i]); } sup++; int[][] row = new int[sup][]; for(int i = 0;i < sup;i++)row[i] = new int[rp[i]]; for(int i = n-1;i >= 0;i--){ row[clus[ord[i]]][--rp[clus[ord[i]]]] = ord[i]; } return row; } public static int[] clusIInd(int[][] clusPath, int n) { int[] iind = new int[n]; for(int[] path : clusPath){ for(int i = 0;i < path.length;i++){ iind[path[i]] = i; } } return iind; } public static int[][] query2(int x, int anc, int y, int[] clus, int[][] cluspath, int[] clusiind, int[] par) { int[][] stack = new int[60][]; int sp = 0; int cx = clus[x]; int indx = clusiind[x]; while(cx != clus[anc]){ stack[sp++] = new int[]{cx, indx, 0}; int con = par[cluspath[cx][0]]; indx = clusiind[con]; cx = clus[con]; } stack[sp++] = new int[]{cx, indx, clusiind[anc]}; int top = sp; int cy = clus[y]; int indy = clusiind[y]; while(cy != clus[anc]){ stack[sp++] = new int[]{cy, 0, indy}; int con = par[cluspath[cy][0]]; indy = clusiind[con]; cy = clus[con]; } if(clusiind[anc] < indy){ stack[sp++] = new int[]{cy, clusiind[anc]+1, indy}; } for(int p = top, q = sp-1;p < q;p++,q--){ int[] dum = stack[p]; stack[p] = stack[q]; stack[q] = dum; } return Arrays.copyOf(stack, sp); } static int[][] packU(int n, int[] from, int[] to) { int[][] g = new int[n][]; int[] p = new int[n]; for (int f : from) p[f]++; for (int t : to) p[t]++; for (int i = 0; i < n; i++) g[i] = new int[p[i]]; for (int i = 0; i < from.length; i++) { g[from[i]][--p[from[i]]] = to[i]; g[to[i]][--p[to[i]]] = from[i]; } return g; } public static void main(String[] args) throws Exception { is = INPUT.isEmpty() ? System.in : new ByteArrayInputStream(INPUT.getBytes()); out = new PrintWriter(System.out); solve(); out.flush(); } private static boolean eof() { if(lenbuf == -1)return true; int lptr = ptrbuf; while(lptr < lenbuf)if(!isSpaceChar(inbuf[lptr++]))return false; try { is.mark(1000); while(true){ int b = is.read(); if(b == -1){ is.reset(); return true; }else if(!isSpaceChar(b)){ is.reset(); return false; } } } catch (IOException e) { return true; } } private static byte[] inbuf = new byte[1024]; static int lenbuf = 0, ptrbuf = 0; private static int readByte() { if(lenbuf == -1)throw new InputMismatchException(); if(ptrbuf >= lenbuf){ ptrbuf = 0; try { lenbuf = is.read(inbuf); } catch (IOException e) { throw new InputMismatchException(); } if(lenbuf <= 0)return -1; } return inbuf[ptrbuf++]; } private static boolean isSpaceChar(int c) { return !(c >= 33 && c <= 126); } private static int skip() { int b; while((b = readByte()) != -1 && isSpaceChar(b)); return b; } private static double nd() { return Double.parseDouble(ns()); } private static char nc() { return (char)skip(); } private static String ns() { int b = skip(); StringBuilder sb = new StringBuilder(); while(!(isSpaceChar(b))){ // when nextLine, (isSpaceChar(b) && b != ' ') sb.appendCodePoint(b); b = readByte(); } return sb.toString(); } private static char[] ns(int n) { char[] buf = new char[n]; int b = skip(), p = 0; while(p < n && !(isSpaceChar(b))){ buf[p++] = (char)b; b = readByte(); } return n == p ? buf : Arrays.copyOf(buf, p); } private static char[][] nm(int n, int m) { char[][] map = new char[n][]; for(int i = 0;i < n;i++)map[i] = ns(m); return map; } private static int[] na(int n) { int[] a = new int[n]; for(int i = 0;i < n;i++)a[i] = ni(); return a; } private static int ni() { int num = 0, b; boolean minus = false; while((b = readByte()) != -1 && !((b >= '0' && b <= '9') || b == '-')); if(b == '-'){ minus = true; b = readByte(); } while(true){ if(b >= '0' && b <= '9'){ num = num * 10 + (b - '0'); }else{ return minus ? -num : num; } b = readByte(); } } private static long nl() { long num = 0; int b; boolean minus = false; while((b = readByte()) != -1 && !((b >= '0' && b <= '9') || b == '-')); if(b == '-'){ minus = true; b = readByte(); } while(true){ if(b >= '0' && b <= '9'){ num = num * 10 + (b - '0'); }else{ return minus ? -num : num; } b = readByte(); } } }
Problem solution in C++ programming.
#define _CRT_SECURE_NO_WARNINGS #include <string> #include <vector> #include <algorithm> #include <numeric> #include <set> #include <map> #include <queue> #include <iostream> #include <sstream> #include <cstdio> #include <cmath> #include <ctime> #include <cstring> #include <cctype> #include <cassert> #include <limits> #include <functional> #define rep(i,n) for(int (i)=0;(i)<(int)(n);++(i)) #define rer(i,l,u) for(int (i)=(int)(l);(i)<=(int)(u);++(i)) #define reu(i,l,u) for(int (i)=(int)(l);(i)<(int)(u);++(i)) #if defined(_MSC_VER) || __cplusplus > 199711L #define aut(r,v) auto r = (v) #else #define aut(r,v) typeof(v) r = (v) #endif #define each(it,o) for(aut(it, (o).begin()); it != (o).end(); ++ it) #define all(o) (o).begin(), (o).end() #define pb(x) push_back(x) #define mp(x,y) make_pair((x),(y)) #define mset(m,v) memset(m,v,sizeof(m)) #define INF 0x3f3f3f3f #define INFL 0x3f3f3f3f3f3f3f3fLL using namespace std; typedef vector<int> vi; typedef pair<int,int> pii; typedef vector<pair<int,int> > vpii; typedef long long ll; typedef vector<long long> vl; typedef pair<long long,long long> pll; typedef vector<pair<long long,long long> > vpll; typedef vector<string> vs; typedef long double ld; template<typename T, typename U> inline void amin(T &x, U y) { if(y < x) x = y; } template<typename T, typename U> inline void amax(T &x, U y) { if(x < y) x = y; } template<int MOD> struct ModInt { static const int Mod = MOD; unsigned x; ModInt(): x(0) { } ModInt(signed sig) { int sigt = sig % MOD; if(sigt < 0) sigt += MOD; x = sigt; } ModInt(signed long long sig) { int sigt = sig % MOD; if(sigt < 0) sigt += MOD; x = sigt; } int get() const { return (int)x; } ModInt &operator+=(ModInt that) { if((x += that.x) >= MOD) x -= MOD; return *this; } ModInt &operator-=(ModInt that) { if((x += MOD - that.x) >= MOD) x -= MOD; return *this; } ModInt &operator*=(ModInt that) { x = (unsigned long long)x * that.x % MOD; return *this; } ModInt operator+(ModInt that) const { return ModInt(*this) += that; } ModInt operator-(ModInt that) const { return ModInt(*this) -= that; } ModInt operator*(ModInt that) const { return ModInt(*this) *= that; } }; typedef ModInt<1000000007> mint; //y = ax + b struct LinearExpr { mint a, b; LinearExpr(): a(1), b(0) { } LinearExpr(mint a_, mint b_): a(a_), b(b_) { } LinearExpr(const LinearExpr &val, int) { a = val.a, b = val.b; } LinearExpr &operator+=(const LinearExpr &that) { b = b * that.a + that.b; a = a * that.a; return *this; } LinearExpr operator+(const LinearExpr &that) const { return LinearExpr(*this) += that; } LinearExpr operator*(int k) const { LinearExpr a = *this, r; while(k) { if(k & 1) r += a; a += a; k >>= 1; } return r; } mint evalute(mint x) const { return a * x + b; } }; typedef LinearExpr Val; struct Sum { LinearExpr forward, backward; Sum(): forward(), backward() { } Sum(const Val &val, int): forward(val), backward(val) { } Sum &operator+=(const Sum &that) { forward += that.forward; backward = that.backward + backward; return *this; } Sum operator+(const Sum &that) const { return Sum(*this) += that; } }; struct Laziness { bool fill; LinearExpr expr; Laziness(): fill(false) { } Laziness(LinearExpr expr_): fill(true), expr(expr_) { } Laziness &operator+=(const Laziness &that) { if(that.fill) *this = that; return *this; } void addToVal(Val &val, int) const { if(fill) val = expr; } void addToSum(Sum &sum, int left, int right) const { if(fill) { LinearExpr multiplicated = expr * (right - left); sum.forward = sum.backward = multiplicated; } } }; struct SegmentTree { vector<Val> leafs; vector<Sum> nodes; vector<Laziness> laziness; vector<int> leftpos, rightpos; int n, n2; void init(int n_, const Val &v = Val()) { init(vector<Val>(n_, v)); } void init(const vector<Val> &u) { n = 1; while(n < (int)u.size()) n *= 2; n2 = (n - 1) / 2 + 1; leafs = u; leafs.resize(n, Val()); nodes.resize(n); for(int i = n-1; i >= n2; -- i) nodes[i] = Sum(leafs[i*2-n], i*2-n) + Sum(leafs[i*2+1-n], i*2+1-n); for(int i = n2-1; i > 0; -- i) nodes[i] = nodes[i*2] + nodes[i*2+1]; laziness.assign(n, Laziness()); leftpos.resize(n); rightpos.resize(n); for(int i = n-1; i >= n2; -- i) { leftpos[i] = i*2-n; rightpos[i] = (i*2+1-n) + 1; } for(int i = n2-1; i > 0; -- i) { leftpos[i] = leftpos[i*2]; rightpos[i] = rightpos[i*2+1]; } } Val get(int i) { int indices[128]; int k = getIndices(indices, i, i+1); propagateRange(indices, k); return leafs[i]; } Sum getRangeCommutative(int i, int j) { int indices[128]; int k = getIndices(indices, i, j); propagateRange(indices, k); Sum res = Sum(); for(int l = i + n, r = j + n; l < r; l >>= 1, r >>= 1) { if(l & 1) res += sum(l ++); if(r & 1) res += sum(-- r); } return res; } Sum getRange(int i, int j) { int indices[128]; int k = getIndices(indices, i, j); propagateRange(indices, k); Sum res = Sum(); for(; i && i + (i&-i) <= j; i += i&-i) res += sum((n+i) / (i&-i)); for(k = 0; i < j; j -= j&-j) indices[k ++] = (n+j) / (j&-j) - 1; while(-- k >= 0) res += sum(indices[k]); return res; } void set(int i, const Val &x) { int indices[128]; int k = getIndices(indices, i, i+1); propagateRange(indices, k); leafs[i] = x; mergeRange(indices, k); } void addToRange(int i, int j, const Laziness &x) { if(i >= j) return; int indices[128]; int k = getIndices(indices, i, j); propagateRange(indices, k); int l = i + n, r = j + n; if(l & 1) { int p = (l ++) - n; x.addToVal(leafs[p], p); } if(r & 1) { int p = (-- r) - n; x.addToVal(leafs[p], p); } for(l >>= 1, r >>= 1; l < r; l >>= 1, r >>= 1) { if(l & 1) laziness[l ++] += x; if(r & 1) laziness[-- r] += x; } mergeRange(indices, k); } private: int getIndices(int indices[], int i, int j) const { int k = 0, l, r; if(i >= j) return 0; for(l = (n + i) >> 1, r = (n + j - 1) >> 1; l != r; l >>= 1, r >>= 1) { indices[k ++] = l; indices[k ++] = r; } for(; l; l >>= 1) indices[k ++] = l; return k; } void propagateRange(int indices[], int k) { for(int i = k - 1; i >= 0; -- i) propagate(indices[i]); } void mergeRange(int indices[], int k) { for(int i = 0; i < k; ++ i) merge(indices[i]); } inline void propagate(int i) { if(i >= n) return; laziness[i].addToSum(nodes[i], leftpos[i], rightpos[i]); if(i * 2 < n) { laziness[i * 2] += laziness[i]; laziness[i * 2 + 1] += laziness[i]; }else { laziness[i].addToVal(leafs[i * 2 - n], i * 2 - n); laziness[i].addToVal(leafs[i * 2 + 1 - n], i * 2 + 1 - n); } laziness[i] = Laziness(); } inline void merge(int i) { if(i >= n) return; nodes[i] = sum(i * 2) + sum(i * 2 + 1); } inline Sum sum(int i) { propagate(i); return i < n ? nodes[i] : Sum(leafs[i - n], i - n); } }; struct CentroidPathDecomposition { vector<int> colors, positions; //Vertex -> Color, Vertex -> Offset vector<int> lengths, parents, branches; //Color -> Int, Color -> Color, Color -> Offset vector<int> parentnodes, depths; //Vertex -> Vertex, Vertex -> Int //vector<FenwickTree>??????1?????????? //sortednodes?[lefts[v], rights[v])?v?subtree?????? vector<int> sortednodes, offsets; //Index -> Vertex, Color -> Index vector<int> lefts, rights; //Vertex -> Index struct BuildDFSState { int i, len, parent; BuildDFSState() { } BuildDFSState(int i_, int l, int p): i(i_), len(l), parent(p) { } }; //?????????????????????????? void build(const vector<vi> &g, int root) { int n = g.size(); colors.assign(n, -1); positions.assign(n, -1); lengths.clear(); parents.clear(); branches.clear(); parentnodes.assign(n, -1); depths.assign(n, -1); sortednodes.clear(); offsets.clear(); lefts.assign(n, -1); rights.assign(n, -1); vector<int> subtreesizes; measure(g, root, subtreesizes); typedef BuildDFSState State; depths[root] = 0; vector<State> s; s.push_back(State(root, 0, -1)); while(!s.empty()) { State t = s.back(); s.pop_back(); int i = t.i, len = t.len; int index = sortednodes.size(); int color = lengths.size(); if(t.parent == -3) { rights[i] = index; continue; } if(t.parent != -2) { assert(parents.size() == color); parents.push_back(t.parent); branches.push_back(len); offsets.push_back(index); len = 0; } colors[i] = color; positions[i] = len; lefts[i] = index; sortednodes.push_back(i); int maxsize = -1, maxj = -1; each(j, g[i]) if(colors[*j] == -1) { if(maxsize < subtreesizes[*j]) { maxsize = subtreesizes[*j]; maxj = *j; } parentnodes[*j] = i; depths[*j] = depths[i] + 1; } s.push_back(State(i, -1, -3)); if(maxj == -1) { lengths.push_back(len + 1); }else { each(j, g[i]) if(colors[*j] == -1 && *j != maxj) s.push_back(State(*j, len, color)); s.push_back(State(maxj, len + 1, -2)); } } } void get(int v, int &c, int &p) const { c = colors[v]; p = positions[v]; } bool go_up(int &c, int &p) const { p = branches[c]; c = parents[c]; return c != -1; } inline const int *nodesBegin(int c) const { return &sortednodes[0] + offsets[c]; } inline const int *nodesEnd(int c) const { return &sortednodes[0] + offsets[c+1]; } private: void measure(const vector<vi> &g, int root, vector<int> &out_subtreesizes) const { out_subtreesizes.assign(g.size(), -1); vector<int> s; s.push_back(root); while(!s.empty()) { int i = s.back(); s.pop_back(); if(out_subtreesizes[i] == -2) { int s = 1; each(j, g[i]) if(out_subtreesizes[*j] != -2) s += out_subtreesizes[*j]; out_subtreesizes[i] = s; }else { s.push_back(i); each(j, g[i]) if(out_subtreesizes[*j] == -1) s.push_back(*j); out_subtreesizes[i] = -2; } } } }; int lowest_common_ancestor(const CentroidPathDecomposition &cpd, int x, int y) { int cx, px, cy, py; cpd.get(x, cx, px); cpd.get(y, cy, py); while(cx != cy) { if(cpd.depths[*cpd.nodesBegin(cx)] < cpd.depths[*cpd.nodesBegin(cy)]) cpd.go_up(cy, py); else cpd.go_up(cx, px); } return cpd.nodesBegin(cx)[min(px, py)]; } int main() { int N; scanf("%d", &N); vector<Val> initval(N); rep(i, N) { int a, b; scanf("%d%d", &a, &b); initval[i] = LinearExpr(a, b); } vector<vi> g(N); rep(i, N-1) { int x, y; scanf("%d%d", &x, &y), -- x, -- y; g[x].push_back(y); g[y].push_back(x); } CentroidPathDecomposition cpd; cpd.build(g, 0); vector<Val> permutatedInitval(N); rep(i, N) permutatedInitval[i] = initval[cpd.sortednodes[i]]; SegmentTree segt; segt.init(permutatedInitval); vector<pii> path; int Q; scanf("%d", &Q); rep(ii, Q) { int ty; scanf("%d", &ty); if(ty == 1) { int u, v, a, b; scanf("%d%d%d%d", &u, &v, &a, &b), -- u, -- v; Laziness laziness(LinearExpr(a, b)); int w = lowest_common_ancestor(cpd, u, v), wc, wp; cpd.get(w, wc, wp); rep(uv, 2) { int c, p; cpd.get(uv == 0 ? u : v, c, p); while(1) { int top = c == wc ? wp + uv : 0; int o = cpd.offsets[c], len = cpd.lengths[c]; //???[o + top, o + p]????? (????) segt.addToRange(o + top, o + p + 1, laziness); if(c == wc) break; cpd.go_up(c, p); } } }else if(ty == 2) { int u, v, x; scanf("%d%d%d", &u, &v, &x), -- u, -- v; LinearExpr expr; int w = lowest_common_ancestor(cpd, u, v), wc, wp; cpd.get(w, wc, wp); rep(uv, 2) { path.clear(); int c, p; cpd.get(uv == 0 ? u : v, c, p); while(1) { int top = c == wc ? wp + uv : 0; int o = cpd.offsets[c], len = cpd.lengths[c]; //???[o + top, o + p]????? (????) path.push_back(mp(o + top, o + p)); if(c == wc) break; cpd.go_up(c, p); } if(uv == 0) { for(int i = 0; i < (int)path.size(); ++ i) { int top = path[i].first, bottom = path[i].second; expr += segt.getRange(top, bottom + 1).backward; } }else { for(int i = (int)path.size() - 1; i >= 0; -- i) { int top = path[i].first, bottom = path[i].second; expr += segt.getRange(top, bottom + 1).forward; } } } mint ans = expr.evalute(x); printf("%dn", ans.get()); }else return 1; } return 0; }
Problem solution in C programming.
#include <stdio.h> #include <stdlib.h> #include <string.h> typedef struct _lnode{ int x; int w; struct _lnode *next; } lnode; typedef struct _tree{ long long sum[2][2]; long long sum1[2][2]; long long offset1; long long offset2; } tree; #define MOD 1000000007 void insert_edge(int x,int y,int w); void dfs0(int u); void dfs1(int u,int c); void preprocess(); int lca(int a,int b); void sum(int v,int tl,int tr,int l,int r,tree *t,long long ans[][2],int f); void range_update(int v,int tl,int tr,int pos1,int pos2,long long o1,long long o2,tree *t); void merge(long long a[][2],long long b[][2],long long ans[][2]); void push(int v,int tl,int tr,tree *t); void range_solve(int x,int y,int a,int b); int min(int x,int y); int max(int x,int y); long long solve(int x,int y,int a); void one(long long*a,int SIZE); void mul(long long*a,long long*b,int SIZE); void powm(long long*a,int n,long long*res,int SIZE); int N,cn,A[100000],B[100000],level[100000],DP[18][100000],subtree_size[100000],special[100000],node_chain[100000],node_idx[100000],chain_head[100000],chain_len[100000]={0}; lnode *table[100000]={0}; tree *chain[100000]; int main(){ int Q,x,y,a,b,i; scanf("%d",&N); for(i=0;i<N;i++) scanf("%d%d",A+i,B+i); for(i=0;i<N-1;i++){ scanf("%d%d",&x,&y); insert_edge(x-1,y-1,1); } preprocess(); scanf("%d",&Q); while(Q--){ scanf("%d",&x); switch(x){ case 1: scanf("%d%d%d%d",&x,&y,&a,&b); range_solve(x-1,y-1,a,b); break; default: scanf("%d%d%d",&x,&y,&a); printf("%lldn",solve(x-1,y-1,a)); } } return 0; } void insert_edge(int x,int y,int w){ lnode *t=malloc(sizeof(lnode)); t->x=y; t->w=w; t->next=table[x]; table[x]=t; t=malloc(sizeof(lnode)); t->x=x; t->w=w; t->next=table[y]; table[y]=t; return; } void dfs0(int u){ lnode *x; subtree_size[u]=1; special[u]=-1; for(x=table[u];x;x=x->next) if(x->x!=DP[0][u]){ DP[0][x->x]=u; level[x->x]=level[u]+1; dfs0(x->x); subtree_size[u]+=subtree_size[x->x]; if(special[u]==-1 || subtree_size[x->x]>subtree_size[special[u]]) special[u]=x->x; } return; } void dfs1(int u,int c){ lnode *x; node_chain[u]=c; node_idx[u]=chain_len[c]++; for(x=table[u];x;x=x->next) if(x->x!=DP[0][u]) if(x->x==special[u]) dfs1(x->x,c); else{ chain_head[cn]=x->x; dfs1(x->x,cn++); } return; } void preprocess(){ int i,j; level[0]=0; DP[0][0]=0; dfs0(0); for(i=1;i<18;i++) for(j=0;j<N;j++) DP[i][j] = DP[i-1][DP[i-1][j]]; cn=1; chain_head[0]=0; dfs1(0,0); for(i=0;i<cn;i++){ chain[i]=(tree*)malloc(4*chain_len[i]*sizeof(tree)); memset(chain[i],0,4*chain_len[i]*sizeof(tree)); for(j=0;j<4*chain_len[i];j++) chain[i][j].offset1=chain[i][j].offset2=-1; } for(i=0;i<N;i++) range_update(1,0,chain_len[node_chain[i]]-1,node_idx[i],node_idx[i],A[i],B[i],chain[node_chain[i]]); return; } int lca(int a,int b){ int i; if(level[a]>level[b]){ i=a; a=b; b=i; } int d = level[b]-level[a]; for(i=0;i<18;i++) if(d&(1<<i)) b=DP[i][b]; if(a==b)return a; for(i=17;i>=0;i--) if(DP[i][a]!=DP[i][b]) a=DP[i][a],b=DP[i][b]; return DP[0][a]; } void sum(int v,int tl,int tr,int l,int r,tree *t,long long ans[][2],int f){ long long a[2][2],b[2][2]; push(v,tl,tr,t); if(l>r){ ans[0][0]=1; ans[0][1]=0; ans[1][0]=0; ans[1][1]=1; return; } if(l==tl && r==tr){ if(f) memcpy(ans,t[v].sum1,sizeof(t[v].sum1)); else memcpy(ans,t[v].sum,sizeof(t[v].sum)); return; } int tm=(tl+tr)/2; sum(v*2,tl,tm,l,min(r,tm),t,a,f); sum(v*2+1,tm+1,tr,max(l,tm+1),r,t,b,f); if(f) merge(b,a,ans); else merge(a,b,ans); return; } void range_update(int v,int tl,int tr,int pos1,int pos2,long long o1,long long o2,tree *t){ push(v,tl,tr,t); if(pos2<tl || pos1>tr) return; int tm=(tl+tr)/2; if(pos1<=tl && pos2>=tr){ t[v].offset1=o1; t[v].offset2=o2; } else{ range_update(v*2,tl,tm,pos1,pos2,o1,o2,t); range_update(v*2+1,tm+1,tr,pos1,pos2,o1,o2,t); push(v*2,tl,tm,t); push(v*2+1,tm+1,tr,t); merge(t[v*2].sum,t[v*2+1].sum,t[v].sum); merge(t[v*2+1].sum1,t[v*2].sum1,t[v].sum1); } return; } void merge(long long a[][2],long long b[][2],long long ans[][2]){ ans[0][0]=(a[0][0]*b[0][0]+a[0][1]*b[1][0])%MOD; ans[0][1]=(a[0][0]*b[0][1]+a[0][1]*b[1][1])%MOD; ans[1][0]=(a[1][0]*b[0][0]+a[1][1]*b[1][0])%MOD; ans[1][1]=(a[1][0]*b[0][1]+a[1][1]*b[1][1])%MOD; return; } void push(int v,int tl,int tr,tree *t){ long long a[2][2]; if(t[v].offset1==-1 || t[v].offset2==-1) return; a[0][0]=t[v].offset1; a[0][1]=t[v].offset2; a[1][0]=0; a[1][1]=1; powm(&a[0][0],tr-tl+1,&t[v].sum[0][0],2); memcpy(t[v].sum1,t[v].sum,sizeof(t[v].sum)); if(tl!=tr){ t[v*2].offset1=t[v*2+1].offset1=t[v].offset1; t[v*2].offset2=t[v*2+1].offset2=t[v].offset2; } t[v].offset1=t[v].offset2=-1; return; } void range_solve(int x,int y,int a,int b){ int ca=lca(x,y); while(node_chain[x]!=node_chain[ca]){ range_update(1,0,chain_len[node_chain[x]]-1,0,node_idx[x],a,b,chain[node_chain[x]]); x=DP[0][chain_head[node_chain[x]]]; } range_update(1,0,chain_len[node_chain[x]]-1,node_idx[ca],node_idx[x],a,b,chain[node_chain[x]]); while(node_chain[y]!=node_chain[ca]){ range_update(1,0,chain_len[node_chain[y]]-1,0,node_idx[y],a,b,chain[node_chain[y]]); y=DP[0][chain_head[node_chain[y]]]; } if(node_idx[y]!=node_idx[ca]) range_update(1,0,chain_len[node_chain[y]]-1,node_idx[ca]+1,node_idx[y],a,b,chain[node_chain[y]]); return; } int min(int x,int y){ return (x<y)?x:y; } int max(int x,int y){ return (x>y)?x:y; } long long solve(int x,int y,int a){ int ca=lca(x,y); long long t1[2][2],t2[2][2]={1,0,0,1},t3[2][2],ans[2][2]; while(node_chain[x]!=node_chain[ca]){ sum(1,0,chain_len[node_chain[x]]-1,0,node_idx[x],chain[node_chain[x]],t1,0); memcpy(t3,t2,sizeof(t2)); merge(t1,t3,t2); x=DP[0][chain_head[node_chain[x]]]; } sum(1,0,chain_len[node_chain[x]]-1,node_idx[ca],node_idx[x],chain[node_chain[x]],t1,0); memcpy(t3,t2,sizeof(t2)); merge(t1,t3,ans); t2[0][0]=1; t2[0][1]=0; t2[1][0]=0; t2[1][1]=1; while(node_chain[y]!=node_chain[ca]){ sum(1,0,chain_len[node_chain[y]]-1,0,node_idx[y],chain[node_chain[y]],t1,1); memcpy(t3,t2,sizeof(t2)); merge(t3,t1,t2); y=DP[0][chain_head[node_chain[y]]]; } if(node_idx[y]!=node_idx[ca]){ sum(1,0,chain_len[node_chain[y]]-1,node_idx[ca]+1,node_idx[y],chain[node_chain[y]],t1,1); memcpy(t3,t2,sizeof(t2)); merge(t3,t1,t2); } merge(t2,ans,t1); return (a*t1[0][0]+t1[0][1])%MOD; } void one(long long*a,int SIZE){ int i,j; for (i = 0; i < SIZE; i++) for (j = 0; j < SIZE; j++) a[i*SIZE+j] = (i == j); return; } void mul(long long*a,long long*b,int SIZE){ int i,j,k; long long res[SIZE][SIZE]; for(i=0;i<SIZE;i++) for(j=0;j<SIZE;j++) res[i][j]=0; for (i = 0; i < SIZE; i++) for (j = 0; j < SIZE; j++) for (k = 0; k < SIZE; k++) res[i][j] = (res[i][j]+a[i*SIZE+k] * b[k*SIZE+j])%MOD; for (i = 0; i < SIZE; i++) for (j = 0; j < SIZE; j++) a[i*SIZE+j] = res[i][j]; return; } void powm(long long*a,int n,long long*res,int SIZE){ one(res,SIZE); while (n > 0) { if (n % 2 == 0) { mul(a, a,SIZE); n /= 2; } else { mul(res, a,SIZE); n--; } } }