HackerRank Number Game on a Tree problem solution

In this HackerRank Number Game on a Tree problem, we have given the number of games and number of nodes in the tree and we need to find the number of unordered pairs we can choose to construct a list to win the game.

HackerRank Number Game on a Tree problem solution

Problem solution in Java Programming.

import java.io.ByteArrayInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.PrintWriter;
import java.util.Arrays;
import java.util.InputMismatchException;
import java.util.Random;

public class D2 {
    InputStream is;
    PrintWriter out;
    String INPUT = "";
    
    Random gen = new Random();
    long[] zh = gen.longs(500005).toArray();
    
    void solve()
    {
        for(int T = ni();T > 0;T--){
            int n = ni();
            int[] from = new int[n - 1];
            int[] to = new int[n - 1];
            int[] w = new int[n-1];
            for (int i = 0; i < n - 1; i++) {
                from[i] = ni() - 1;
                to[i] = ni() - 1;
                w[i] = ni();
            }
            w = shrink(w);
            
            int[][][] g = packWU(n, from, to, w);
            int[][] pars = parents(g, 0);
            int[] par = pars[0], ord = pars[1], dep = pars[2];
            int[] pw = pars[4];
            
            long[] dp = new long[n];
            for(int i = 1;i < n;i++){
                int cur = ord[i];
                dp[cur] = dp[par[cur]] ^ zh[pw[cur]];
            }
            Arrays.sort(dp);
            long ret = (long)n*(n-1)/2;
            for(int i = 0;i < n;){
                int j = i;
                while(j < n && dp[i] == dp[j])j++;
                ret -= (long)(j-i)*(j-i-1)/2;
                
                i = j;
            }
            out.println(ret);
        }
    }
    
    public static int[] shrink(int[] a) {
        int n = a.length;
        long[] b = new long[n];
        for (int i = 0; i < n; i++)
            b[i] = (long) a[i] << 32 | i;
        Arrays.sort(b);
        int[] ret = new int[n];
        int p = 0;
        for (int i = 0; i < n; i++) {
            if (i > 0 && (b[i] ^ b[i - 1]) >> 32 != 0)
                p++;
            ret[(int) b[i]] = p;
        }
        return ret;
    }


    public static int[][] parents(int[][][] g, int root) {
        int n = g.length;
        int[] par = new int[n];
        Arrays.fill(par, -1);
        int[] dw = new int[n];
        int[] pw = new int[n];
        int[] dep = new int[n];

        int[] q = new int[n];
        q[0] = root;
        for (int p = 0, r = 1; p < r; p++) {
            int cur = q[p];
            for (int[] nex : g[cur]) {
                if (par[cur] != nex[0]) {
                    q[r++] = nex[0];
                    par[nex[0]] = cur;
                    dep[nex[0]] = dep[cur] + 1;
                    dw[nex[0]] = dw[cur] + nex[1];
                    pw[nex[0]] = nex[1];
                }
            }
        }
        return new int[][] { par, q, dep, dw, pw };
    }


    public static int[][][] packWU(int n, int[] from, int[] to, int[] w) {
        int[][][] g = new int[n][][];
        int[] p = new int[n];
        for (int f : from)
            p[f]++;
        for (int t : to)
            p[t]++;
        for (int i = 0; i < n; i++)
            g[i] = new int[p[i]][2];
        for (int i = 0; i < from.length; i++) {
            --p[from[i]];
            g[from[i]][p[from[i]]][0] = to[i];
            g[from[i]][p[from[i]]][1] = w[i];
            --p[to[i]];
            g[to[i]][p[to[i]]][0] = from[i];
            g[to[i]][p[to[i]]][1] = w[i];
        }
        return g;
    }

    
    void run() throws Exception
    {
        is = INPUT.isEmpty() ? System.in : new ByteArrayInputStream(INPUT.getBytes());
        out = new PrintWriter(System.out);
        
        long s = System.currentTimeMillis();
        solve();
        out.flush();
        if(!INPUT.isEmpty())tr(System.currentTimeMillis()-s+"ms");
    }
    
    public static void main(String[] args) throws Exception { new D2().run(); }
    
