Skip to content
Programming101
Programming101

Learn everything about programming

  • Home
  • CS Subjects
    • IoT – Internet of Things
    • Digital Communication
    • Human Values
  • Programming Tutorials
    • C Programming
    • Data structures and Algorithms
    • 100+ Java Programs
    • 100+ C Programs
  • HackerRank Solutions
    • HackerRank Algorithms Solutions
    • HackerRank C problems solutions
    • HackerRank C++ problems solutions
    • HackerRank Java problems solutions
    • HackerRank Python problems solutions
Programming101
Programming101

Learn everything about programming

HackerRank Fibonacci Numbers Tree problem solution

YASH PAL, 31 July 2024

In this HackerRank Fibonacci Numbers Tree problem we have given the configuration for the tree and a list of operations, perform all the operations efficiently.

HackerRank Fibonacci Numbers Tree problem solution

Problem solution in Java Programming.

import java.io.ByteArrayInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.PrintWriter;
import java.util.Arrays;
import java.util.InputMismatchException;

public class H {
    InputStream is;
    PrintWriter out;
    String INPUT = "";
    
    void solve()
    {
        int n = ni();
        int Q = ni();
        par = new int[n];
        for(int i = 1;i < n;i++){
            par[i] = ni()-1;
        }
        par[0] = -1;
        int[][] g = parentToG(par);
        int[][] pars = parents3(g, 0);
        int[] ord = pars[1], dep = pars[2];
        clus = decomposeToHeavyLight(g,par, ord);
        cluspath = clusPaths(clus, ord);
        clusiind = clusIInd(cluspath, n);
        int m = cluspath.length;
        sts = new SegmentTreeMatrix[m];
        for(int i = 0;i < m;i++){
            sts[i] = new SegmentTreeMatrix(cluspath[i].length);
        }
        int[][] spar = logstepParents(par);
        
        for(int z = 0;z < Q;z++){
            char type = nc();
            if(type == 'U'){
                int x = ni()-1;
                long K = nl();
                sts[clus[x]].update(clusiind[x], K);
            }else{
                int mod = 1000000007;
                int x = ni()-1, y = ni()-1;
                int lca = lca2(x, y, spar, dep);
                long ret = d(x)+d(y)-d(lca)-d(par[lca]);
                ret %= mod;
                if(ret < 0)ret += mod;
                out.println(ret);
            }
        }
    }
    
    int[] par;
    int[] clus, clusiind;
    int[][] cluspath;
    SegmentTreeMatrix[] sts;
    
    long d(int x)
    {
        if(x == -1)return 0;
        int[] lcx = new int[100];
        int[] lto = new int[100];
        int p = 0;
        int cx = clus[x];
        int ind = clusiind[x];
        while (true) {
            lcx[p] = cx;
            lto[p] = ind+1;
            p++;
            int con = par[cluspath[cx][0]];
            if(con == -1)break;
            ind = clusiind[con];
            cx = clus[con];
        }
        int[] v = {0, 0, 1, 0};

        for(int i = p-1;i >= 0;i--){
            v = sts[lcx[i]].apply(0, lto[i], v);
        }
        return v[3];
    }
    
    public static class SegmentTreeMatrix {
        public int M, H, N;
        public int[][][] node;
        public static int mod = 1000000007;
        public static long BIG = 8L*mod*mod;
        public static int S = 4;
        
        public SegmentTreeMatrix(int n)
        {
            N = n;
            M = Integer.highestOneBit(Math.max(N-1, 1))<<2;
            H = M>>>1;
            
            node = new int[M][][];
            for(int i = 0;i < N;i++){
                node[H+i] = new int[S][S];
                node[H+i][0][0] = 1;
                node[H+i][0][1] = 1;
                node[H+i][1][0] = 1;
                node[H+i][3][1] = 1;
                node[H+i][2][2] = 1;
                node[H+i][3][3] = 1;
                node[H+i][3][0] = 1;
            }
            
            for(int i = H-1;i >= 1;i--)propagate(i);
        }
        
