In this HackerEarth Joseph and Tree problem solution, Joseph loves games about a tree! His friend Nick invented a game for him!
Initially, there is a rooted weighted tree with N vertices numbered 1 … N. Nick guarantees that the tree is connected, and there is a unique path between any vertices! Also he gave us Q queries on it of the following type:
- v and k: Let S denote the sorted (nondecreasing order) array of shortest distances from v to any other vertex from subtree rooted v. Answer will be kth element of S. If such a number does not exist, i.e. the S has less than k elements, answer is 1. Note that v is not included in his own subtree.
All the indices in the queries are 1-based. The root of the tree is node 1.
But it turns out, Joseph has an exam tomorrow, and he doesn’t have time for playing a game! And he asks your help!
HackerEarth Joseph and Tree problem solution.
#include <bits/stdc++.h>
#define pb push_back
#define f first
#define s second
#define mp make_pair
#define sz(a) int((a).size())
#ifdef _WIN32
# define I64 "%I64d"
#else
# define I64 "%lld"
#endif
#define fname "."
#define pi pair < int, int >
#define pp pop_back
typedef long long ll;
typedef long double ld;
typedef unsigned long long ull;
const int MAX_N = (int)1e5 + 123;
const double eps = 1e-6;
const int inf = (int)1e9 + 123;
using namespace std;
int n, q;
vector < pi > g[MAX_N];
struct tree {
int sum, l, r;
tree() : sum(0), l(-1), r(-1) {}
};
vector < tree > t;
int update(int x, int v, int tl = 0, int tr = n - 1) {
int now = sz(t);
t.pb(tree());
if (v != -1)
t[now] = t[v];
if (tl == tr) {
t[now].sum++;
return now;
}
int tm = (tl + tr) / 2;
if (x <= tm) {
int son = update(x, (v == -1 ? -1 : t[v].l), tl, tm);
t[now].l = son;
}
else {
int son = update(x, (v == -1 ? -1 : t[v].r), tm + 1, tr);
t[now].r = son;
}
t[now].sum = 0;
if (t[now].l != -1)
t[now].sum += t[t[now].l].sum;
if (t[now].r != -1)
t[now].sum += t[t[now].r].sum;
return now;
}
int find_kth(int L, int R, int k, int tl = 0, int tr = n - 1) {
if (tl == tr)
return tl;
int tm = (tl + tr) / 2;
int left = 0;
if (R != -1 && t[R].l != -1)
left += t[t[R].l].sum;
if (L != -1 && t[L].l != -1)
left -= t[t[L].l].sum;
if (k <= left)
return find_kth((L == -1 ? -1 : t[L].l), (R == -1 ? -1 : t[R].l), k, tl, tm);
k -= left;
return find_kth((L == -1 ? -1 : t[L].r), (R == -1 ? -1 : t[R].r), k, tm + 1, tr);
}
vector < int > st;
ll dist[MAX_N];
int l[MAX_N], r[MAX_N];
void dfs(int v, int pr = -1, ll all = 0) {
dist[v] = all;
l[v] = sz(st);
st.pb(v);
for (auto to : g[v])
if (to.f != pr)
dfs(to.f, v, all + to.s);
r[v] = sz(st) - 1;
}
int root[MAX_N];
vector < ll > uniq;
ll get(int v, int k) {
int sz = r[v] - l[v];
if (k > sz)
return -1;
return uniq[find_kth(root[l[v]], root[r[v]], k)] - uniq[dist[v]];
}
int main() {
#ifdef Nick
freopen(fname"in", "r", stdin);
freopen(fname"out", "w", stdout);
#endif
scanf("%d", &n);
for (int i = 1, u, v, w; i < n; i++) {
scanf("%d%d%d", &u, &v, &w);
g[u].pb(mp(v, w)), g[v].pb(mp(u, w));
}
dfs(1);
for (int i = 1; i <= n; i++)
uniq.pb(dist[i]);
sort(uniq.begin(), uniq.end());
uniq.resize(unique(uniq.begin(), uniq.end()) - uniq.begin());
for (int i = 1; i <= n; i++)
dist[i] = lower_bound(uniq.begin(), uniq.end(), dist[i]) - uniq.begin();
for (int i = 0, last = -1; i < sz(st); i++) {
last = root[i] = update(dist[st[i]], last);
}
int query;
scanf("%d", &query);
for (int i = 1, v, k; i <= query; i++) {
scanf("%d%d", &v, &k);
printf(I64"n", get(v, k));
}
return 0;
}
Second solution
import java.io.*;
import java.util.*;
public class AugClashJosephAndTree {
static int N;
static ArrayList<Integer> adj[], weight[];
static long depth[];
static int dfsOrder[];
static int timeIn[], timeOut[];
static int time;
static TreeMap<Long, Integer> dist;
static long revMap[];
static ArrayList<Node> nodes;
static int root[];
@SuppressWarnings("unchecked")
public static void main(String[] args) {
InputReader in = new InputReader(System.in);
PrintWriter out = new PrintWriter(System.out);
N = in.nextInt();
check(1, N, (int) 1e5);
adj = new ArrayList[N + 1];
weight = new ArrayList[N + 1];
for (int i = 1; i <= N; i++) {
adj[i] = new ArrayList<Integer>();
weight[i] = new ArrayList<Integer>();
}
for (int i = 1; i < N; i++) {
int a = in.nextInt();
int b = in.nextInt();
int w = in.nextInt();
check(1, a, N);
check(1, b, N);
check(1, w, (int) 1e9);
adj[a].add(b);
weight[a].add(w);
adj[b].add(a);
weight[b].add(w);
}
depth = new long[N + 1];
timeIn = new int[N + 1];
timeOut = new int[N + 1];
dfsOrder = new int[N + 1];
time = 0;
dfs(1, -1, 0);
dist = new TreeMap<Long, Integer>();
for (int i = 1; i <= N; i++)
dist.put(depth[i], 0);
revMap = new long[dist.size() + 1];
int cnt = 0;
for (long x : dist.keySet()) {
dist.put(x, ++cnt);
revMap[cnt] = x;
}
root = new int[N + 1];
root[0] = 0;
nodes = new ArrayList<Node>();
nodes.add(new Node());
// dist.get(depth[i]) returns rank of depth[i]
for (int i = 1; i <= N; i++) {
root[i] = update(root[i - 1], 1, cnt, dist.get(depth[dfsOrder[i]]));
}
int Q = in.nextInt();
check(1, Q, (int) 1e5);
while (Q-- > 0) {
int v = in.nextInt();
int k = in.nextInt();
check(1, v, N);
check(1, k, (int) 1e9);
out.println(solve(v, k, cnt));
}
out.close();
}
static long solve(int v, int k, int cnt) {
int size = timeOut[v] - timeIn[v];
if (size < k)
return -1;
int ans = findKth(1, cnt, root[timeIn[v]], root[timeOut[v]], k);
return revMap[ans] - depth[v];
}
static int findKth(int start, int end, int leftRoot, int rightRoot, int k) {
if (start == end)
return start;
int mid = (start + end) >> 1;
int leftSum = sum(left(rightRoot)) - sum(left(leftRoot)); //number of nodes in [start,mid]
if (leftSum >= k)
return findKth(start, mid, left(leftRoot), left(rightRoot), k);
else
return findKth(mid + 1, end, right(leftRoot), right(rightRoot), k - leftSum);
}
static int update(int prevRoot, int start, int end, int x) {
Node now = new Node();
now.sum = sum(prevRoot);
now.left = left(prevRoot);
now.right = right(prevRoot);
nodes.add(now);
int idx = nodes.size() - 1;
if (start == end) {
now.sum++;
return idx;
}
int mid = (start + end) >> 1;
if (x <= mid) {
int leftSon = left(prevRoot);
now.left = update(leftSon, start, mid, x);
}
else {
int rightSon = right(prevRoot);
now.right = update(rightSon, mid + 1, end, x);
}
now.sum = sum(now.left) + sum(now.right);
return idx;
}
static void dfs(int curr, int parent, long pathLength) {
depth[curr] = pathLength;
timeIn[curr] = ++time;
dfsOrder[time] = curr;
for (int i = 0; i < adj[curr].size(); i++) {
int child = adj[curr].get(i);
int edgeWeight = weight[curr].get(i);
if (child != parent) {
dfs(child, curr, pathLength + edgeWeight);
}
}
timeOut[curr] = time;
}
static class Node {
int sum, left, right;
Node() {
sum = 0;
left = -1;
right = -1;
}
}
static int sum(int idx) {
return idx == -1 ? 0 : nodes.get(idx).sum;
}
static int left(int idx) {
return idx == -1 ? -1 : nodes.get(idx).left;
}
static int right(int idx) {
return idx == -1 ? -1 : nodes.get(idx).right;
}
static void check(int start, int key, int end) {
if (key < start || key > end)
throw new RuntimeException();
}
static class InputReader {
private final InputStream stream;
private final byte[] buf = new byte[8192];
private int curChar, snumChars;
private SpaceCharFilter filter;
public InputReader(InputStream stream) {
this.stream = stream;
}
public int snext() {
if (snumChars == -1)
throw new InputMismatchException();
if (curChar >= snumChars) {
curChar = 0;
try {
snumChars = stream.read(buf);
} catch (IOException e) {
throw new InputMismatchException();
}
if (snumChars <= 0)
return -1;
}
return buf[curChar++];
}
public int nextInt() {
int c = snext();
while (isSpaceChar(c)) {
c = snext();
}
int sgn = 1;
if (c == '-') {
sgn = -1;
c = snext();
}
int res = 0;
do {
if (c < '0' || c > '9')
throw new InputMismatchException();
res *= 10;
res += c - '0';
c = snext();
} while (!isSpaceChar(c));
return res * sgn;
}
public boolean isSpaceChar(int c) {
if (filter != null)
return filter.isSpaceChar(c);
return c == ' ' || c == 'n' || c == 'r' || c == 't' || c == -1;
}
public interface SpaceCharFilter {
public boolean isSpaceChar(int ch);
}
}
}