In this HackerRank Coprime Paths problem solution You are given an undirected, connected graph, G, with n nodes and m edges where m = n-1. Each node i is initially assigned a value, node, that has at most 3 prime divisors.
You must answer q queries in the form u v. For each query, find and print the number of (x,y) pairs of nodes on the path between u and v such that gcd(node x, node y) = 1 and the length of the path between u and v is minimal among all paths from u to v.
Problem solution in Java.
import java.io.ByteArrayInputStream; import java.io.IOException; import java.io.InputStream; import java.io.PrintWriter; import java.util.Arrays; import java.util.InputMismatchException; public class F { InputStream is; PrintWriter out; String INPUT = ""; long ret; int[] freq; int[] pfreq; EulerTour et; int[] lpf = enumLowestPrimeFactors(10000005); int[] mob = enumMobiusByLPF(10000005, lpf); int[] a; void solve() { int n = ni(), Q = ni(); a = na(n); for(int i = 0;i < n;i++){ int pre = -1; int mul = 1; for(int j = a[i];j > 1;j /= lpf[j]){ if(pre != lpf[j]){ mul *= lpf[j]; pre = lpf[j]; } } a[i] = mul; } int[] from = new int[n - 1]; int[] to = new int[n - 1]; for (int i = 0; i < n - 1; i++) { from[i] = ni() - 1; to[i] = ni() - 1; } int[][] g = packU(n, from, to); int[][] pars = parents3(g, 0); int[] par = pars[0], ord = pars[1], dep = pars[2]; et = nodalEulerTour(g, 0); int[][] spar = logstepParents(par); int[][] qs = new int[Q][]; int[] special = new int[Q]; Arrays.fill(special, -1); for(int i = 0;i < Q;i++){ int x = ni()-1, y = ni()-1; int lca = lca2(x, y, spar, dep); if(lca == x){ qs[i] = new int[]{et.first[x], et.first[y]}; }else if(lca == y){ qs[i] = new int[]{et.first[y], et.first[x]}; }else if(et.first[x] < et.first[y]){ qs[i] = new int[]{et.last[x], et.first[y]}; special[i] = lca; }else{ qs[i] = new int[]{et.last[y], et.first[x]}; special[i] = lca; } } long[] pqs = sqrtSort(qs, 2*n-1); int L = 0, R = -1; freq = new int[n]; long[] ans = new long[Q]; pfreq = new int[10000005]; for(long pa : pqs){ int ind = (int)(pa&(1<<25)-1); int ql = qs[ind][0], qr = qs[ind][1]; while(R < qr)change(++R, 1); while(L > ql)change(--L, 1); while(R > qr)change(R--, -1); while(L < ql)change(L++, -1); if(special[ind] != -1)change(et.first[special[ind]], 1); ans[ind] = ret; if(special[ind] != -1)change(et.first[special[ind]], -1); } for(long v : ans){ out.println(v); } } public static void trnz(int... o) { for(int i = 0;i < o.length;i++)if(o[i] != 0)System.out.print(i+":"+o[i]+" "); System.out.println(); } public static int[] enumMobiusByLPF(int n, int[] lpf) { int[] mob = new int[n+1]; mob[1] = 1; for(int i = 2;i <= n;i++){ int j = i/lpf[i]; if(lpf[j] == lpf[i]){ // mob[i] = 0; }else{ mob[i] = -mob[j]; } } return mob; } void dfs(int cur, int n, int d) { if(n == 1){ if(d > 0)ret += mob[cur] * pfreq[cur]; pfreq[cur] += d; if(d < 0)ret -= mob[cur] * pfreq[cur]; return; } dfs(cur, n/lpf[n], d); dfs(cur/lpf[n], n/lpf[n], d); } void change(int x, int d) { int ind = et.vs[x]; if(freq[ind] == 1){ dfs(a[ind], a[ind], -1); } freq[ind] += d; if(freq[ind] == 1){ dfs(a[ind], a[ind], 1); } } public static long[] sqrtSort(int[][] qs, int n) { int m = qs.length; long[] pack = new long[m]; int S = (int)Math.sqrt(n); for(int i = 0;i < m;i++){ pack[i] = (long)qs[i][0]/S<<50|(long)((qs[i][0]/S&1)==0?qs[i][1]:(1<<25)-1-qs[i][1])<<25|i; } Arrays.sort(pack); return pack; } public static int lca2(int a, int b, int[][] spar, int[] depth) { if (depth[a] < depth[b]) { b = ancestor(b, depth[b] - depth[a], spar); } else if (depth[a] > depth[b]) { a = ancestor(a, depth[a] - depth[b], spar); } if (a == b) return a; int sa = a, sb = b; for (int low = 0, high = depth[a], t = Integer.highestOneBit(high), k = Integer .numberOfTrailingZeros(t); t > 0; t >>>= 1, k--) { if ((low ^ high) >= t) { if (spar[k][sa] != spar[k][sb]) { low |= t; sa = spar[k][sa]; sb = spar[k][sb]; } else { high = low | t - 1; } } } return spar[0][sa]; } protected static int ancestor(int a, int m, int[][] spar) { for (int i = 0; m > 0 && a != -1; m >>>= 1, i++) { if ((m & 1) == 1) a = spar[i][a]; } return a; } public static int[][] logstepParents(int[] par) { int n = par.length; int m = Integer.numberOfTrailingZeros(Integer.highestOneBit(n - 1)) + 1; int[][] pars = new int[m][n]; pars[0] = par; for (int j = 1; j < m; j++) { for (int i = 0; i < n; i++) { pars[j][i] = pars[j - 1][i] == -1 ? -1 : pars[j - 1][pars[j - 1][i]]; } } return pars; } public static class EulerTour { public int[] vs; public int[] first; public int[] last; public EulerTour(int[] vs, int[] f, int[] l) { this.vs = vs; this.first = f; this.last = l; } } public static EulerTour nodalEulerTour(int[][] g, int root) { int n = g.length; int[] vs = new int[2*n]; int[] f = new int[n]; int[] l = new int[n]; int p = 0; Arrays.fill(f, -1); int[] stack = new int[n]; int[] inds = new int[n]; int sp = 0; stack[sp++] = root; outer: while(sp > 0){ int cur = stack[sp-1], ind = inds[sp-1]; if(ind == 0){ vs[p] = cur; f[cur] = p; p++; } while(ind < g[cur].length){ int nex = g[cur][ind++]; if(f[nex] == -1){ inds[sp-1] = ind; stack[sp] = nex; inds[sp] = 0; sp++; continue outer; } } inds[sp-1] = ind; if(ind == g[cur].length){ vs[p] = cur; l[cur] = p; p++; sp--; } } return new EulerTour(vs, f, l); } public static int[][] parents3(int[][] g, int root) { int n = g.length; int[] par = new int[n]; Arrays.fill(par, -1); int[] depth = new int[n]; depth[0] = 0; int[] q = new int[n]; q[0] = root; for (int p = 0, r = 1; p < r; p++) { int cur = q[p]; for (int nex : g[cur]) { if (par[cur] != nex) { q[r++] = nex; par[nex] = cur; depth[nex] = depth[cur] + 1; } } } return new int[][] { par, q, depth }; } static int[][] packU(int n, int[] from, int[] to) { int[][] g = new int[n][]; int[] p = new int[n]; for (int f : from) p[f]++; for (int t : to) p[t]++; for (int i = 0; i < n; i++) g[i] = new int[p[i]]; for (int i = 0; i < from.length; i++) { g[from[i]][--p[from[i]]] = to[i]; g[to[i]][--p[to[i]]] = from[i]; } return g; } public static int[] enumLowestPrimeFactors(int n) { int tot = 0; int[] lpf = new int[n + 1]; int u = n + 32; double lu = Math.log(u); int[] primes = new int[(int) (u / lu + u / lu / lu * 1.5)]; for (int i = 2; i <= n; i++) lpf[i] = i; for (int p = 2; p <= n; p++) { if (lpf[p] == p) primes[tot++] = p; int tmp; for (int i = 0; i < tot && primes[i] <= lpf[p] && (tmp = primes[i] * p) <= n; i++) { lpf[tmp] = primes[i]; } } return lpf; } void run() throws Exception { is = INPUT.isEmpty() ? System.in : new ByteArrayInputStream(INPUT.getBytes()); out = new PrintWriter(System.out); long s = System.currentTimeMillis(); solve(); out.flush(); if(!INPUT.isEmpty())tr(System.currentTimeMillis()-s+"ms"); } public static void main(String[] args) throws Exception { new F().run(); } private byte[] inbuf = new byte[1024]; public int lenbuf = 0, ptrbuf = 0; private int readByte() { if(lenbuf == -1)throw new InputMismatchException(); if(ptrbuf >= lenbuf){ ptrbuf = 0; try { lenbuf = is.read(inbuf); } catch (IOException e) { throw new InputMismatchException(); } if(lenbuf <= 0)return -1; } return inbuf[ptrbuf++]; } private boolean isSpaceChar(int c) { return !(c >= 33 && c <= 126); } private int skip() { int b; while((b = readByte()) != -1 && isSpaceChar(b)); return b; } private double nd() { return Double.parseDouble(ns()); } private char nc() { return (char)skip(); } private String ns() { int b = skip(); StringBuilder sb = new StringBuilder(); while(!(isSpaceChar(b))){ sb.appendCodePoint(b); b = readByte(); } return sb.toString(); } private char[] ns(int n) { char[] buf = new char[n]; int b = skip(), p = 0; while(p < n && !(isSpaceChar(b))){ buf[p++] = (char)b; b = readByte(); } return n == p ? buf : Arrays.copyOf(buf, p); } private char[][] nm(int n, int m) { char[][] map = new char[n][]; for(int i = 0;i < n;i++)map[i] = ns(m); return map; } private int[] na(int n) { int[] a = new int[n]; for(int i = 0;i < n;i++)a[i] = ni(); return a; } private int ni() { int num = 0, b; boolean minus = false; while((b = readByte()) != -1 && !((b >= '0' && b <= '9') || b == '-')); if(b == '-'){ minus = true; b = readByte(); } while(true){ if(b >= '0' && b <= '9'){ num = num * 10 + (b - '0'); }else{ return minus ? -num : num; } b = readByte(); } } private long nl() { long num = 0; int b; boolean minus = false; while((b = readByte()) != -1 && !((b >= '0' && b <= '9') || b == '-')); if(b == '-'){ minus = true; b = readByte(); } while(true){ if(b >= '0' && b <= '9'){ num = num * 10 + (b - '0'); }else{ return minus ? -num : num; } b = readByte(); } } private static void tr(Object... o) { System.out.println(Arrays.deepToString(o)); } }
Problem solution in C++.
#include <bits/stdc++.h> using namespace std; const int N = 25123, LN = 15; const int MaxVal = 10000000; using ll = long long; vector<int> adj[N]; int depth[N], parent[LN][N]; int ST[N], EN[N], cur_time = 0; int vec[2 * N]; void dfs(int u = 0, int d = 0, int prev = -1) { depth[u] = d; parent[0][u] = prev; ST[u] = cur_time++; vec[ST[u]] = u; for (int v : adj[u]) { if (v == prev) continue; dfs(v, d + 1, u); } EN[u] = cur_time++; vec[EN[u]] = u; } int lca(int u, int v) { if (depth[u] < depth[v]) swap(u, v); int diff = depth[u] - depth[v]; for (int i = 0; i < LN; i++) { if ((diff >> i) & 1) { u = parent[i][u]; } } if (u == v) return u; for (int i = LN - 1; i >= 0; i--) { if (parent[i][u] != parent[i][v]) { u = parent[i][u]; v = parent[i][v]; } } return parent[0][u]; } vector<int> primes[N]; vector<pair<int, int>> upd[N]; int pr[MaxVal + 1], S; ll ans[N]; bool used[N]; int vp[3][4 * N]; int main() { memset(parent, -1, sizeof parent); for (int i = 2; i <= MaxVal; i++) { if (!pr[i]) { for (int j = i + i; j <= MaxVal; j += i) { if (!pr[j]) pr[j] = i; } } } int n, q; scanf("%d %d", &n, &q); S = sqrt(2 * n); // printf("S = %dn", S); map<int, int> p1; map<tuple<int, int>, int> p2; map<tuple<int, int, int>, int> p3; for (int i = 0; i < n; i++) { int id, x; scanf("%d", &x); auto& v = primes[i]; while (pr[x]) { v.push_back(pr[x]); x /= pr[x]; } if (x > 1) { v.push_back(x); } assert(is_sorted(begin(v), end(v))); v.resize(unique(begin(v), end(v)) - begin(v)); for (int k = 0; k < v.size(); k++) { id = p1.size(); if (p1.count(v[k])) id = p1[v[k]]; else p1[v[k]] = id; upd[i].emplace_back(0, id); for (int j = k + 1; j < v.size(); j++) { auto tmp = make_tuple(v[k], v[j]); id = p2.size(); if (p2.count(tmp)) id = p2[tmp]; else p2[tmp] = id; upd[i].emplace_back(1, id); } } if (v.size() == 3) { auto tmp = make_tuple(v[0], v[1], v[2]); id = p3.size(); if (p3.count(tmp)) id = p3[tmp]; else p3[tmp] = id; upd[i].emplace_back(2, id); } } for (int i = 1; i < n; i++) { int u, v; scanf("%d %d", &u, &v); u--, v--; adj[u].push_back(v); adj[v].push_back(u); } dfs(); for (int i = 1; i < LN; i++) { for (int j = 0; j < n; j++) { if (parent[i - 1][j] != -1) { parent[i][j] = parent[i - 1][parent[i - 1][j]]; } } } using qt = tuple<int, int, int, int>; vector<qt> queries; for (int i = 0; i < q; i++) { int u, v; scanf("%d %d", &u, &v); u--, v--; if (ST[u] > ST[v]) swap(u, v); int p = lca(u, v); if (p == u) queries.emplace_back(ST[u], ST[v], i, -1); else { queries.emplace_back(EN[u], ST[v], i, p); } } sort(begin(queries), end(queries), [](const qt& a, const qt& b) -> bool { int l1, r1, i1, p1; int l2, r2, i2, p2; tie(l1, r1, i1, p1) = a; tie(l2, r2, i2, p2) = b; if (l1 / S != l2 / S) return l1 / S < l2 / S; return r1 > r2; }); int active = 0; ll cur = 0; auto insert = [&](int u) { int tmp = active; for (auto& p : upd[u]) { if (p.first & 1) tmp += vp[p.first][p.second]; else tmp -= vp[p.first][p.second]; } cur += tmp; used[u] = true; active++; for (auto& p : upd[u]) { vp[p.first][p.second]++; } }; auto remove = [&](int u) { used[u] = false; active--; for (auto& p : upd[u]) { vp[p.first][p.second]--; } int tmp = active; for (auto& p : upd[u]) { if (p.first & 1) tmp += vp[p.first][p.second]; else tmp -= vp[p.first][p.second]; } cur -= tmp; }; int L = 0, R = -1; for (auto& t : queries) { int l, r, i, p; tie(l, r, i, p) = t; // printf("%d %d %d %dn", l, r, i, p); while (R < r) { R++; if (used[vec[R]]) remove(vec[R]); else insert(vec[R]); } while (R > r) { if (used[vec[R]]) remove(vec[R]); else insert(vec[R]); R--; } while (L < l) { if (used[vec[L]]) remove(vec[L]); else insert(vec[L]); L++; } while (L > l) { L--; if (used[vec[L]]) remove(vec[L]); else insert(vec[L]); } ans[i] = cur; // printf("cur: %dn", cur); if (p != -1) { int tmp = active; for (auto& k : upd[p]) { if (k.first & 1) tmp += vp[k.first][k.second]; else tmp -= vp[k.first][k.second]; } ans[i] += tmp; } } for (int i = 0; i < q; i++) { printf("%lldn", ans[i]); } return 0; }