        private void propagate(int cur)
        {
            node[cur] = prop2(node[2*cur], node[2*cur+1], node[cur]);
        }
        
        private int[][] prop2(int[][] L, int[][] R, int[][] C)
        {
            if(L != null && R != null){
                C = mul(R, L, C, mod);
                return C;
            }else if(L != null){
                return prop1(L, C);
            }else if(R != null){
                return prop1(R, C);
            }else{
                return null;
            }
        }
        
        private int[][] prop1(int[][] L, int[][] C)
        {
            if(C == null){
//                C = L; // read only
                C = new int[S][];
                for(int i = 0;i < S;i++){
                    C[i] = Arrays.copyOf(L[i], S);
                }
            }else{
                for(int i = 0;i < S;i++){
                    C[i] = Arrays.copyOf(L[i], S);
                }
            }
            return C;
        }
        
        public void update(int pos, long x) {
            int[][] M = {{1, 1}, {1, 0}};
            int[] v = {1, 0};
            v = pow(M, v, x-1);
            node[H+pos][0][2] += v[0];
            if(node[H+pos][0][2] >= mod)node[H+pos][0][2] -= mod;
            node[H+pos][1][2] += v[1];
            if(node[H+pos][1][2] >= mod)node[H+pos][1][2] -= mod;
            node[H+pos][3][2] += v[0];
            if(node[H+pos][3][2] >= mod)node[H+pos][3][2] -= mod;
            for(int i = H+pos>>>1;i >= 1;i>>>=1)propagate(i);
        }
        
        public int[] apply(int l, int r, int[] v){
            return apply(l, r, 0, H, 1, v);
        }
        
        protected int[] apply(int l, int r, int cl, int cr, int cur, int[] v)
        {
            if(l <= cl && cr <= r){
                return mul(node[cur], v, mod);
            }else{
                int mid = cl+cr>>>1;
                if(cl < r && l < mid){
                    v = apply(l, r, cl, mid, 2*cur, v);
                }
                if(mid < r && l < cr){
                    v = apply(l, r, mid, cr, 2*cur+1, v);
                }
                return v;
            }
        }
        
        
        public static int[] mul(int[][] A, int[] v, int mod)
        {
            int m = A.length;
            int n = v.length;
            int[] w = new int[m];
            for(int i = 0;i < m;i++){
                long sum = 0;
                for(int k = 0;k < n;k++){
                    sum += (long)A[i][k] * v[k];
                    if(sum >= BIG)sum -= BIG;
                }
                w[i] = (int)(sum % mod);
            }
            return w;
        }
        
        public static int[][] mul(int[][] A, int[][] B, int[][] C, int mod)
        {
            int m = A.length;
            int n = A[0].length;
            int o = B[0].length;
            if(C == null)C = new int[m][o];
            for(int i = 0;i < m;i++){
                for(int j = 0;j < o;j++){
                    long sum = 0;
                    for(int k = 0;k < n;k++){
                        sum += (long)A[i][k] * B[k][j];
                        if(sum >= BIG)sum -= BIG;
                    }
                    sum %= mod;
                    C[i][j] = (int)sum;
                }
            }
            return C;
        }
        
        // A^e*v
        public static int[] pow(int[][] A, int[] v, long e)
        {
            for(int i = 0;i < v.length;i++){
                if(v[i] >= mod)v[i] %= mod;
            }
            int[][] MUL = A;
            for(;e > 0;e>>>=1) {
                if((e&1)==1)v = mul(MUL, v);
                MUL = p2(MUL);
            }
            return v;
        }
        
        // int matrix*int vector
        public static int[] mul(int[][] A, int[] v)
        {
            int m = A.length;
            int n = v.length;
            int[] w = new int[m];
            for(int i = 0;i < m;i++){
                long sum = 0;
                for(int k = 0;k < n;k++){
                    sum += (long)A[i][k] * v[k];
                    if(sum >= BIG)sum -= BIG;
                }
                w[i] = (int)(sum % mod);
            }
            return w;
        }
        
