In this HackerEarth Colorful Tree problem solution You are given a tree that contains N nodes, where every node i is colored with some color Ci.
The distance of a node V from a node U is defined as the number of edges along the simple path from the node U to the node V. Your task is to answer M queries of the following type:
- K C: Determine the distance of most distant node of color C from node K. If there is no node of color C in the tree, then print -1.
HackerEarth Colorful Tree problem solution.
#include <bits/stdc++.h>
using namespace std;
#define IOS ios::sync_with_stdio(0); cin.tie(0); cout.tie(0);
#define endl "n"
const int N=5e5+5;
const int LG=21;
int n, k, q, tim=0, dist=0, node;
int col[N];
int parent[LG][N];
int tin[N], tout[N], level[N], vertices[2*N];
vector<int> g[N], contains[N], diam[N];
vector<pair<int, int> > tree[N];
void dfs(int k, int par, int lvl)
{
tin[k]=++tim;
parent[0][k]=par;
level[k]=lvl;
for(auto it:g[k])
{
if(it==par)
continue;
dfs(it, k, lvl+1);
}
tout[k]=tim;
}
void precompute()
{
for(int i=1;i<LG;i++)
for(int j=1;j<=n;j++)
if(parent[i-1][j])
parent[i][j]=parent[i-1][parent[i-1][j]];
}
int LCA(int u, int v)
{
if(level[u]<level[v])
swap(u,v);
int diff=level[u]-level[v];
for(int i=LG-1;i>=0;i--)
{
if((1<<i) & diff)
{
u=parent[i][u];
}
}
if(u==v)
return u;
for(int i=LG-1;i>=0;i--)
{
if(parent[i][u] && parent[i][u]!=parent[i][v])
{
u=parent[i][u];
v=parent[i][v];
}
}
return parent[0][u];
}
int dist1(int u, int v)
{
return level[u]+level[v]-2*level[LCA(u, v)];
}
bool isancestor(int u, int v) //Check if u is an ancestor of v
{
return (tin[u]<=tin[v]) && (tout[v]<=tout[u]);
}
int dfs2(int k, int par, int dis)
{
//cerr<<k<<endl;
if(dis>dist)
{
dist=dis;
node=k;
}
for(auto it:tree[k])
{
if(it.first==par)
continue;
dfs2(it.first, k, dis+it.second);
}
}
int work(int color)
{
sort(vertices+1, vertices+k+1, [](int a, int b)
{
return tin[a]<tin[b];
});
int idx=k;
for(int i=1;i<idx;i++)
vertices[++k]=LCA(vertices[i], vertices[i+1]);
sort(vertices+1, vertices+k+1);
k=unique(vertices+1, vertices+k+1) - vertices - 1;
sort(vertices+1, vertices+k+1, [](int a, int b)
{
return tin[a]<tin[b];
});
stack<int> s;
s.push(vertices[1]);
for(int i=2;i<=k;i++)
{
while(!isancestor(s.top(), vertices[i]))
s.pop();
int u=s.top();
int v=vertices[i];
int w=dist1(u, v);
tree[u].push_back({v, w});
tree[v].push_back({u, w});
s.push(vertices[i]);
}
dist=0;
dfs2(vertices[1], vertices[1], 1);
diam[color].push_back(node);
dfs2(node, node, 1);
diam[color].push_back(node);
for(int i=1;i<=k;i++)
tree[vertices[i]].clear();
}
int32_t main()
{
IOS;
cin>>n>>q;
for(int i=1;i<=n;i++)
{
cin>>col[i];
contains[col[i]].push_back(i);
}
for(int i=1;i<=n-1;i++)
{
int u, v;
cin>>u>>v;
g[u].push_back(v);
g[v].push_back(u);
}
dfs(1, 0, 1);
precompute();
for(int i=1;i<=5e5;i++)
{
if(contains[i].size()<=2)
{
for(auto &it:contains[i])
diam[i].push_back(it);
continue;
}
k=0;
for(auto &it:contains[i])
vertices[++k]=it;
work(i);
}
while(q--)
{
int k, c;
cin>>k>>c;
if(!contains[c].size())
cout<<"-1"<<endl;
else
{
int ans=0;
for(auto &it: diam[c])
ans=max(ans, dist1(k, it));
cout<<ans<<endl;
}
}
return 0;
}
Second solution
#ifndef _GLIBCXX_NO_ASSERT
#include <cassert>
#endif
#include <cctype>
#include <cerrno>
#include <cfloat>
#include <ciso646>
#include <climits>
#include <clocale>
#include <cmath>
#include <csetjmp>
#include <csignal>
#include <cstdarg>
#include <cstddef>
#include <cstdio>
#include <cstdlib>
#include <cstring>
#include <ctime>
#if __cplusplus >= 201103L
#include <ccomplex>
#include <cfenv>
#include <cinttypes>
#include <cstdbool>
#include <cstdint>
#include <ctgmath>
#include <cwchar>
#include <cwctype>
#endif
// C++
#include <algorithm>
#include <bitset>
#include <complex>
#include <deque>
#include <exception>
#include <fstream>
#include <functional>
#include <iomanip>
#include <ios>
#include <iosfwd>
#include <iostream>
#include <istream>
#include <iterator>
#include <limits>
#include <list>
#include <locale>
#include <map>
#include <memory>
#include <new>
#include <numeric>
#include <ostream>
#include <queue>
#include <set>
#include <sstream>
#include <stack>
#include <stdexcept>
#include <streambuf>
#include <string>
#include <typeinfo>
#include <utility>
#include <valarray>
#include <vector>
#if __cplusplus >= 201103L
#include <array>
#include <atomic>
#include <chrono>
#include <condition_variable>
#include <forward_list>
#include <future>
#include <initializer_list>
#include <mutex>
#include <random>
#include <ratio>
#include <regex>
#include <scoped_allocator>
#include <system_error>
#include <thread>
#include <tuple>
#include <typeindex>
#include <type_traits>
#include <unordered_map>
#include <unordered_set>
#endif
#define ll long long
#define pb push_back
#define mp make_pair
#define pii pair<int,int>
#define vi vector<int>
#define all(a) (a).begin(),(a).end()
#define F first
#define S second
#define sz(x) (int)x.size()
#define hell 1000000007
#define endl 'n'
#define rep(i,a,b) for(int i=a;i<b;i++)
using namespace std;
string to_string(string s) {
return '"' + s + '"';
}
string to_string(const char* s) {
return to_string((string) s);
}
string to_string(bool b) {
return (b ? "true" : "false");
}
string to_string(char ch) {
return string("'")+ch+string("'");
}
template <typename A, typename B>
string to_string(pair<A, B> p) {
return "(" + to_string(p.first) + ", " + to_string(p.second) + ")";
}
template <class InputIterator>
string to_string (InputIterator first, InputIterator last) {
bool start = true;
string res = "{";
while (first!=last) {
if (!start) {
res += ", ";
}
start = false;
res += to_string(*first);
++first;
}
res += "}";
return res;
}
template <typename A>
string to_string(A v) {
bool first = true;
string res = "{";
for (const auto &x : v) {
if (!first) {
res += ", ";
}
first = false;
res += to_string(x);
}
res += "}";
return res;
}
void debug_out() { cerr << endl; }
template <typename Head, typename... Tail>
void debug_out(Head H, Tail... T) {
cerr << " " << to_string(H);
debug_out(T...);
}
template <typename A, typename B>
istream& operator>>(istream& input,pair<A,B>& x){
input>>x.F>>x.S;
return input;
}
template <typename A>
istream& operator>>(istream& input,vector<A>& x){
for(auto& i:x)
input>>i;
return input;
}
#ifdef PRINTERS
#define debug(...) cerr << "[" << #__VA_ARGS__ << "]:", debug_out(__VA_ARGS__)
#else
#define debug(...) 42
#endif
long long readInt(long long l,long long r,char endd){
long long x=0;
int cnt=0;
int fi=-1;
bool is_neg=false;
while(true){
char g=getchar();
if(g=='-'){
assert(fi==-1);
is_neg=true;
continue;
}
if('0'<=g && g<='9'){
x*=10;
x+=g-'0';
if(cnt==0){
fi=g-'0';
}
cnt++;
assert(fi!=0 || cnt==1);
assert(fi!=0 || is_neg==false);
assert(!(cnt>19 || ( cnt==19 && fi>1) ));
} else if(g==endd){
assert(cnt>0);
if(is_neg){
x= -x;
}
assert(l<=x && x<=r);
return x;
} else {
debug(int(g));
assert(false);
}
}
}
string readString(int l,int r,char endd){
string ret="";
int cnt=0;
while(true){
char g=getchar();
if(g==endd){
break;
}
else if(islower(g)){
cnt++;
ret+=g;
}
else{
assert(false);
}
}
assert(l<=cnt && cnt<=r);
return ret;
}
long long readIntSp(long long l,long long r){
return readInt(l,r,' ');
}
long long readIntLn(long long l,long long r){
return readInt(l,r,'n');
}
string readStringLn(int l,int r){
return readString(l,r,'n');
}
string readStringSp(int l,int r){
return readString(l,r,' ');
}
vi colnode[500005];
int intime[500005];
int outtime[500005];
int height[500005];
vi adj[500005];
int dp[20][500005];
vector<pii> temptree[500005];
vi reqdnodes[500005];
int color[500005];
void dfs(int u,int p=0){
static int clck = 1;
intime[u] = clck;
colnode[color[u]].emplace_back(u);
height[u] = height[p]+1;
dp[0][u]=p;
clck++;
for(auto i:adj[u]){
if(i!=p)dfs(i,u);
}
outtime[u]=clck;
clck++;
}
int lca(int u,int v){
if(height[u]>height[v])swap(u,v);
for(int i=19;i>=0;i--){
if(height[v]-(1<<i)>=height[u])v=dp[i][v];
}
if(u==v)return u;
for(int i=19;i>=0;i--){
if(dp[i][u]!=dp[i][v])u=dp[i][u],v=dp[i][v];
}
return dp[0][u];
}
struct diameter
{
pii maxdep1,maxdep2;
pair<int,pii> best_res;
diameter(int u){
maxdep1.S=u;
maxdep2.S=u;
best_res.S={u,u};
}
};
diameter get_diameter(int u,int p){
diameter res(u);
for(auto i:temptree[u]){
if(i.F==p)continue;
auto new_res = get_diameter(i.F,u);
if(new_res.maxdep1.F+i.S>res.maxdep1.F){
res.maxdep2=res.maxdep1;
res.maxdep1=mp(new_res.maxdep1.F+i.S,new_res.maxdep1.S);
}
else if(new_res.maxdep1.F+i.S>res.maxdep2.F){
res.maxdep2=mp(new_res.maxdep1.F+i.S,new_res.maxdep1.S);
}
res.best_res=max(res.best_res,new_res.best_res);
}
res.best_res=max(res.best_res,mp(res.maxdep1.F+res.maxdep2.F,mp(res.maxdep1.S,res.maxdep2.S)));
return res;
}
void solve(){
auto comp = [](int a,int b){return intime[a]<intime[b];};
int N,M;
N = readIntSp(1,500000);
M = readIntLn(1,500000);
rep(i,1,N+1){
int col;
if(i==N) col = readIntLn(1,500000);
else col = readIntSp(1,500000);
color[i] = col;
}
rep(i,1,N){
int u,v;
u = readIntSp(1,N);
v = readIntLn(1,N);
adj[u].emplace_back(v);
adj[v].emplace_back(u);
}
dfs(1);
for(int i = 1; i < 20; i++){
for(int j = 1; j <= N; j++){
dp[i][j]=dp[i-1][dp[i-1][j]];
}
}
vector<bool>nodes(N+1);
rep(i,1,500005){
if(colnode[i].empty())continue;
if(sz(colnode[i])==1){
reqdnodes[i].emplace_back(colnode[i].front());
continue;
}
colnode[i].reserve(2*sz(colnode[i]));
for(auto j:colnode[i])nodes[j]=1;
int k = sz(colnode[i]);
rep(j,1,k){
int tmp = lca(colnode[i][j-1],colnode[i][j]);
if(!nodes[tmp]){
nodes[tmp]=1;
colnode[i].emplace_back(tmp);
}
}
for(auto j:colnode[i])nodes[j]=0;
sort(colnode[i].begin()+k,colnode[i].end(),comp);
inplace_merge(colnode[i].begin(),colnode[i].begin()+k,colnode[i].end(),comp);
stack<pii>stk;
stk.emplace(0,INT_MAX);
for(auto j:colnode[i]){
while(outtime[j]>stk.top().S)stk.pop();
temptree[stk.top().F].emplace_back(j,abs(height[j]-height[stk.top().F]));
temptree[j].emplace_back(stk.top().F,abs(height[j]-height[stk.top().F]));
stk.emplace(j,outtime[j]);
}
auto res = get_diameter(temptree[0][0].F,0);
reqdnodes[i].emplace_back(res.best_res.S.F);
reqdnodes[i].emplace_back(res.best_res.S.S);
for(auto j:colnode[i]){
temptree[j].clear();
}
temptree[0].clear();
vi().swap(colnode[i]);
}
rep(i,1,M+1){
int K,C;
K = readIntSp(1,N);
if(i==M) C = readInt(1,500000,EOF);
else C = readIntLn(1,500000);
int ans = -1;
for(auto j:reqdnodes[C]){
int LCA = lca(j,K);
ans = max(ans,height[j]+height[K]-2*height[LCA]);
}
cout << ans << endl;
}
}
int main(){
ios::sync_with_stdio(0); cin.tie(0); cout.tie(0);
int t=1;
// cin>>t;
while(t--){
solve();
}
return 0;
}