In this HackerEarth XOR paths problem solution You are given a weighted tree (an acyclic undirected connected graph) with N nodes. The tree nodes are numbered from 1 to N. There are N – 1 edge with each having a weight assigned to it.
You have to process Q queries on it. In each query, you are given three integers u, v, x. You are required to determine the maximum XOR that you can obtain when you the bitwise XOR operation on any edge weight in the path from node u to node v with x.
HackerEarth XOR paths problem solution.
#include<bits/stdc++.h>
#include<ext/pb_ds/assoc_container.hpp>
#include<ext/pb_ds/tree_policy.hpp>
#define int long long
#define mkp make_pair
#define pb push_back
#define ff first
#define ss second
#define debug1(a) cout<<a<<endl;
#define debug2(a,b) cout<<a<<' '<<b<<endl;
#define debug3(a,b,c) cout<<a<' '<<b<<' '<<c<<endl;
#define rep(i,n) for(int i=0;i<n;i++)
#define repre(i,a,b)for(int i=a;i<=b;i++)
#define clr1(arr) memset(arr,-1,sizeof(arr));
#define clr0(arr) memset(arr,0,sizeof(arr));
#define pi pair<int,int>
#define pii pair<int,pi>
#define aint(v) v.begin(),v.end()
#define fastio ios_base::sync_with_stdio(false);cin.tie(NULL);cout.tie(NULL);
using namespace __gnu_pbds;
using namespace std;
typedef tree<int,null_type,less<int>,rb_tree_tag,tree_order_statistics_node_update> orderedSet;
typedef tree<int,null_type,less_equal<int>,rb_tree_tag,tree_order_statistics_node_update> orderedMSet;
struct custom_hash {
static uint64_t splitmix64(uint64_t x) {
x += 0x9e3779b97f4a7c15;
x = (x ^ (x >> 30)) * 0xbf58476d1ce4e5b9;
x = (x ^ (x >> 27)) * 0x94d049bb133111eb;
return x ^ (x >> 31);
}
size_t operator()(uint64_t x) const {
static const uint64_t FIXED_RANDOM = chrono::steady_clock::now().time_since_epoch().count();
return splitmix64(x + FIXED_RANDOM);
}
};
#define ump gp_hash_table<int,int,custom_hash>
int power(int x,int y);
bool isPrime(int n);
int modInv(int a,int b);
int gcdExtended(int a,int b,int *x,int* y);
int mpower(int a,int b,int p);
const int maxn=100009;
int chainNo,ptr;
int chainHead[maxn];
int posInBase[maxn];
int chainInd[maxn];
int baseArray[maxn];
int dp[maxn][25];
int level[maxn];
int sz[maxn];
vector<int>adj[maxn];
vector<int>costs[maxn];
class node
{
public:
node* child[2];
node()
{
this->child[0]=NULL;
this->child[1]=NULL;
}
};
node* trieArr[4*maxn+1];
void combine(node* &root,node* root1,node* root2)
{
if(root1==NULL && root2==NULL)return;
root=new node();
if(root1==NULL){root=root2;return;}
if(root2==NULL){root=root1;return;}
combine(root->child[0],root1->child[0],root2->child[0]);
combine(root->child[1],root1->child[1],root2->child[1]);
}
void insertInTrie(node* root,int val)
{
node* temp=root;
for(int i=31;i>=0;i--)
{
bool bitAt=val & (1<<i);
if(temp->child[bitAt]==NULL)temp->child[bitAt]=new node();
temp=temp->child[bitAt];
}
}
void build(int index,int low,int high)
{
if(low>high)return;
if(low==high)
{
trieArr[index]=new node();
insertInTrie(trieArr[index],baseArray[low]);
return;
}
int mid=low+(high-low)/2;
build(2*index,low,mid);
build(2*index+1,mid+1,high);
combine(trieArr[index],trieArr[2*index],trieArr[2*index+1]);
}
int queryInTrie(node* root,int pref){
int maxVal = 0;
node* curr = root;
for(int i=31; i>=0; i--){
int b = (pref>>i)&1;
if(b==0){
if(curr->child[1]){
maxVal += (1<<i);
curr=curr->child[1];
}else{
curr=curr->child[0];
}
}else{
if(curr->child[0]){
maxVal += (1<<i);
curr=curr->child[0];
}else{
curr=curr->child[1];
}
}
}
return maxVal;
}
int query(int index,int low,int high,int l,int r,int x)
{
if(low>r || high<l || low>high)return 0;
if(low>=l && high<=r)return queryInTrie(trieArr[index],x);
int mid=low+(high-low)/2;
return max(query(2*index,low,mid,l,r,x),query(2*index+1,mid+1,high,l,r,x));
}
void dfs(int v,int p,int d)
{
level[v]=d;
dp[v][0]=p;
sz[v]=1;
for(auto u:adj[v]){
if(u==p)continue;
dfs(u,v,d+1);
sz[v]+=sz[u];
}
}
void pre(int n)
{
repre(j,1,20)
{
repre(i,1,n)
{
if(dp[i][j-1]!=-1)dp[i][j]=dp[dp[i][j-1]][j-1];
}
}
}
int lca(int a,int b)
{
if(level[a]<level[b])swap(a,b);
int diff=level[a]-level[b];
repre(i,0,20)
{
if(diff & (1<<i))a=dp[a][i];
}
if(a==b)return a;
for(int i=20;i>=0;i--)
{
if(dp[a][i]!=dp[b][i])
{
a=dp[a][i];
b=dp[b][i];
}
}
return dp[a][0];
}
void hld(int curNode,int cost,int prev)
{
if(chainHead[chainNo]==-1)chainHead[chainNo]=curNode;
chainInd[curNode]=chainNo;
posInBase[curNode]=ptr;
baseArray[ptr++]=cost;
int sc=-1;
int ncost;
for(int i=0;i<adj[curNode].size();i++)
{
if(adj[curNode][i]!=prev)
{
if(sc==-1 || sz[sc]<sz[adj[curNode][i]])
{
sc=adj[curNode][i];
ncost=costs[curNode][i];
}
}
}
if(sc!=-1)hld(sc,ncost,curNode);
for(int i=0;i<adj[curNode].size();i++)
{
if(adj[curNode][i]!=prev)
{
if(sc!=adj[curNode][i])
{
chainNo++;
hld(adj[curNode][i],costs[curNode][i],curNode);
}
}
}
}
int queryUp(int u,int v,int x)
{
if(u==v)return 0;
int vchain=chainInd[v];
int uchain;
int ans=0;
while(1)
{
uchain=chainInd[u];
if(uchain==vchain)
{
if(u==v)break;
ans=max(ans,query(1,0,ptr-1,posInBase[v]+1,posInBase[u],x));
break;
}
ans=max(ans,query(1,0,ptr-1,posInBase[chainHead[uchain]],posInBase[u],x));
u=chainHead[uchain];
u=dp[u][0];
}
return ans;
}
int getAns(int u,int v,int x){
int lc=lca(u,v);
int ans1=queryUp(u,lc,x);
int ans2=queryUp(v,lc,x);
return max(ans1,ans2);
}
int32_t main()
{
fastio
clr1(chainHead);
clr1(dp);
int n,q;
cin>>n>>q;
repre(i,1,n-1)
{
int u,v,w;
cin>>u>>v>>w;
adj[u].pb(v);
adj[v].pb(u);
costs[u].pb(w);
costs[v].pb(w);
}
chainNo=0;
ptr=0;
dfs(1,0,0);
pre(n);
hld(1,0,0);
build(1,0,ptr-1);
while(q--)
{
int u,v,x;
cin>>u>>v>>x;
cout<<getAns(u,v,x)<<'n';
}
return 0;
}
int modInv(int a, int m)
{
int x, y;
int g = gcdExtended(a, m, &x, &y);
int res = (x%m + m) % m;
return res;
}
int gcdExtended(int a, int b, int *x, int *y)
{
if (a == 0){
*x = 0, *y = 1;
return b;
}
int x1, y1;
int gcd = gcdExtended(b%a, a, &x1, &y1);
*x = y1 - (b/a) * x1;
*y = x1;
return gcd;
}
int mpower(int x, int y, int p)
{
int res = 1;
x = x % p;
while(y > 0){
if(y & 1) res = (res*x) % p;
y = y>>1;
x = (x*x) % p;
}
return res;
}
int power(int x, int y)
{
int res = 1;
while (y > 0){
if (y & 1) res = res*x;
y = y>>1;
x = x*x;
}
return res;
}
bool isPrime(int n)
{
if (n <= 1) return false;
if (n <= 3) return true;
if (n%2 == 0 || n%3 == 0) return false;
int p=sqrt(n);
for(int i=5;i<=p;i=i+6)
if (n%i == 0 || n%(i+2) == 0)
return false;
return true;
}