    private byte[] inbuf = new byte[1024];
    public int lenbuf = 0, ptrbuf = 0;
    
    private int readByte()
    {
        if(lenbuf == -1)throw new InputMismatchException();
        if(ptrbuf >= lenbuf){
            ptrbuf = 0;
            try { lenbuf = is.read(inbuf); } catch (IOException e) { throw new InputMismatchException(); }
            if(lenbuf <= 0)return -1;
        }
        return inbuf[ptrbuf++];
    }
    
    private boolean isSpaceChar(int c) { return !(c >= 33 && c <= 126); }
    private int skip() { int b; while((b = readByte()) != -1 && isSpaceChar(b)); return b; }
    
    private double nd() { return Double.parseDouble(ns()); }
    private char nc() { return (char)skip(); }
    
    private String ns()
    {
        int b = skip();
        StringBuilder sb = new StringBuilder();
        while(!(isSpaceChar(b))){ // when nextLine, (isSpaceChar(b) && b != ' ')
            sb.appendCodePoint(b);
            b = readByte();
        }
        return sb.toString();
    }
    
    private char[] ns(int n)
    {
        char[] buf = new char[n];
        int b = skip(), p = 0;
        while(p < n && !(isSpaceChar(b))){
            buf[p++] = (char)b;
            b = readByte();
        }
        return n == p ? buf : Arrays.copyOf(buf, p);
    }
    
    private char[][] nm(int n, int m)
    {
        char[][] map = new char[n][];
        for(int i = 0;i < n;i++)map[i] = ns(m);
        return map;
    }
    
    private int[] na(int n)
    {
        int[] a = new int[n];
        for(int i = 0;i < n;i++)a[i] = ni();
        return a;
    }
    
    private int ni()
    {
        int num = 0, b;
        boolean minus = false;
        while((b = readByte()) != -1 && !((b >= '0' && b <= '9') || b == '-'));
        if(b == '-'){
            minus = true;
            b = readByte();
        }
        
        while(true){
            if(b >= '0' && b <= '9'){
                num = num * 10 + (b - '0');
            }else{
                return minus ? -num : num;
            }
            b = readByte();
        }
    }
    
    private long nl()
    {
        long num = 0;
        int b;
        boolean minus = false;
        while((b = readByte()) != -1 && !((b >= '0' && b <= '9') || b == '-'));
        if(b == '-'){
            minus = true;
            b = readByte();
        }
        
        while(true){
            if(b >= '0' && b <= '9'){
                num = num * 10 + (b - '0');
            }else{
                return minus ? -num : num;
            }
            b = readByte();
        }
    }
    
    private static void tr(Object... o) { System.out.println(Arrays.deepToString(o)); }
}

Problem solution in C++ programming.

#include <bits/stdc++.h>

using namespace std ;

#define ft first
#define sd second
#define pb push_back
#define all(x) x.begin(),x.end()

#define ll long long int
#define vi vector<int>
#define vii vector<pair<int,int> >
#define pii pair<int,int>
#define plii pair<pair<ll, int>, int>
#define piii pair<pii, int>
#define viii vector<pair<pii, int> >
#define vl vector<ll>
#define vll vector<pair<ll,ll> >
#define pll pair<ll,ll>
#define pli pair<ll,int>
#define mp make_pair
#define ms(x, v) memset(x, v, sizeof x)

#define sc1(x) scanf("%d",&x)
#define sc2(x,y) scanf("%d%d",&x,&y)
#define sc3(x,y,z) scanf("%d%d%d",&x,&y,&z)

#define scll1(x) scanf("%lld",&x)
#define scll2(x,y) scanf("%lld%lld",&x,&y)
#define scll3(x,y,z) scanf("%lld%lld%lld",&x,&y,&z)

#define pr1(x) printf("%dn",x)
#define pr2(x,y) printf("%d %dn",x,y)
#define pr3(x,y,z) printf("%d %d %dn",x,y,z)

#define prll1(x) printf("%lldn",x)
#define prll2(x,y) printf("%lld %lldn",x,y)
#define prll3(x,y,z) printf("%lld %lld %lldn",x,y,z)