        // int matrix^2 (be careful about negative value)
        public static int[][] p2(int[][] A)
        {
            int n = A.length;
            int[][] C = new int[n][n];
            for(int i = 0;i < n;i++){
                long[] sum = new long[n];
                for(int k = 0;k < n;k++){
                    for(int j = 0;j < n;j++){
                        sum[j] += (long)A[i][k] * A[k][j];
                        if(sum[j] >= BIG)sum[j] -= BIG;
                    }
                }
                for(int j = 0;j < n;j++){
                    C[i][j] = (int)(sum[j] % mod);
                }
            }
            return C;
        }

    }    
    
    public static int lca2(int a, int b, int[][] spar, int[] depth) {
        if (depth[a] < depth[b]) {
            b = ancestor(b, depth[b] - depth[a], spar);
        } else if (depth[a] > depth[b]) {
            a = ancestor(a, depth[a] - depth[b], spar);
        }

        if (a == b)
            return a;
        int sa = a, sb = b;
        for (int low = 0, high = depth[a], t = Integer.highestOneBit(high), k = Integer
                .numberOfTrailingZeros(t); t > 0; t >>>= 1, k--) {
            if ((low ^ high) >= t) {
                if (spar[k][sa] != spar[k][sb]) {
                    low |= t;
                    sa = spar[k][sa];
                    sb = spar[k][sb];
                } else {
                    high = low | t - 1;
                }
            }
        }
        return spar[0][sa];
    }

    protected static int ancestor(int a, int m, int[][] spar) {
        for (int i = 0; m > 0 && a != -1; m >>>= 1, i++) {
            if ((m & 1) == 1)
                a = spar[i][a];
        }
        return a;
    }

    public static int[][] logstepParents(int[] par) {
        int n = par.length;
        int m = Integer.numberOfTrailingZeros(Integer.highestOneBit(n - 1)) + 1;
        int[][] pars = new int[m][n];
        pars[0] = par;
        for (int j = 1; j < m; j++) {
            for (int i = 0; i < n; i++) {
                pars[j][i] = pars[j - 1][i] == -1 ? -1
                        : pars[j - 1][pars[j - 1][i]];
            }
        }
        return pars;
    }

    
    public static int[] decomposeToHeavyLight(int[][] g, int[] par, int[] ord) {
        int n = g.length;
        int[] size = new int[n];
        Arrays.fill(size, 1);
        for (int i = n - 1; i > 0; i--)
            size[par[ord[i]]] += size[ord[i]];

        int[] clus = new int[n];
        Arrays.fill(clus, -1);
        int p = 0;
        outer: for (int i = 0; i < n; i++) {
            int u = ord[i];
            if (clus[u] == -1)
                clus[u] = p++;
            for (int v : g[u]) {
                if (par[u] != v && size[v] >= size[u] / 2) {
                    clus[v] = clus[u];
                    continue outer;
                }
            }
            for (int v : g[u]) {
                if (par[u] != v) {
                    clus[v] = clus[u];
                    break;
                }
            }
        }
        return clus;
    }

    public static int[][] clusPaths(int[] clus, int[] ord) {
        int n = clus.length;
        int[] rp = new int[n];
        int sup = 0;
        for (int i = 0; i < n; i++) {
            rp[clus[i]]++;
            sup = Math.max(sup, clus[i]);
        }
        sup++;

        int[][] row = new int[sup][];
        for (int i = 0; i < sup; i++)
            row[i] = new int[rp[i]];

        for (int i = n - 1; i >= 0; i--) {
            row[clus[ord[i]]][--rp[clus[ord[i]]]] = ord[i];
        }
        return row;
    }

    public static int[] clusIInd(int[][] clusPath, int n) {
        int[] iind = new int[n];
        for (int[] path : clusPath) {
            for (int i = 0; i < path.length; i++) {
                iind[path[i]] = i;
            }
        }
        return iind;
    }
    
