HackerEarth Colorful Tree problem solution YASH PAL, 31 July 2024 In this HackerEarth Colorful Tree problem solution You are given a tree that contains N nodes, where every node i is colored with some color Ci. The distance of a node V from a node U is defined as the number of edges along the simple path from the node U to the node V. Your task is to answer M queries of the following type: K C: Determine the distance of most distant node of color C from node K. If there is no node of color C in the tree, then print -1. HackerEarth Colorful Tree problem solution. #include <bits/stdc++.h>using namespace std; #define IOS ios::sync_with_stdio(0); cin.tie(0); cout.tie(0);#define endl "n"const int N=5e5+5;const int LG=21;int n, k, q, tim=0, dist=0, node;int col[N];int parent[LG][N];int tin[N], tout[N], level[N], vertices[2*N];vector<int> g[N], contains[N], diam[N];vector<pair<int, int> > tree[N];void dfs(int k, int par, int lvl){ tin[k]=++tim; parent[0][k]=par; level[k]=lvl; for(auto it:g[k]) { if(it==par) continue; dfs(it, k, lvl+1); } tout[k]=tim;}void precompute(){ for(int i=1;i<LG;i++) for(int j=1;j<=n;j++) if(parent[i-1][j]) parent[i][j]=parent[i-1][parent[i-1][j]];}int LCA(int u, int v){ if(level[u]<level[v]) swap(u,v); int diff=level[u]-level[v]; for(int i=LG-1;i>=0;i--) { if((1<<i) & diff) { u=parent[i][u]; } } if(u==v) return u; for(int i=LG-1;i>=0;i--) { if(parent[i][u] && parent[i][u]!=parent[i][v]) { u=parent[i][u]; v=parent[i][v]; } } return parent[0][u];}int dist1(int u, int v){ return level[u]+level[v]-2*level[LCA(u, v)];}bool isancestor(int u, int v) //Check if u is an ancestor of v{ return (tin[u]<=tin[v]) && (tout[v]<=tout[u]);}int dfs2(int k, int par, int dis){ //cerr<<k<<endl; if(dis>dist) { dist=dis; node=k; } for(auto it:tree[k]) { if(it.first==par) continue; dfs2(it.first, k, dis+it.second); }}int work(int color){ sort(vertices+1, vertices+k+1, [](int a, int b) { return tin[a]<tin[b]; }); int idx=k; for(int i=1;i<idx;i++) vertices[++k]=LCA(vertices[i], vertices[i+1]); sort(vertices+1, vertices+k+1); k=unique(vertices+1, vertices+k+1) - vertices - 1; sort(vertices+1, vertices+k+1, [](int a, int b) { return tin[a]<tin[b]; }); stack<int> s; s.push(vertices[1]); for(int i=2;i<=k;i++) { while(!isancestor(s.top(), vertices[i])) s.pop(); int u=s.top(); int v=vertices[i]; int w=dist1(u, v); tree[u].push_back({v, w}); tree[v].push_back({u, w}); s.push(vertices[i]); } dist=0; dfs2(vertices[1], vertices[1], 1); diam[color].push_back(node); dfs2(node, node, 1); diam[color].push_back(node); for(int i=1;i<=k;i++) tree[vertices[i]].clear();}int32_t main(){ IOS; cin>>n>>q; for(int i=1;i<=n;i++) { cin>>col[i]; contains[col[i]].push_back(i); } for(int i=1;i<=n-1;i++) { int u, v; cin>>u>>v; g[u].push_back(v); g[v].push_back(u); } dfs(1, 0, 1); precompute(); for(int i=1;i<=5e5;i++) { if(contains[i].size()<=2) { for(auto &it:contains[i]) diam[i].push_back(it); continue; } k=0; for(auto &it:contains[i]) vertices[++k]=it; work(i); } while(q--) { int k, c; cin>>k>>c; if(!contains[c].size()) cout<<"-1"<<endl; else { int ans=0; for(auto &it: diam[c]) ans=max(ans, dist1(k, it)); cout<<ans<<endl; } } return 0;} Second solution #ifndef _GLIBCXX_NO_ASSERT#include <cassert>#endif#include <cctype>#include <cerrno>#include <cfloat>#include <ciso646>#include <climits>#include <clocale>#include <cmath>#include <csetjmp>#include <csignal>#include <cstdarg>#include <cstddef>#include <cstdio>#include <cstdlib>#include <cstring>#include <ctime>#if __cplusplus >= 201103L#include <ccomplex>#include <cfenv>#include <cinttypes>#include <cstdbool>#include <cstdint>#include <ctgmath>#include <cwchar>#include <cwctype>#endif// C++#include <algorithm>#include <bitset>#include <complex>#include <deque>#include <exception>#include <fstream>#include <functional>#include <iomanip>#include <ios>#include <iosfwd>#include <iostream>#include <istream>#include <iterator>#include <limits>#include <list>#include <locale>#include <map>#include <memory>#include <new>#include <numeric>#include <ostream>#include <queue>#include <set>#include <sstream>#include <stack>#include <stdexcept>#include <streambuf>#include <string>#include <typeinfo>#include <utility>#include <valarray>#include <vector>#if __cplusplus >= 201103L#include <array>#include <atomic>#include <chrono>#include <condition_variable>#include <forward_list>#include <future>#include <initializer_list>#include <mutex>#include <random>#include <ratio>#include <regex>#include <scoped_allocator>#include <system_error>#include <thread>#include <tuple>#include <typeindex>#include <type_traits>#include <unordered_map>#include <unordered_set>#endif#define ll long long#define pb push_back#define mp make_pair#define pii pair<int,int>#define vi vector<int>#define all(a) (a).begin(),(a).end()#define F first#define S second#define sz(x) (int)x.size()#define hell 1000000007#define endl 'n'#define rep(i,a,b) for(int i=a;i<b;i++)using namespace std;string to_string(string s) { return '"' + s + '"';}string to_string(const char* s) { return to_string((string) s);}string to_string(bool b) { return (b ? "true" : "false");}string to_string(char ch) { return string("'")+ch+string("'");}template <typename A, typename B>string to_string(pair<A, B> p) { return "(" + to_string(p.first) + ", " + to_string(p.second) + ")";}template <class InputIterator>string to_string (InputIterator first, InputIterator last) { bool start = true; string res = "{"; while (first!=last) { if (!start) { res += ", "; } start = false; res += to_string(*first); ++first; } res += "}"; return res;}template <typename A>string to_string(A v) { bool first = true; string res = "{"; for (const auto &x : v) { if (!first) { res += ", "; } first = false; res += to_string(x); } res += "}"; return res;}void debug_out() { cerr << endl; }template <typename Head, typename... Tail>void debug_out(Head H, Tail... T) { cerr << " " << to_string(H); debug_out(T...);}template <typename A, typename B>istream& operator>>(istream& input,pair<A,B>& x){ input>>x.F>>x.S; return input;}template <typename A>istream& operator>>(istream& input,vector<A>& x){ for(auto& i:x) input>>i; return input;}#ifdef PRINTERS#define debug(...) cerr << "[" << #__VA_ARGS__ << "]:", debug_out(__VA_ARGS__)#else#define debug(...) 42#endiflong long readInt(long long l,long long r,char endd){ long long x=0; int cnt=0; int fi=-1; bool is_neg=false; while(true){ char g=getchar(); if(g=='-'){ assert(fi==-1); is_neg=true; continue; } if('0'<=g && g<='9'){ x*=10; x+=g-'0'; if(cnt==0){ fi=g-'0'; } cnt++; assert(fi!=0 || cnt==1); assert(fi!=0 || is_neg==false); assert(!(cnt>19 || ( cnt==19 && fi>1) )); } else if(g==endd){ assert(cnt>0); if(is_neg){ x= -x; } assert(l<=x && x<=r); return x; } else { debug(int(g)); assert(false); } }}string readString(int l,int r,char endd){ string ret=""; int cnt=0; while(true){ char g=getchar(); if(g==endd){ break; } else if(islower(g)){ cnt++; ret+=g; } else{ assert(false); } } assert(l<=cnt && cnt<=r); return ret;}long long readIntSp(long long l,long long r){ return readInt(l,r,' ');}long long readIntLn(long long l,long long r){ return readInt(l,r,'n');}string readStringLn(int l,int r){ return readString(l,r,'n');}string readStringSp(int l,int r){ return readString(l,r,' ');}vi colnode[500005];int intime[500005];int outtime[500005];int height[500005];vi adj[500005];int dp[20][500005];vector<pii> temptree[500005];vi reqdnodes[500005];int color[500005];void dfs(int u,int p=0){ static int clck = 1; intime[u] = clck; colnode[color[u]].emplace_back(u); height[u] = height[p]+1; dp[0][u]=p; clck++; for(auto i:adj[u]){ if(i!=p)dfs(i,u); } outtime[u]=clck; clck++;}int lca(int u,int v){ if(height[u]>height[v])swap(u,v); for(int i=19;i>=0;i--){ if(height[v]-(1<<i)>=height[u])v=dp[i][v]; } if(u==v)return u; for(int i=19;i>=0;i--){ if(dp[i][u]!=dp[i][v])u=dp[i][u],v=dp[i][v]; } return dp[0][u];}struct diameter{ pii maxdep1,maxdep2; pair<int,pii> best_res; diameter(int u){ maxdep1.S=u; maxdep2.S=u; best_res.S={u,u}; }};diameter get_diameter(int u,int p){ diameter res(u); for(auto i:temptree[u]){ if(i.F==p)continue; auto new_res = get_diameter(i.F,u); if(new_res.maxdep1.F+i.S>res.maxdep1.F){ res.maxdep2=res.maxdep1; res.maxdep1=mp(new_res.maxdep1.F+i.S,new_res.maxdep1.S); } else if(new_res.maxdep1.F+i.S>res.maxdep2.F){ res.maxdep2=mp(new_res.maxdep1.F+i.S,new_res.maxdep1.S); } res.best_res=max(res.best_res,new_res.best_res); } res.best_res=max(res.best_res,mp(res.maxdep1.F+res.maxdep2.F,mp(res.maxdep1.S,res.maxdep2.S))); return res;}void solve(){ auto comp = [](int a,int b){return intime[a]<intime[b];}; int N,M; N = readIntSp(1,500000); M = readIntLn(1,500000); rep(i,1,N+1){ int col; if(i==N) col = readIntLn(1,500000); else col = readIntSp(1,500000); color[i] = col; } rep(i,1,N){ int u,v; u = readIntSp(1,N); v = readIntLn(1,N); adj[u].emplace_back(v); adj[v].emplace_back(u); } dfs(1); for(int i = 1; i < 20; i++){ for(int j = 1; j <= N; j++){ dp[i][j]=dp[i-1][dp[i-1][j]]; } } vector<bool>nodes(N+1); rep(i,1,500005){ if(colnode[i].empty())continue; if(sz(colnode[i])==1){ reqdnodes[i].emplace_back(colnode[i].front()); continue; } colnode[i].reserve(2*sz(colnode[i])); for(auto j:colnode[i])nodes[j]=1; int k = sz(colnode[i]); rep(j,1,k){ int tmp = lca(colnode[i][j-1],colnode[i][j]); if(!nodes[tmp]){ nodes[tmp]=1; colnode[i].emplace_back(tmp); } } for(auto j:colnode[i])nodes[j]=0; sort(colnode[i].begin()+k,colnode[i].end(),comp); inplace_merge(colnode[i].begin(),colnode[i].begin()+k,colnode[i].end(),comp); stack<pii>stk; stk.emplace(0,INT_MAX); for(auto j:colnode[i]){ while(outtime[j]>stk.top().S)stk.pop(); temptree[stk.top().F].emplace_back(j,abs(height[j]-height[stk.top().F])); temptree[j].emplace_back(stk.top().F,abs(height[j]-height[stk.top().F])); stk.emplace(j,outtime[j]); } auto res = get_diameter(temptree[0][0].F,0); reqdnodes[i].emplace_back(res.best_res.S.F); reqdnodes[i].emplace_back(res.best_res.S.S); for(auto j:colnode[i]){ temptree[j].clear(); } temptree[0].clear(); vi().swap(colnode[i]); } rep(i,1,M+1){ int K,C; K = readIntSp(1,N); if(i==M) C = readInt(1,500000,EOF); else C = readIntLn(1,500000); int ans = -1; for(auto j:reqdnodes[C]){ int LCA = lca(j,K); ans = max(ans,height[j]+height[K]-2*height[LCA]); } cout << ans << endl; }}int main(){ ios::sync_with_stdio(0); cin.tie(0); cout.tie(0); int t=1;// cin>>t; while(t--){ solve(); } return 0;} coding problems