#define pr_vec(v) for(int i=0;i<v.size();i++) cout << v[i] << " " ;

#define f_in(st) freopen(st,"r",stdin)
#define f_out(st) freopen(st,"w",stdout)

#define fr(i, a, b) for(i=a; i<=b; i++)
#define fb(i, a, b) for(i=a; i>=b; i--)
#define ASST(x, l, r) assert( x <= r && x >= l )

#include <ext/pb_ds/assoc_container.hpp>
#include <ext/pb_ds/tree_policy.hpp>

const int mod = 1e9 + 7;

int ADD(int a, int b, int m = mod) {
    int s = a;
    s += b;
    if( s >= m )
      s -= m;
    return s;
}

int MUL(int a, int b, int m = mod) {
    return (1LL * a * b % m);
}

int power(int a, int b, int m = mod) {
    int res = 1;
    while( b ) {
        if( b & 1 ) {
            res = 1LL * res * a % m;
        }
        a = 1LL * a * a % m;
        b /= 2;
    }
    return res;
}

ll nC2(ll x) {
    return ( x * ( x - 1 ) / 2 );
}

const int maxn = 5 * 1e5 + 5;

int t, n, vis[maxn], cnt;
map<int, int> M;
vii adj[ maxn ];
int prime1 = 23, prime2 = 7, base[2][maxn];
int mod1 = 1589917477;
int mod2 = 1897266401;
vii a;
void dfs(int u, int p = 0, ll cst1 = 0, ll cst2 = 0) {
    a[u-1].ft = cst1;
    a[u-1].sd = cst2;
    for( auto it: adj[u] ) {
        if( it.ft != p ) {
            if(!M.count(it.sd)) {
                M[it.sd] = cnt ++;
            }
            vis[M[it.sd]] = 1 - vis[M[it.sd]];
            cst1 += (vis[M[it.sd]] ? base[0][M[it.sd]] : -base[0][M[it.sd]]);
            cst2 += (vis[M[it.sd]] ? base[1][M[it.sd]] : -base[1][M[it.sd]]);
            if( cst1 >= mod1 ) cst1 -= mod1; if( cst1 < 0 ) cst1 += mod1;
            if( cst2 >= mod2 ) cst2 -= mod2; if( cst2 < 0 ) cst2 += mod2;
            dfs(it.ft, u, cst1, cst2);
            vis[M[it.sd]] = 1 - vis[M[it.sd]];
            cst1 += (vis[M[it.sd]] ? base[0][M[it.sd]] : -base[0][M[it.sd]]);
            cst2 += (vis[M[it.sd]] ? base[1][M[it.sd]] : -base[1][M[it.sd]]);
            if( cst1 >= mod1 ) cst1 -= mod1; if( cst1 < 0 ) cst1 += mod1;
            if( cst2 >= mod2 ) cst2 -= mod2; if( cst2 < 0 ) cst2 += mod2;
        }
    }
}

int main() {
    cin >> t;
    int sum = 0;
    while( t-- ) {
        cin >> n; sum += n;
            assert(sum <= 500000);
        int i; base[0][0] = base[1][0] = 1;
        fr(i, 1, n) {
            base[0][i] = 1LL * base[0][i-1] * prime1 % mod1;
            base[1][i] = 1LL * base[1][i-1] * prime2 % mod2;
        }
        fr(i, 1, n-1) {
            int u, v, cst; 
            cin >> u >> v >> cst;
            adj[u].pb( {v, cst} );
            adj[v].pb( {u, cst} );
        }
        cnt = 0;
        a.resize(n);
        dfs(1, 0, 0);
        assert(a.size() == n);
        sort(all(a));
        i = 0;
        ll ans = 0;
        while( i < n ) {
            pii x = a[i]; int c = 0;
            while( i < n && x == a[i] ) {
                c ++; i ++;
            }
            ans += 1LL * c * (c-1) / 2;
        }
        ans = nC2(n) - ans;
        cout << ans << "n";
        M.clear(); a.clear();
        fr(i, 0, n) {
            adj[i].clear(); 
            vis[i] = base[0][i] = base[1][i] = 0;
        }
    }
    assert(n <= 500000);
    return 0;
}