    public static int[][] parentToG(int[] par)
    {
        int n = par.length;
        int[] ct = new int[n];
        for(int i = 0;i < n;i++){
            if(par[i] >= 0){
                ct[i]++;
                ct[par[i]]++;
            }
        }
        int[][] g = new int[n][];
        for(int i = 0;i < n;i++){
            g[i] = new int[ct[i]];
        }
        for(int i = 0;i < n;i++){
            if(par[i] >= 0){
                g[par[i]][--ct[par[i]]] = i;
                g[i][--ct[i]] = par[i];
            }
        }
        return g;
    }


    public static int[][] parents3(int[][] g, int root) {
        int n = g.length;
        int[] par = new int[n];
        Arrays.fill(par, -1);

        int[] depth = new int[n];
        depth[0] = 0;

        int[] q = new int[n];
        q[0] = root;
        for (int p = 0, r = 1; p < r; p++) {
            int cur = q[p];
            for (int nex : g[cur]) {
                if (par[cur] != nex) {
                    q[r++] = nex;
                    par[nex] = cur;
                    depth[nex] = depth[cur] + 1;
                }
            }
        }
        return new int[][] { par, q, depth };
    }

    static int[][] packU(int n, int[] from, int[] to) {
        int[][] g = new int[n][];
        int[] p = new int[n];
        for (int f : from)
            p[f]++;
        for (int t : to)
            p[t]++;
        for (int i = 0; i < n; i++)
            g[i] = new int[p[i]];
        for (int i = 0; i < from.length; i++) {
            g[from[i]][--p[from[i]]] = to[i];
            g[to[i]][--p[to[i]]] = from[i];
        }
        return g;
    }
    
    void run() throws Exception
    {
        is = INPUT.isEmpty() ? System.in : new ByteArrayInputStream(INPUT.getBytes());
        out = new PrintWriter(System.out);
        
        long s = System.currentTimeMillis();
        solve();
        out.flush();
        if(!INPUT.isEmpty())tr(System.currentTimeMillis()-s+"ms");
    }
    
    public static void main(String[] args) throws Exception { new H().run(); }
    
    private byte[] inbuf = new byte[1024];
    private int lenbuf = 0, ptrbuf = 0;
    
    private int readByte()
    {
        if(lenbuf == -1)throw new InputMismatchException();
        if(ptrbuf >= lenbuf){
            ptrbuf = 0;
            try { lenbuf = is.read(inbuf); } catch (IOException e) { throw new InputMismatchException(); }
            if(lenbuf <= 0)return -1;
        }
        return inbuf[ptrbuf++];
    }
    
    private boolean isSpaceChar(int c) { return !(c >= 33 && c <= 126); }
    private int skip() { int b; while((b = readByte()) != -1 && isSpaceChar(b)); return b; }
    
    private double nd() { return Double.parseDouble(ns()); }
    private char nc() { return (char)skip(); }
    
    private String ns()
    {
        int b = skip();
        StringBuilder sb = new StringBuilder();
        while(!(isSpaceChar(b))){ // when nextLine, (isSpaceChar(b) && b != ' ')
            sb.appendCodePoint(b);
            b = readByte();
        }
        return sb.toString();
    }
    
    private char[] ns(int n)
    {
        char[] buf = new char[n];
        int b = skip(), p = 0;
        while(p < n && !(isSpaceChar(b))){
            buf[p++] = (char)b;
            b = readByte();
        }
        return n == p ? buf : Arrays.copyOf(buf, p);
    }
    
    private char[][] nm(int n, int m)
    {
        char[][] map = new char[n][];
        for(int i = 0;i < n;i++)map[i] = ns(m);
        return map;
    }
    
    private int[] na(int n)
    {
        int[] a = new int[n];
        for(int i = 0;i < n;i++)a[i] = ni();
        return a;
    }
    
