In this HackerRank Self-Driving Bus problem you need to find the connected segments in the tree.
Problem solution in Java Programming.
import java.io.ByteArrayInputStream; import java.io.IOException; import java.io.InputStream; import java.io.PrintWriter; import java.util.Arrays; import java.util.InputMismatchException; import java.util.Random; public class G { InputStream is; PrintWriter out; String INPUT = ""; void solve() { int n = ni(); int[] from = new int[n - 1]; int[] to = new int[n - 1]; for (int i = 0; i < n - 1; i++) { from[i] = ni() - 1; to[i] = ni() - 1; } int[][] g = packU(n, from, to); int[][] pars = parents3(g, 0); int[] par = pars[0], ord = pars[1], dep = pars[2]; int[] iord = new int[n]; for(int i = 0;i < n;i++)iord[ord[i]] = i; Node[] nodes = new Node[n]; int[] map = new int[n]; Arrays.fill(map, -1); int[] left = new int[n]; int[] right = new int[n]; for(int i = n-1;i >= 0;i--){ int cur = ord[i]; int des = count(nodes[cur]); int curind = -search(nodes[cur], cur)-1; // #less assert curind >= 0; { int low = -1, high = curind; while(high - low > 1){ int h = high+low>>>1; if(cur-get(nodes[cur], h).v == curind-h){ high = h; }else{ low = h; } } left[cur] = high + cur - curind; } { int low = curind-1, high = count(nodes[cur]); while(high - low > 1){ int h = high+low>>>1; if(get(nodes[cur], h).v-cur == h-(curind-1)){ low = h; }else{ high = h; } } right[cur] = low + cur - curind + 1; } nodes[cur] = insertb(nodes[cur], new Node(cur)); if(par[cur] != -1){ if(count(nodes[cur]) > count(nodes[par[cur]])){ Node d = nodes[cur]; nodes[cur] = nodes[par[cur]]; nodes[par[cur]] = d; } // drain while(nodes[cur] != null){ Node first = get(nodes[cur], 0); nodes[cur] = erase(nodes[cur], 0); nodes[par[cur]] = insertb(nodes[par[cur]], first); } } } int[][] rs = new int[n][]; int[][] rs2 = new int[n][]; int q = 0, q2 = 0; for(int i = 0;i < n;i++){ if(right[i]-i >= i-left[i]){ rs[q++] = new int[]{left[i], i, right[i]}; }else{ rs2[q2++] = new int[]{right[i], i, left[i]}; } } long ret = 0; ret += go(Arrays.copyOf(rs, q), n, par); for(int i = 0;i < q2;i++){ rs2[i][0] = n-1-rs2[i][0]; rs2[i][1] = n-1-rs2[i][1]; rs2[i][2] = n-1-rs2[i][2]; } for(int i = 0;i < n;i++)par[i] = n-1-par[i]; for(int i = 0, j = n-1;i < j;i++,j--){ int d = par[i]; par[i] = par[j]; par[j] = d; } for(int i = 0, j = q2-1;i < j;i++,j--){ int[] d = rs2[i]; rs2[i] = rs2[j]; rs2[j] = d; } ret += go(Arrays.copyOf(rs2, q2), n, par); out.println(ret); } long go(int[][] rs, int n, int[] par){ int m = rs.length; SegmentTreeRMQ stmin = new SegmentTreeRMQ(par); int[] stack = new int[n]; // desc ind long[] has = new long[n+1]; long[] lhas = new long[n+1]; int sp = 0; int pre = n-1; // tr(par); int[] lstack = new int[n]; // desc ind int[] lvals = new int[n]; // desc ind Arrays.fill(stack, -1); Arrays.fill(lstack, -1); long ret = 0; for(int z = m-1;z >= 0;z--){ int i = rs[z][1]; int li = rs[z][0]; int ri = rs[z][2]; while(pre > i){ while(sp > 0 && par[pre] >= par[stack[sp-1]])sp--; int ll = Math.max(pre, par[pre]); int rr = sp >= 1 ? stack[sp-1] : n; has[sp+1] = Math.max(0, rr-ll) + has[sp]; stack[sp++] = pre; pre--; } int lsp = 0; int lmin = i; for(int j = i;j >= li;j--){ int pj = j == i ? j : par[j]; while(lsp > 0 && pj >= lvals[lsp-1])lsp--; if(lsp == 0){ while(tsp > 0 && pj >= par[stack[tsp-1]])tsp--; lvals[lsp] = pj; int ll = Math.max(pre, pj); int rr = tsp >= 1 ? stack[tsp-1] : n; lhas[lsp+1] = Math.max(0, rr-ll) + lhas[lsp]; lstack[lsp++] = j; }else{ int ll = Math.max(pre, pj); int rr = lsp >= 1 ? lstack[lsp-1] : tsp >= 1 ? stack[tsp-1] : n; lvals[lsp] = pj; lhas[lsp+1] = Math.max(0, rr-ll) + lhas[lsp]; lstack[lsp++] = j; } lmin = Math.min(lmin, pj); if(lmin >= j){ int fl = stmin.firstle(i+1, j-1); if(fl == -1){ fl = ri+1; } int lright = Math.min(ri, fl-1); if(tsp-1 >= 0 && lright >= stack[tsp-1]){ int ub = upperBoundR(stack, 0, tsp, lright); int ll = Math.max(stack[ub], par[stack[ub]]); int rr = lright+1; long valid = lhas[lsp]+has[tsp]-has[ub+1] + Math.max(0, rr-ll); ret += valid; }else{ int ub = upperBoundR(lstack, 0, lsp, lright); int ll = Math.max(lstack[ub], lvals[ub]); int rr = lright+1; long valid = lhas[lsp]-lhas[ub+1] + Math.max(0, rr-ll); ret += valid; } } } } return ret; } public static int upperBoundR(int[] a, int l, int r, int v) { int low = l-1, high = r; while(high-low > 1){ int h = high+low>>>1; if(a[h] <= v){ high = h; }else{ low = h; } } return high; } public static class SegmentTreeRMQ { public int M, H, N; public int[] st; public SegmentTreeRMQ(int n) { N = n; M = Integer.highestOneBit(Math.max(N-1, 1))<<2; H = M>>>1; st = new int[M]; Arrays.fill(st, 0, M, Integer.MAX_VALUE); } public SegmentTreeRMQ(int[] a) { N = a.length; M = Integer.highestOneBit(Math.max(N-1, 1))<<2; H = M>>>1; st = new int[M]; for(int i = 0;i < N;i++){ st[H+i] = a[i]; } Arrays.fill(st, H+N, M, Integer.MAX_VALUE); for(int i = H-1;i >= 1;i--)propagate(i); } public void update(int pos, int x) { st[H+pos] = x; for(int i = (H+pos)>>>1;i >= 1;i >>>= 1)propagate(i); } private void propagate(int i) { st[i] = Math.min(st[2*i], st[2*i+1]); } public int minx(int l, int r){ if(l >= r)return 0; int min = Integer.MAX_VALUE; while(l != 0){ int f = l&-l; if(l+f > r)break; int v = st[(H+l)/f]; if(v < min)min = v; l += f; } while(l < r){ int f = r&-r; int v = st[(H+r)/f-1]; if(v < min)min = v; r -= f; } return min; } public int min(int l, int r){ return l >= r ? 0 : min(l, r, 0, H, 1);} private int min(int l, int r, int cl, int cr, int cur) { if(l <= cl && cr <= r){ return st[cur]; }else{ int mid = cl+cr>>>1; int ret = Integer.MAX_VALUE; if(cl < r && l < mid){ ret = Math.min(ret, min(l, r, cl, mid, 2*cur)); } if(mid < r && l < cr){ ret = Math.min(ret, min(l, r, mid, cr, 2*cur+1)); } return ret; } } public int firstle(int l, int v) { if(l >= N)return -1; int cur = H+l; while(true){ if(st[cur] <= v){ if(cur < H){ cur = 2*cur; }else{ return cur-H; } }else{ cur++; if((cur&cur-1) == 0)return -1; if((cur&1)==0)cur>>>=1; } } } public int lastle(int l, int v) { int cur = H+l; while(true){ if(st[cur] <= v){ if(cur < H){ cur = 2*cur+1; }else{ return cur-H; } }else{ if((cur&cur-1) == 0)return -1; cur--; if((cur&1)==1)cur>>>=1; } } } } public static Random gen = new Random(0); static public class Node { public int v; // value public long priority; public Node left, right, parent; public int count; public Node(int v) { this.v = v; priority = gen.nextLong(); update(this); } @Override public String toString() { StringBuilder builder = new StringBuilder(); builder.append("Node [v="); builder.append(v); builder.append(", count="); builder.append(count); builder.append(", parent="); builder.append(parent != null ? parent.v : "null"); builder.append("]"); return builder.toString(); } } public static Node update(Node a) { if(a == null)return null; a.count = 1; if(a.left != null)a.count += a.left.count; if(a.right != null)a.count += a.right.count; // TODO return a; } public static void propagate(Node x) { for(;x != null;x = x.parent)update(x); } public static Node disconnect(Node a) { if(a == null)return null; a.left = a.right = a.parent = null; return update(a); } public static Node root(Node x) { if(x == null)return null; while(x.parent != null)x = x.parent; return x; } public static int count(Node a) { return a == null ? 0 : a.count; } public static void setParent(Node a, Node par) { if(a != null)a.parent = par; } public static Node merge(Node a, Node b, Node... c) { Node x = merge(a, b); for(Node n : c)x = merge(x, n); return x; } public static Node merge(Node a, Node b) { if(b == null)return a; if(a == null)return b; if(a.priority > b.priority){ setParent(a.right, null); setParent(b, null); a.right = merge(a.right, b); setParent(a.right, a); return update(a); }else{ setParent(a, null); setParent(b.left, null); b.left = merge(a, b.left); setParent(b.left, b); return update(b); } } public static Node[] split(Node x) { if(x == null)return new Node[]{null, null}; if(x.left != null)x.left.parent = null; Node[] sp = new Node[]{x.left, x}; x.left = null; update(x); while(x.parent != null){ Node p = x.parent; x.parent = null; if(x == p.left){ p.left = sp[1]; if(sp[1] != null)sp[1].parent = p; sp[1] = p; }else{ p.right = sp[0]; if(sp[0] != null)sp[0].parent = p; sp[0] = p; } update(p); x = p; } return sp; } public static Node[] split(Node a, int... ks) { int n = ks.length; if(n == 0)return new Node[]{a}; for(int i = 0;i < n-1;i++){ if(ks[i] > ks[i+1])throw new IllegalArgumentException(Arrays.toString(ks)); } Node[] ns = new Node[n+1]; Node cur = a; for(int i = n-1;i >= 0;i--){ Node[] sp = split(cur, ks[i]); cur = sp[0]; ns[i] = sp[0]; ns[i+1] = sp[1]; } return ns; } // [0,K),[K,N) public static Node[] split(Node a, int K) { if(a == null)return new Node[]{null, null}; if(K <= count(a.left)){ setParent(a.left, null); Node[] s = split(a.left, K); a.left = s[1]; setParent(a.left, a); s[1] = update(a); return s; }else{ setParent(a.right, null); Node[] s = split(a.right, K-count(a.left)-1); a.right = s[0]; setParent(a.right, a); s[0] = update(a); return s; } } public static Node insertb(Node root, Node x) { int ind = search(root, x.v); if(ind < 0)ind = -ind-1; return insert(root, ind, x); } public static Node insert(Node a, int K, Node b) { if(a == null)return b; if(b.priority < a.priority){ if(K <= count(a.left)){ a.left = insert(a.left, K, b); setParent(a.left, a); }else{ a.right = insert(a.right, K-count(a.left)-1, b); setParent(a.right, a); } return update(a); }else{ Node[] ch = split(a, K); b.left = ch[0]; b.right = ch[1]; setParent(b.left, b); setParent(b.right, b); return update(b); } } // delete K-th public static Node erase(Node a, int K) { if(a == null)return null; if(K < count(a.left)){ a.left = erase(a.left, K); setParent(a.left, a); return update(a); }else if(K == count(a.left)){ setParent(a.left, null); setParent(a.right, null); Node aa = merge(a.left, a.right); disconnect(a); return aa; }else{ a.right = erase(a.right, K-count(a.left)-1); setParent(a.right, a); return update(a); } } public static Node get(Node a, int K) { while(a != null){ if(K < count(a.left)){ a = a.left; }else if(K == count(a.left)){ break; }else{ K = K - count(a.left)-1; a = a.right; } } return a; } public static int index(Node a) { if(a == null)return -1; int ind = count(a.left); while(a != null){ Node par = a.parent; if(par != null && par.right == a){ ind += count(par.left) + 1; } a = par; } return ind; } public static int search(Node a, int q) { int lcount = 0; while(a != null){ if(a.v == q){ lcount += count(a.left); break; } if(q < a.v){ a = a.left; }else{ lcount += count(a.left) + 1; a = a.right; } } return a == null ? -(lcount+1) : lcount; } public static Node next(Node x) { if(x == null)return null; if(x.right != null){ x = x.right; while(x.left != null)x = x.left; return x; }else{ while(true){ Node p = x.parent; if(p == null)return null; if(p.left == x)return p; x = p; } } } public static Node prev(Node x) { if(x == null)return null; if(x.left != null){ x = x.left; while(x.right != null)x = x.right; return x; }else{ while(true){ Node p = x.parent; if(p == null)return null; if(p.right == x)return p; x = p; } } } public static Node[] nodes(Node a) { return nodes(a, new Node[a.count], 0, a.count); } public static Node[] nodes(Node a, Node[] ns, int L, int R) { if(a == null)return ns; nodes(a.left, ns, L, L+count(a.left)); ns[L+count(a.left)] = a; nodes(a.right, ns, R-count(a.right), R); return ns; } public static String toString(Node a, String indent) { if(a == null)return ""; StringBuilder sb = new StringBuilder(); sb.append(toString(a.left, indent + " ")); sb.append(indent).append(a).append("n"); sb.append(toString(a.right, indent + " ")); return sb.toString(); } public static int[][] parents3(int[][] g, int root) { int n = g.length; int[] par = new int[n]; Arrays.fill(par, -1); int[] depth = new int[n]; depth[0] = 0; int[] q = new int[n]; q[0] = root; for (int p = 0, r = 1; p < r; p++) { int cur = q[p]; for (int nex : g[cur]) { if (par[cur] != nex) { q[r++] = nex; par[nex] = cur; depth[nex] = depth[cur] + 1; } } } return new int[][] { par, q, depth }; } static int[][] packU(int n, int[] from, int[] to) { int[][] g = new int[n][]; int[] p = new int[n]; for (int f : from) p[f]++; for (int t : to) p[t]++; for (int i = 0; i < n; i++) g[i] = new int[p[i]]; for (int i = 0; i < from.length; i++) { g[from[i]][--p[from[i]]] = to[i]; g[to[i]][--p[to[i]]] = from[i]; } return g; } void run() throws Exception { is = INPUT.isEmpty() ? System.in : new ByteArrayInputStream(INPUT.getBytes()); out = new PrintWriter(System.out); long s = System.currentTimeMillis(); solve(); out.flush(); if(!INPUT.isEmpty())tr(System.currentTimeMillis()-s+"ms"); } public static void main(String[] args) throws Exception { new G().run(); } private byte[] inbuf = new byte[1024]; private int lenbuf = 0, ptrbuf = 0; private int readByte() { if(lenbuf == -1)throw new InputMismatchException(); if(ptrbuf >= lenbuf){ ptrbuf = 0; try { lenbuf = is.read(inbuf); } catch (IOException e) { throw new InputMismatchException(); } if(lenbuf <= 0)return -1; } return inbuf[ptrbuf++]; } private boolean isSpaceChar(int c) { return !(c >= 33 && c <= 126); } private int skip() { int b; while((b = readByte()) != -1 && isSpaceChar(b)); return b; } private double nd() { return Double.parseDouble(ns()); } private char nc() { return (char)skip(); } private String ns() { int b = skip(); StringBuilder sb = new StringBuilder(); while(!(isSpaceChar(b))){ // when nextLine, (isSpaceChar(b) && b != ' ') sb.appendCodePoint(b); b = readByte(); } return sb.toString(); } private char[] ns(int n) { char[] buf = new char[n]; int b = skip(), p = 0; while(p < n && !(isSpaceChar(b))){ buf[p++] = (char)b; b = readByte(); } return n == p ? buf : Arrays.copyOf(buf, p); } private char[][] nm(int n, int m) { char[][] map = new char[n][]; for(int i = 0;i < n;i++)map[i] = ns(m); return map; } private int[] na(int n) { int[] a = new int[n]; for(int i = 0;i < n;i++)a[i] = ni(); return a; } private int ni() { int num = 0, b; boolean minus = false; while((b = readByte()) != -1 && !((b >= '0' && b <= '9') || b == '-')); if(b == '-'){ minus = true; b = readByte(); } while(true){ if(b >= '0' && b <= '9'){ num = num * 10 + (b - '0'); }else{ return minus ? -num : num; } b = readByte(); } } private long nl() { long num = 0; int b; boolean minus = false; while((b = readByte()) != -1 && !((b >= '0' && b <= '9') || b == '-')); if(b == '-'){ minus = true; b = readByte(); } while(true){ if(b >= '0' && b <= '9'){ num = num * 10 + (b - '0'); }else{ return minus ? -num : num; } b = readByte(); } } private static void tr(Object... o) { System.out.println(Arrays.deepToString(o)); } }
Problem solution in C++ programming.
#ifdef _MSC_VER #define _CRT_SECURE_NO_WARNINGS #endif #include<iostream> #include<cstdio> #include<algorithm> #include<set> #include<map> #include<queue> #include<vector> #include<string> #include<cstring> #include<unordered_map> #include<cassert> #include<cmath> #define dri(X) int (X); scanf("%d", &X) #define drii(X, Y) int X, Y; scanf("%d%d", &X, &Y) #define driii(X, Y, Z) int X, Y, Z; scanf("%d%d%d", &X, &Y, &Z) #define pb push_back #define mp make_pair #define rep(i, s, t) for ( int i=(s) ; i <(t) ; i++) #define fill(x, v) memset (x, v, sizeof(x)) #define all(x) (x).begin(), (x).end() #define why(d) cerr << (d) << "!n" #define whisp(X, Y) cerr << (X) << " " << (Y) << "#n" #define exclam cerr << "!!n" #define left(p) (p << 1) #define right(p) ((p << 1) + 1) #define mid ((l + r) >> 1) typedef long long ll; using namespace std; typedef pair<int, int> pii; const ll inf = (ll)1e9 + 70; const ll mod = 1e9 + 7; const int maxn = 2e5 + 1000; vector<int> adj [maxn]; bool used[maxn]; int sz[maxn];//subtree sizes int mark[maxn]; int tt = 0; int maax[maxn]; int miin[maxn]; int goright[maxn]; int goleft[maxn]; int val[maxn]; int ST[4 * maxn];// an all-purpose ST: min, max, and sum!! int query(int p, int l, int r, int i, int j){//sum query if (l > j || r < i) return 0; if (i <= l && r <= j){ return ST[p]; } return query(left(p), l, mid, i, j) + query(right(p), mid + 1, r, i, j); } void update(int p, int l, int r, int i, int delta){ if (l > i || r < i) return; if (l == r){ ST[p] += delta; return; } update(left(p), l, mid, i, delta); update(right(p), mid + 1, r, i, delta); ST[p] = ST[left(p)] + ST[right(p)]; } void buildtree(int p, int l, int r, bool m){ if (l == r){ ST[p] = val[l]; return; } buildtree(left(p), l, mid, m); buildtree(right(p), mid + 1, r, m); if (m) ST[p] = min(ST[left(p)], ST[right(p)]); else ST[p] = max(ST[left(p)], ST[right(p)]); } vector<pair<int, pii>> blocks; void decompose(int p, int l, int r, int i){ if (r < i) return; if (l >= i){ blocks.push_back(mp(p, pii(l, r))); return; } decompose(left(p), l, mid, i); decompose(right(p), mid + 1, r, i); } void decompose2(int p, int l, int r, int i){ if (l > i) return; if (r <= i){ blocks.push_back(mp(p, pii(l, r))); return; } decompose2(left(p), l, mid, i); decompose2(right(p), mid + 1, r, i); } int find(int p, int l, int r, int x){ assert(ST[p] < x); if (l == r) return l; if (ST[left(p)] >= x){ return find(right(p), mid + 1, r, x); } return find(left(p), l, mid, x); } int find2(int p, int l, int r, int x){ assert(ST[p] > x); if (l == r) return l; if (ST[right(p)] <= x){ return find2(left(p), l, mid, x); } return find2(right(p), mid + 1, r, x); } void dfs(int v, int p){ mark[v] = tt; sz[v] = 1; if (p == -1){ maax[v] = v; miin[v] = v; } else{ maax[v] = max(maax[p], v); miin[v] = min(miin[p], v); } for (int u : adj[v]){ if (u == p || used[u]) continue; dfs(u, v); sz[v] += sz[u]; } } ll perform(int v, int n){ if (n == 1){ return 1; } //first, FIND the centroid. dfs(v, -1); int g = v; int p = -1; while (true){ int w = -1; for (int h : adj[g]){ if (h == p || used[h]) continue; if (w == -1 || sz[h] > sz[w]) w = h; } assert(w != -1);//g should NOT be a leaf. if (2 * sz[w] <= n){ break; } p = g; g = w; } //g is the centroid. tt++; dfs(g, -1); //here comes the HEART OF THE ALGORITHM. int m = -800; for (int l = g; l > 0; l--){ if (mark[l] != tt) break; m = l; } int M = -800; for (int r = g; r < maxn; r++){ if (mark[r] != tt) break; M = r; } //Our working interval is m <= i <= M. rep(i, m, M + 1){ val[i] = miin[i]; //cout << miin[i] << " "; }//cout << endl; buildtree(1, m, M, true); rep(i, m, M + 1){ if (miin[i] < i){ goright[i] = -inf; continue; } blocks.clear(); decompose(1, m, M, i); reverse(blocks.begin(), blocks.end()); while (!blocks.empty() && ST[blocks.back().first] >= i) blocks.pop_back(); if (blocks.empty()){ goright[i] = M; } else{ auto ee = blocks.back(); goright[i] = find(ee.first, ee.second.first, ee.second.second, i) - 1; } } //rep(i, m, M + 1)cout << goright[i] << " "; cout << endl; //now, goleft! rep(i, m, M + 1) val[i] = maax[i]; //rep(i, m, M + 1) cout << maax[i] << " "; cout << endl; buildtree(1, m, M, false); rep(i, m, M + 1){ if (maax[i] > i){ goleft[i] = inf; continue; } blocks.clear(); decompose2(1, m, M, i); while (!blocks.empty() && ST[blocks.back().first] <= i) blocks.pop_back(); if (blocks.empty()){ goleft[i] = m; } else{ auto ee = blocks.back(); goleft[i] = find2(ee.first, ee.second.first, ee.second.second, i) + 1; } } //rep(i, m, M + 1) cout << goleft[i] << " "; cout << endl; vector<pii> rs; rep(r, m, M + 1){ if (goleft[r] != inf) rs.push_back(pii(goleft[r], r)); } sort(all(rs)); reverse(all(rs)); rep(i, m, M + 1) val[i] = 0; buildtree(1, m, M, true);//basically: just clear it. ll ans = 0; for (int l = m; l <= M; l++){ //whisp(l, goright[l]); while (!rs.empty() && rs.back().first == l){ update(1, m, M, rs.back().second, 1); //cout << "update " << rs.back().second << "n"; rs.pop_back(); } //cout << query(1, m, M, l, goright[l]) << "n"; ans += query(1, m, M, l, goright[l]); } used[g] = true; vector<pii> ls; for (int u : adj[g]){ if (used[u]) continue; ls.push_back(pii(u, sz[u])); } for (pii x : ls){ ans += perform(x.first, x.second); } return ans; } int main(){ if (fopen("input.txt", "r")) freopen("input.txt", "r", stdin); dri(n); rep(i, 1, n){ drii(a, b); adj[a].push_back(b); adj[b].push_back(a); } cout << perform(1, n) << "n"; return 0; }
Problem solution in C programming.
#include <stdio.h> #include <string.h> #include <math.h> #include <stdlib.h> #define fprintf(...) struct vertex { struct vertex* parent; int rank; int count; }; struct vertex* vfind(struct vertex *v) { if (v->parent == NULL) return v; // this is a disconnected one. if (v->parent != v) { v->parent = vfind(v->parent); } return v->parent; } struct vertex* vunion(struct vertex *x, struct vertex* y) { struct vertex *xroot = vfind(x); struct vertex *yroot = vfind(y); if (xroot == yroot) return yroot; // fix any uninitialized counts. if (xroot->count == 0) xroot->count++; if (yroot->count == 0) yroot->count++; if (xroot->rank > yroot->rank) { struct vertex* tmp = xroot; xroot = yroot; yroot = tmp; } // xroot is now the smaller tree if they're not the same. if (xroot->rank == yroot->rank) { yroot->rank++; } xroot->parent = yroot; yroot->count += xroot->count; return yroot; } struct edge { int a, b; }; int ecmp(const void*a_in, const void*b_in) { const struct edge* a = a_in; const struct edge* b = b_in; if (a->b < b->b) return -1; if (a->b > b->b) return 1; if (a->a < b->a) return -1; if (a->a > b->a) return 1; return 0; } int ecmp_a(const void*a_in, const void*b_in) { const struct edge* a = a_in; const struct edge* b = b_in; if (a->a < b->a) return -1; if (a->a > b->a) return 1; if (a->b < b->b) return -1; if (a->b > b->b) return 1; return 0; } // n * ack-1(n) algorithm; needs to be run n times for n^2 ack-1(n). Not the best, but gets 50%. int count_components1(int start, struct edge* edges, int ne, int n) { if (ne == 0) return 1; fprintf(stderr, "start: %d, ne %d, n %dn", start, ne, n); int max_components = n - start + 1; struct vertex v[max_components]; memset(v, 0, sizeof(v)); int components = 1; struct edge* le = edges + ne; for (int maxv = start + 1; maxv <= n; maxv++) { struct vertex* join = NULL; while (edges < le && edges->b <= maxv) { if (edges->a >= start) { join = vunion(&v[edges->a - start], &v[edges->b - start]); fprintf(stderr, "Join: %d to %d, new count %dn", edges->a, edges->b, join->count); } edges++; } if (join && join->count == maxv - start + 1) components++; } return components; } int count_components(int start, struct edge* edges, int ne, int n) { if (ne == 0) return 1; fprintf(stderr, "start: %d, ne %d, n %dn", start, ne, n); int max_components = n - start + 1; struct vertex v[max_components]; memset(v, 0, sizeof(v)); int components = 1; struct edge* le = edges + ne; for (int maxv = start + 1; maxv <= n; maxv++) { struct vertex* join = NULL; while (edges < le && edges->b <= maxv) { if (edges->a >= start) { join = vunion(&v[edges->a - start], &v[edges->b - start]); fprintf(stderr, "Join: %d to %d, new count %dn", edges->a, edges->b, join->count); } edges++; } if (join && join->count == maxv - start + 1) components++; } return components; } int old_main() { int n; scanf("%dn", &n); struct edge edges[n-1]; memset(edges, 0, sizeof(edges)); for (int i = 0; i < n-1; i++) { int e1, e2; scanf("%d %dn", &e1, &e2); if (e1 < e2) { edges[i].a = e1; edges[i].b = e2; } else { edges[i].a = e2; edges[i].b = e1; } } qsort(edges, n-1, sizeof(struct edge), ecmp); for (int i = 0; i < n-1; i++) { fprintf(stderr, "Edge: %d %dn", edges[i].a, edges[i].b); } int result = 0; struct edge *ep = edges; struct edge *lp = edges + n - 1; for (int i = 1; i <= n; i++) { while(ep < lp && ep->a < i) ep++; int cc = count_components(i, ep, lp - ep, n); fprintf(stderr, "i: %d cc: %dn", i, cc); result += cc; } printf("%dn", result); return 0; } struct node { int nn; // indexes of forward edges in the node. // Edges always belong to the low node. int first_edge; int n_edges; }; struct line { int start_node; int end_node; }; struct segment_node { int lazy; int max_v; // maximum value of any node below int num_v; // number of nodes with that maximum value }; #define C1(i) ((i)*2+1) #define C2(i) ((i)*2+2) void propagate(struct segment_node* tree, int index, int start, int end) { if (!tree[index].lazy) return; if (start == end) { // leaf, nothing to do; tree[index].lazy = 0; return; } fprintf(stderr, "Prop: %d v: %dn", index, tree[index].lazy); tree[C1(index)].lazy += tree[index].lazy; tree[C2(index)].lazy += tree[index].lazy; tree[C1(index)].max_v += tree[index].lazy; tree[C2(index)].max_v += tree[index].lazy; tree[index].lazy = 0; } // ns, ne = node start/end = recursion counter // rs, re = input range start/end // adds "v" to all nodes between rs and re. // the segment tree is implicitly "complete", i.e. contains all integers in [ns, ne] int treelim; void update(struct segment_node* tree, int index, int ns, int ne, int rs, int re, int v) { if (index >= treelim) exit(-1); fprintf(stderr, "upd: i: %d ns,ne (%d %d) rs, re (%d %d), v %dn", index, ns, ne, rs, re, v); fprintf(stderr, " prev max_v %d num_v %dn", tree[index].max_v, tree[index].num_v); propagate(tree, index, ns, ne); if (ns == rs && ne == re) { tree[index].max_v += v; if (ns == ne) tree[index].num_v = 1; tree[index].lazy += v; return; } int mid = (ns + ne) / 2; if (re <= mid) update(tree, C1(index), ns, mid, rs, re, v); else if (rs > mid) update(tree, C2(index), mid + 1, ne, rs, re, v); else { update(tree, C1(index), ns, mid, rs, mid, v); update(tree, C2(index), mid + 1, ne, mid + 1, re, v); } // now up-propagate. if (tree[C1(index)].max_v > tree[C2(index)].max_v) { fprintf(stderr, "C1n"); tree[index].max_v = tree[C1(index)].max_v; tree[index].num_v = tree[C1(index)].num_v; } else if (tree[C1(index)].max_v < tree[C2(index)].max_v) { fprintf(stderr, "C2n"); tree[index].max_v = tree[C2(index)].max_v; tree[index].num_v = tree[C2(index)].num_v; } else { fprintf(stderr, "BBn"); tree[index].max_v = tree[C1(index)].max_v; tree[index].num_v = tree[C1(index)].num_v + tree[C2(index)].num_v; } fprintf(stderr, "upd done: %d max_v %d num_v %dn", index, tree[index].max_v, tree[index].num_v); } int main() { int n; scanf("%dn", &n); struct node nodes[n]; struct edge edges[n-1]; int nl = 0; memset(nodes, 0, sizeof(nodes)); memset(edges, 0, sizeof(edges)); for (int i = 0; i < n-1; i++) { int e1, e2; scanf("%d %dn", &e1, &e2); if (e1 < e2) { edges[i].a = e1; edges[i].b = e2; } else { edges[i].a = e2; edges[i].b = e1; } } qsort(edges, n-1, sizeof(struct edge), ecmp); for (int i = 0; i < n; i++) { nodes[i].nn = i+1; } int cur_node = -1; for (int i = 0; i < n-1; i++) { fprintf(stderr, "Edge %d: %d %dn", i, edges[i].a, edges[i].b); if (edges[i].b - 1 > cur_node) { for (int j = cur_node + 1; j < edges[i].b - 1; j++) { // Make the zero-edge nodes have a "first edge" that makes sense nodes[j].first_edge = i; } if (cur_node >= 0) { nodes[cur_node].n_edges = i - nodes[cur_node].first_edge; } cur_node = edges[i].b - 1; nodes[cur_node].first_edge = i; } } fprintf(stderr, "n:%d, cur_node %d %dn", n, cur_node, nodes[cur_node].nn); nodes[cur_node].n_edges = n - 1 - nodes[cur_node].first_edge; while (++cur_node < n) { nodes[cur_node].first_edge = n - 1; } for (int i = 0; i < n; i++) { fprintf(stderr, "Node: %d edges start at %d nedges %dn", nodes[i].nn, nodes[i].first_edge, nodes[i].n_edges); } long result = 0; treelim = 1<<((int)ceil(log2(n)) + 1); struct segment_node stree[treelim]; memset(stree, 0, sizeof(stree)); for (int i = 0; i < n; i++) { for (int j = nodes[i].first_edge; j < nodes[i].first_edge + nodes[i].n_edges; j++) { update(stree, 0, 1, n, 1, edges[j].a, 1); } // Adds the current vertex. update(stree, 0, 1, n, nodes[i].nn, nodes[i].nn, nodes[i].nn); fprintf(stderr, "Node %d max_v %d num_v %dn", nodes[i].nn, stree[0].max_v, stree[0].num_v); if (stree[0].max_v == nodes[i].nn) result += stree[0].num_v; } printf("%ldn", result); return 0; }