    private int ni()
    {
        int num = 0, b;
        boolean minus = false;
        while((b = readByte()) != -1 && !((b >= '0' && b <= '9') || b == '-'));
        if(b == '-'){
            minus = true;
            b = readByte();
        }
        
        while(true){
            if(b >= '0' && b <= '9'){
                num = num * 10 + (b - '0');
            }else{
                return minus ? -num : num;
            }
            b = readByte();
        }
    }
    
    private long nl()
    {
        long num = 0;
        int b;
        boolean minus = false;
        while((b = readByte()) != -1 && !((b >= '0' && b <= '9') || b == '-'));
        if(b == '-'){
            minus = true;
            b = readByte();
        }
        
        while(true){
            if(b >= '0' && b <= '9'){
                num = num * 10 + (b - '0');
            }else{
                return minus ? -num : num;
            }
            b = readByte();
        }
    }
    
    private static void tr(Object... o) { System.out.println(Arrays.deepToString(o)); }
}

Problem solution in C++ programming.

#include <cstdio>
#include <cassert>
#include <cstdlib>
#include <cstring>
#include <ctime>
#include <iostream>
#include <algorithm>
#include <vector>
#include <map>
#include <set>
#include <climits>
#include <cctype>
#include <utility>
#include <queue>
#include <cmath>
#include <complex>
using namespace std;

typedef long long LL;
typedef pair<int, int> PII;
typedef vector<int> VI;
typedef vector<PII> VPII;
typedef pair<LL, LL> PLL;
typedef pair<int, LL> PIL;
typedef pair<LL, int> PLI;
typedef double DB;

#define pb push_back
#define mset(a, b) memset(a, b, sizeof a)
#define all(x) (x).begin(), (x).end()
#define bit(x) (1 << (x))
#define bitl(x) (1LL << (x))
#define sqr(x) ((x) * (x))
#define sz(x) ((int)(x.size()))
#define cnti(x) (__builtin_popcount(x))
#define cntl(x) (__builtin_popcountll(x))
#define clzi(x) (__builtin_clz(x))
#define clzl(x) (__builtin_clzll(x))
#define ctzi(x) (__builtin_ctz(x))
#define ctzl(x) (__builtin_ctzll(x))

#define X first
#define Y second

#define Error(x) cout << #x << " = " << x << endl

template <typename T, typename U> 
inline void chkmax(T& x, U y) {
    if (x < y) x = y;
}

template <typename T, typename U>
inline void chkmin(T& x, U y) {
    if (y < x) x = y;
}

const int MOD = 1e9 + 7;
const int MAXN = 111111;

int n, m;

VI adj[MAXN];
int f[17][MAXN];
int d[MAXN], st[MAXN], en[MAXN];
int b[3][MAXN];
int c[55][2][2];

int T;

void dfs(int u) {
    st[u] = ++T;
    for (int i = 0; i < sz(adj[u]); i++) {
        d[adj[u][i]] = d[u] + 1;
        dfs(adj[u][i]);
    }
    en[u] = T;
}

void add(int p, int id, int x) {
    for (; p <= n; p += p & -p) {
        b[id][p] = (b[id][p] + x) % MOD;
    }
}

int get_sum(int p, int id) {
    int ret = 0;
    for (; p; p -= p & -p) {
        ret = (ret + b[id][p]) % MOD;
    }
    return ret;
}

PII get(LL e) {
    int a[2], _a[2]; a[0] = 1, a[1] = 0;
    for (int i = 50; i >= 0; i--) {
        if (e >> i & 1) {
            for (int j = 0; j < 2; j++) {
                LL s = 0;
                for (int k = 0; k < 2; k++) {
                    s += (LL)a[k] * c[i][j][k];
                }
                _a[j] = s % MOD;
            }
            memcpy(a, _a, sizeof a);
        }
    }
    return make_pair(a[0], a[1]);
}

inline int lca(int u, int v) {
  if (d[u] < d[v]) swap(u, v);
  for (int i = 16; i >= 0; i--) {
    if (d[f[i][u]] < d[v]) continue;
    u = f[i][u];
    if (d[u] == d[v]) break;
  }
  if (u == v) return u;
  for (int i = 16; i >= 0; i--) {
    if (f[i][u] == f[i][v]) continue;
    u = f[i][u], v = f[i][v];
  }
  return f[0][u];
}

void change(int u, LL k) {
    k += 2*(MOD+1);
    PII x = get(k), y = get(k - d[u] + 1);
    add(st[u], 0, MOD-x.first);
    add(st[u], 1, y.first);
    add(st[u], 2, y.second);
    add(en[u]+1, 0, x.first);
    add(en[u]+1, 1, MOD-y.first);
    add(en[u]+1, 2, MOD-y.second);
}

int query(int u) {
    if (!u) {
        return 0;
    }
    PII x = get(d[u]);
    LL s0 = get_sum(st[u], 0), s1 = get_sum(st[u], 1), s2 = get_sum(st[u], 2);
    return (s0 + s1 * x.first + s2 * x.second) % MOD;
}

int main() {
    
    for (int i = 0; i < 2; i++) {
        for (int j = 0; j < 2; j++) {
            c[0][i][j] = !(i & j);
        }
    }
    for (int e = 1; e <= 50; e++) {
        for (int i = 0; i < 2; i++) {
            for (int j = 0; j < 2; j++) {
                LL s = 0;
                for (int k = 0; k < 2; k++) {
                    s += (LL)c[e-1][i][k] * c[e-1][k][j];
                }
                c[e][i][j] = s % MOD;
            }
        }
    }
    scanf("%d%d", &n, &m);
    for (int i = 2; i <= n; i++) {
        int x; scanf("%d", &x);
        adj[x].push_back(i), f[0][i] = x;
    }
    d[1] = 1;
    dfs(1);
    for (int i = 1; i <= 16; i++) {
        for (int j = 1; j <= n; j++) {
            f[i][j] = f[i-1][f[i-1][j]];
        }
    }
    memset(b, 0, sizeof b);
    for (int i = 0; i < m; i++) {
        char t[5]; int u;
        scanf("%s%d", t, &u);
        if (t[0] == 'U') {
            LL k; scanf("%lld", &k);
            change(u, k);
        } else {
            int v; scanf("%d", &v);
            int w = lca(u, v), p = f[0][w], ans = 0;
            ans = (query(u)+query(v)) % MOD;
            ans = (ans+MOD-query(w)) % MOD;
            ans = (ans+MOD-query(p)) % MOD;
            printf("%dn", ans);
        }
    }
    return 0;
}

Problem solution in C programming.

#include <stdio.h>
#include <string.h>
#include <stdlib.h>

#define prime_base 1000000007
#define fib_cycle_mag 2000000016L

static inline int mod_prime_base(int self) { 
    return (self % prime_base) + (prime_base & (self >> 31));
}

static inline int mod_long_prime(long self) {
    return (int)((self % (long)prime_base) + ((long)prime_base & (self >> 63)));
}

static inline unsigned mod_fib_cycle(long self) {
    return (unsigned)((self % fib_cycle_mag) + (fib_cycle_mag & (self >> 63)));
}


unsigned cached_fib[1 << 20] = {0};
#define array_cnt(self) (sizeof(self)/sizeof((self)[0]))

void init_cached_fib() {
    unsigned at, prev = 1, next = 1;
    for (at = 1; at < array_cnt(cached_fib); at++) {
        cached_fib[at] = prev;
        prev = next;
        next += cached_fib[at];
        next %= prime_base;
    }
}

unsigned fib(unsigned self) {
    if (self < array_cnt(cached_fib))
        return cached_fib[self];

    if (self & 1U) { 
        self = (self + 1) >> 1;
        unsigned long
            self_fib = fib(self),
            other_fib = fib(self - 1);

        return (unsigned)((self_fib * self_fib + other_fib * other_fib) % prime_base);
    }

    unsigned long self_fib = fib(self >>= 1); 
    return (unsigned)((self_fib * ((fib(self - 1) << 1) + self_fib)) % prime_base);
}

typedef int sum_t;

static inline sum_t bit_sum(unsigned length, sum_t self[length], unsigned node) {
    sum_t total = 0;
    for (; node; node &= node - 1)
        total = mod_prime_base(total + self[node]);
    return total;
}

static inline void bit_update(unsigned length, sum_t self[length], unsigned node, sum_t delta) {
    for (; node < length; node += node & -node)
        self[node] = mod_prime_base(self[node] + delta);
}

static inline void bit_point_range_update(unsigned length, sum_t sums[length], unsigned low, unsigned high, sum_t delta) {
    bit_update(length, sums, low, delta);
    bit_update(length, sums, high + 1, mod_prime_base(-delta));
}

static inline unsigned log2_floor(unsigned self) {
    static unsigned debruijn[] = {
        0, 9, 1, 10, 13, 21, 2, 29, 11, 14, 16, 18, 22, 25, 3, 30,
        8, 12, 20, 28, 15, 17, 24, 7, 19, 27, 23, 6, 26, 5, 4, 31
    };

    self |= self >> 1;
    self |= self >> 2;
    self |= self >> 4;
    self |= self >> 8;
    self |= self >> 16;

    return debruijn[(self * 0x07C4ACDDU) >> 27];
}

unsigned nearest_common_ancestor(
    unsigned vertex_cnt,
    unsigned log2_dists,
    unsigned ancestors[log2_dists][vertex_cnt],
    unsigned *depths,
    unsigned source,
    unsigned target
) {
    if (depths[source] < depths[target]) {
        source ^= target;
        target ^= source;
        source ^= target;
    }

    for (; depths[target] < depths[source]; source = ancestors[log2_floor(depths[source] - depths[target])][source]);

    if (source == target)
        return source;

    unsigned low = 0;
    while ((depths[source] - low) > 1) {
        unsigned mid = log2_floor((depths[source] - low) >> 1);

        if (ancestors[mid][source] != ancestors[mid][target]) {
            source = ancestors[mid][source];
            target = ancestors[mid][target];
        } else
            low = depths[ancestors[mid][source]];
    }

    return ancestors[0][source];
}


static inline sum_t tree_path_sum(unsigned length, sum_t sums[3][length], unsigned *depths, unsigned *ids, unsigned at) {
    return mod_long_prime(
          (long)cached_fib[depths[at]] * bit_sum(length, sums[0], ids[at])
        + (long)cached_fib[depths[at] + 1] * bit_sum(length, sums[1], ids[at])
        - bit_sum(length, sums[2], ids[at])
    );
}

int main() {
    init_cached_fib();

    unsigned at, vertex_cnt, queries;
    scanf("%u %u", &vertex_cnt, &queries);

    unsigned
        log2_dists = log2_floor(vertex_cnt) + 1,
        ancestors[log2_dists][vertex_cnt + 1],
        *descendants = memset(
            malloc(5 * (vertex_cnt + 1) * sizeof(unsigned)), 0, sizeof(unsigned) * ((vertex_cnt + 1) << 1)
        ),
        *indices = &descendants[vertex_cnt + 1],
        *weights = &indices[vertex_cnt + 1],
        *ids     = &weights[vertex_cnt + 1],
        *depths  = &ids[vertex_cnt + 1];


    ancestors[0][1] = 0;
    indices[0] = 1;
    for (at = 2; at <= vertex_cnt; indices[ancestors[0][at++]]++)
        scanf("%u", &ancestors[0][at]);

    for (at = 0; at < log2_dists; ancestors[at++][0] = 0);
    for (at = 0; ++at <= vertex_cnt; indices[at] += indices[at - 1]);
    for (at = vertex_cnt + 1; --at; descendants[--indices[ancestors[0][at]]] = at);

    {
        unsigned history[vertex_cnt + 1];
        memset(weights, 0, sizeof(weights[0]) * sizeof(unsigned));

        depths[0] = 0;
        history[0] = 0;
        for (at = 1; at;) {
            unsigned
                root = history[at - 1],
                others = indices[root];

            if (weights[root]) {
                for (; ancestors[0][descendants[others]] == root; weights[root] += weights[descendants[others++]]);
                at--;
            } else
                for (weights[root] = 1; ancestors[0][descendants[others]] == root; 
                     history[at++] = descendants[others++])
                    depths[descendants[others]] = depths[root] + 1;
        }

        ids[0] = 0;
        history[0] = 0;
        for (at = 1; at;) {
            unsigned
                weight = 1,
                root = history[--at],
                others = indices[root];

            for (; ancestors[0][descendants[others]] == root; weight += weights[descendants[others++]]) {
                history[at++] = descendants[others];
                ids[descendants[others]] = ids[root] + weight;
            }
        }
    }

    unsigned level;
    for (level = 1; level < log2_dists; level++)
        for (at = 0; ++at <= vertex_cnt;
             ancestors[level][at] = ancestors[level - 1][ancestors[level - 1][at]]);

    sum_t sums[3][++vertex_cnt];
    memset(sums, 0, sizeof(sums));

    long kth_fib;
    for (getchar(); queries--; getchar())
        if (getchar() == 'U') {
            scanf("%u %ld", &at, &kth_fib);

            long kth_fibs[2] = {
                fib(mod_fib_cycle(kth_fib)),
                fib(mod_fib_cycle(kth_fib + 1))
            }, *low_fibs = &(((long [3]){
                [2] = fib(depths[at]),
                [1] = fib(depths[at] - 1),
                [0] = fib(((depths[at] > 1) ? (depths[at] - 2) : 1))
            })[2]);

            #define odd_mask(self) (((int)((self) << 31) >> 30) | 1)

            bit_point_range_update(
                vertex_cnt,
                sums[0],
                ids[at],
                ids[at] + weights[at] - 1,
                mod_long_prime(odd_mask(depths[at]) * (kth_fibs[1] * low_fibs[-1UL] - kth_fibs[0] * low_fibs[0]))
            );

            bit_point_range_update(
                vertex_cnt,
                sums[1],
                ids[at],
                ids[at] + weights[at] - 1,
                mod_long_prime(odd_mask(depths[at]) * (kth_fibs[0] * low_fibs[-1UL] - kth_fibs[1] * low_fibs[-2UL]))
            );

            bit_point_range_update(
                vertex_cnt,
                sums[2],
                ids[at],
                ids[at] + weights[at] - 1,
                (unsigned)kth_fibs[1]
            );

        } else {
            unsigned source, target, common;
            scanf("%u %u", &source, &target);

            common = nearest_common_ancestor(vertex_cnt, log2_dists, ancestors, depths, source, target);
            printf("%dn",
                mod_long_prime(
                    (long)tree_path_sum(vertex_cnt, sums, depths, ids, source)
                        + tree_path_sum(vertex_cnt, sums, depths, ids, target)
                        - tree_path_sum(vertex_cnt, sums, depths, ids, common)
                        - tree_path_sum(vertex_cnt, sums, depths, ids, ancestors[0][common])
                )
            );
        }
    
    free(descendants);

    return 0;
}

coding problems data structure

Post navigation

Previous post
Next post
  • HackerRank Separate the Numbers solution
  • How AI Is Revolutionizing Personalized Learning in Schools
  • GTA 5 is the Game of the Year for 2024 and 2025
  • Hackerrank Day 5 loops 30 days of code solution
  • Hackerrank Day 6 Lets Review 30 days of code solution
How to download udemy paid courses for free

Pages

  • About US
  • Contact US
  • Privacy Policy

Programing Practice

  • C Programs
  • java Programs

HackerRank Solutions

  • C
  • C++
  • Java
  • Python
  • Algorithm

Other

  • Leetcode Solutions
  • Interview Preparation

Programming Tutorials

  • DSA
  • C

CS Subjects

  • Digital Communication
  • Human Values
  • Internet Of Things
©2025 Programming101 | WordPress Theme by SuperbThemes