HackerRank Tripartite Matching problem solution

In this HackerRank Tripartite Matching problem solution, You are given 3 unweighted, undirected graphs, G1, G2, and G3, with n vertices each, where the kth graph has mk edges and the vertices in each graph are numbered from 1 through n. Find the number of ordered triples (a,b,c), where 1 <= a, b, c <= n a != b, b != c, c != a, such that there is an edge (a,b) in G1, an edge (b,c) in G2, and an edge (c,a) in G3.

HackerRank Tripartite Matching problem solution

Problem solution in Python.

#!/bin/python3

import os
import sys

#
# Complete the tripartiteMatching function below.
#
vertices = int(input())
graph = [[], [], [], []]
count = 0
G1 = 1
G2 = 2
G3 = 3

for i in range(1, 4):
    for j in range(0, vertices + 1):
        graph[i].append(None)

for i in range(1, 4):
    edges = int(input())
    for j in range(0, edges):
        edge = list(map(int, input().split(" ")))
        if graph[i][edge[0]] is None:
            graph[i][edge[0]] = set()
        graph[i][edge[0]].add(edge[1])
        if graph[i][edge[1]] is None:
            graph[i][edge[1]] = set()
        graph[i][edge[1]].add(edge[0])

for vertex in range(1, vertices + 1):
    verticesToG1 = graph[G1][vertex]
    if verticesToG1 is not None:
        verticesFromG2 = graph[G2][vertex]
        if verticesFromG2 is not None:
            for toVertex in verticesFromG2:
                verticesFromG3 = graph[G3][toVertex]
                if verticesFromG3 is not None:
                    count = count + len(verticesToG1.intersection(verticesFromG3))
        
print(count)

{“mode”:”full”,”isActive”:false}

Problem solution in Java.

import java.io.ByteArrayInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.PrintWriter;
import java.util.Arrays;
import java.util.InputMismatchException;

public class D {
    InputStream is;
    PrintWriter out;
    String INPUT = "";
    
    int[][] read(int n)
    {
        int m = ni();
        int[] from = new int[m];
        int[] to = new int[m];
        for(int i = 0;i < m;i++){
            from[i] = ni()-1;
            to[i] = ni()-1;
        }
        return packU(n, from, to);
    }
    
    void solve()
    {
        int n = ni();
        int[][] ga = read(n);
        int[][] gb = read(n);
        int[][] gc = read(n);
        
        int S = (int)Math.sqrt(100000);
        long[][] gbb = new long[n][];
        long[][] gbc = new long[n][];
        for(int i = 0;i < n;i++){
            if(gb[i].length >= S){
                gbb[i] = new long[(n>>>6)+1];
                for(int e : gb[i]){
                    gbb[i][e>>>6] |= 1L<<e;
                }
            }
            if(gc[i].length >= S){
                gbc[i] = new long[(n>>>6)+1];
                for(int e : gc[i]){
                    gbc[i][e>>>6] |= 1L<<e;
                }
            }
            Arrays.sort(gb[i]);
            Arrays.sort(gc[i]);
        }
        
        int na = ga.length;
        long ret = 0;
        for(int a = 0;a < na;a++){
            for(int b : ga[a]){
                if(gbb[b] != null){
                    if(gbc[a] != null){
                        for(int i = 0;i < (n>>>6)+1;i++){
                            ret += Long.bitCount(gbb[b][i]&gbc[a][i]);
                        }
                    }else{
                        for(int e : gc[a]){
                            if(gbb[b][e>>>6]<<~e<0)ret++;
                        }
                    }
                }else{
                    if(gbc[a] != null){
                        for(int e : gb[b]){
                            if(gbc[a][e>>>6]<<~e<0)ret++;
                        }
                    }else{
                        for(int i = 0, j = 0;i < gb[b].length && j < gc[a].length;){
                            if(gb[b][i] == gc[a][j]){
                                ret++; i++; j++;
                            }else if(gb[b][i] < gc[a][j]){
                                i++;
                            }else{
                                j++;
                            }
                        }
                    }
                }
            }
        }
        out.println(ret);
    }
    
    static int[][] packU(int n, int[] from, int[] to) {
        int[][] g = new int[n][];
        int[] p = new int[n];
        for (int f : from)
            p[f]++;
        for (int t : to)
            p[t]++;
        for (int i = 0; i < n; i++)
            g[i] = new int[p[i]];
        for (int i = 0; i < from.length; i++) {
            g[from[i]][--p[from[i]]] = to[i];
            g[to[i]][--p[to[i]]] = from[i];
        }
        return g;
    }
    
    void run() throws Exception
    {
        is = INPUT.isEmpty() ? System.in : new ByteArrayInputStream(INPUT.getBytes());
        out = new PrintWriter(System.out);
        
        long s = System.currentTimeMillis();
        solve();
        out.flush();
        if(!INPUT.isEmpty())tr(System.currentTimeMillis()-s+"ms");
    }
    
    public static void main(String[] args) throws Exception { new D().run(); }
    
    private byte[] inbuf = new byte[1024];
    private int lenbuf = 0, ptrbuf = 0;
    
    private int readByte()
    {
        if(lenbuf == -1)throw new InputMismatchException();
        if(ptrbuf >= lenbuf){
            ptrbuf = 0;
            try { lenbuf = is.read(inbuf); } catch (IOException e) { throw new InputMismatchException(); }
            if(lenbuf <= 0)return -1;
        }
        return inbuf[ptrbuf++];
    }
    
    private boolean isSpaceChar(int c) { return !(c >= 33 && c <= 126); }
    private int skip() { int b; while((b = readByte()) != -1 && isSpaceChar(b)); return b; }
    
    private double nd() { return Double.parseDouble(ns()); }
    private char nc() { return (char)skip(); }
    
    private String ns()
    {
        int b = skip();
        StringBuilder sb = new StringBuilder();
        while(!(isSpaceChar(b))){ // when nextLine, (isSpaceChar(b) && b != ' ')
            sb.appendCodePoint(b);
            b = readByte();
        }
        return sb.toString();
    }
    
    private char[] ns(int n)
    {
        char[] buf = new char[n];
        int b = skip(), p = 0;
        while(p < n && !(isSpaceChar(b))){
            buf[p++] = (char)b;
            b = readByte();
        }
        return n == p ? buf : Arrays.copyOf(buf, p);
    }
    
    private char[][] nm(int n, int m)
    {
        char[][] map = new char[n][];
        for(int i = 0;i < n;i++)map[i] = ns(m);
        return map;
    }
    
    private int[] na(int n)
    {
        int[] a = new int[n];
        for(int i = 0;i < n;i++)a[i] = ni();
        return a;
    }
    
    private int ni()
    {
        int num = 0, b;
        boolean minus = false;
        while((b = readByte()) != -1 && !((b >= '0' && b <= '9') || b == '-'));
        if(b == '-'){
            minus = true;
            b = readByte();
        }
        
        while(true){
            if(b >= '0' && b <= '9'){
                num = num * 10 + (b - '0');
            }else{
                return minus ? -num : num;
            }
            b = readByte();
        }
    }
    
    private long nl()
    {
        long num = 0;
        int b;
        boolean minus = false;
        while((b = readByte()) != -1 && !((b >= '0' && b <= '9') || b == '-'));
        if(b == '-'){
            minus = true;
            b = readByte();
        }
        
        while(true){
            if(b >= '0' && b <= '9'){
                num = num * 10 + (b - '0');
            }else{
                return minus ? -num : num;
            }
            b = readByte();
        }
    }
    
    private static void tr(Object... o) { System.out.println(Arrays.deepToString(o)); }
}

{“mode”:”full”,”isActive”:false}

Problem solution in C++.

#include <iostream>
#include <cstdio>
#include <string.h>
#include <algorithm>
#include <vector>
#include <string>
#include <queue>
#include <stack>
#include <set>
#include <map>
#include <sstream>
#include <cmath>

typedef long long ll;
typedef unsigned int uint;

#define forn(i, n) for (int i = 0; i < (int)(n); i++)
#define forv(i, v) forn(i, v.size())
#define all(v) v.begin(), v.end()
#define pb push_back

using namespace std;

const int MAGIC = 32;

int n;
int bSize;

vector< vector<int> > g[3];
int m[3];

struct AdjSet {
    set<int>* s;
    uint* b;
    AdjSet() {
        s = NULL;
        b = NULL;
    }
};

vector<AdjSet> adj[2];

void buildAdjSets(vector<AdjSet>& as, const vector< vector<int> >& g) {
    forn(i, n) {
        if (g[i].size() < MAGIC) {
            as[i].s = new set<int>();
            for (int j : g[i]) {
                as[i].s->insert(j);
            }
        } else {
            as[i].b = new uint[bSize];
            memset(as[i].b, 0, bSize * 4);
            for (int j : g[i]) {
                as[i].b[j >> 5] |= 1 << (j & 31);
            }
        }
    }
}

int ones[1 << 16];

int calcIntersectionSize(const AdjSet& a1, const AdjSet& a2) {
    if (a1.s != NULL && a2.s != NULL) {
        int res = 0;
        for (int v : *a1.s) {
            if (a2.s->count(v)) res++;
        }
        return res;
    }
    
    if (a1.s != NULL || a2.s != NULL) {
        if (a2.s != NULL) return calcIntersectionSize(a2, a1);
        int res = 0;
        for (int v : *a1.s) {
            if (a2.b[v >> 5] & (1 << (v & 31))) res++;
        }
        return res;
    }
    
    int res = 0;
    forn(i, bSize) {
        uint x = a1.b[i] & a2.b[i];
        res += ones[x >> 16] + ones[x & 65535];
    }
    return res;
}

int main() {
#ifdef NEREVAR_PROJECT
    freopen("input.txt", "r", stdin);
    freopen("output.txt", "w", stdout);
#endif
    forn(i, 1 << 16) {
        forn(j, 16) ones[i] += (i >> j) & 1;
    }
    cin >> n;
    bSize = (n + 31) / 32;
    forn(k, 3) {
        cin >> m[k];
        g[k] = vector< vector<int> >(n);
        forn(i, m[k]) {
            int x, y;
            scanf("%d %d", &x, &y);
            x--, y--;
            g[k][x].push_back(y);
            g[k][y].push_back(x);
        }
    }
    forn(i, 2) {
        adj[i] = vector<AdjSet>(n);
        buildAdjSets(adj[i], g[i + 1]);
    }
    ll ans = 0;
    forn(i, n) {
        for (int j : g[0][i]) {
            ans += calcIntersectionSize(adj[1][i], adj[0][j]);
        }
    }
    cout << ans << endl;
    return 0;
}

{“mode”:”full”,”isActive”:false}

Problem solution in C.

#include <assert.h>
#include <limits.h>
#include <math.h>
#include <stdbool.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>

char* readline();
char** split_string(char*);

struct edge{
    int from;
    int to;
};

bool precedge(struct edge e1, struct edge e2){
    return e1.from < e2.from || (e1.from == e2.from && e1.to < e2.to);
}

void setup(int n, int m, int** edges, struct edge *sortedge, int* edgebds){
    for(int i = 0; i < m; i++){
        sortedge[2*i].from = edges[i][0] - 1;
        sortedge[2*i].to = edges[i][1] - 1;
        sortedge[2*i + 1].from = edges[i][1] - 1;
        sortedge[2*i + 1].to = edges[i][0] - 1;
    }

    for(int i = 0; i < 2*m; i++){
        int curr = i;
        while(curr > 0){
            int next = (curr - 1)/2;
            if(precedge(sortedge[next], sortedge[curr])){
                struct edge temp = sortedge[curr];
                sortedge[curr] = sortedge[next];
                sortedge[next] = temp;
                curr = next;
            }
            else{
                break;
            }
        }
    }

    for(int i = 2*m - 1; i >= 0; i--){
        struct edge temp = sortedge[0];
        sortedge[0] = sortedge[i];
        sortedge[i] = temp;

        int curr = 0;
        while(true){
            int next = curr;
            if(2*curr + 1 < i && precedge(sortedge[next], sortedge[2*curr + 1])){
                next = 2*curr + 1;
            }
            if(2*curr + 2 < i && precedge(sortedge[next], sortedge[2*curr + 2])){
                next = 2*curr + 2;
            }
            if(next != curr){
                struct edge temp = sortedge[curr];
                sortedge[curr] = sortedge[next];
                sortedge[next] = temp;
                curr = next;
            }
            else{
                break;
            }
        }
    }

    edgebds[0] = 0;
    for(int i = 0; i < n; i++){
        int index = edgebds[i];
        while(index < 2*m && sortedge[index].from == i){
            index++;
        }
        edgebds[i + 1] = index;
    }
}

int tripartiteMatching(int n, int m1, int** g1, int m2, int** g2, int m3, int** g3) {
    struct edge sortedge1[2*m1], sortedge2[2*m2], sortedge3[2*m3];
    int edgebds1[n + 1], edgebds2[n + 1], edgebds3[n + 1];
    setup(n, m1, g1, sortedge1, edgebds1);
    setup(n, m2, g2, sortedge2, edgebds2);
    setup(n, m3, g3, sortedge3, edgebds3);

    int toreturn = 0;
    for(int i = 0; i < 2*m1; i++){
        int node2 = sortedge1[i].from;
        int node3 = sortedge1[i].to;
        int index2 = edgebds2[node2];
        int index3 = edgebds3[node3];
        while(index2 < edgebds2[node2 + 1] && index3 < edgebds3[node3 + 1]){
            int tonode2 = sortedge2[index2].to;
            int tonode3 = sortedge3[index3].to;
            if(tonode2 == tonode3){
                toreturn++;
                index2++;
                index3++;
            }
            else if(tonode2 < tonode3){
                index2++;
            }
            else{
                index3++;
            }
        }
    }
    return toreturn;

}

int main()
{
    FILE* fptr = fopen(getenv("OUTPUT_PATH"), "w");

    char* n_endptr;
    char* n_str = readline();
    int n = strtol(n_str, &n_endptr, 10);

    if (n_endptr == n_str || *n_endptr != '') { exit(EXIT_FAILURE); }

    char* m1_endptr;
    char* m1_str = readline();
    int m1 = strtol(m1_str, &m1_endptr, 10);

    if (m1_endptr == m1_str || *m1_endptr != '') { exit(EXIT_FAILURE); }

    int** g1 = malloc(m1 * sizeof(int*));

    for (int g1_row_itr = 0; g1_row_itr < m1; g1_row_itr++) {
        *(g1 + g1_row_itr) = malloc(2 * (sizeof(int)));

        char** g1_item_temp = split_string(readline());

        for (int g1_column_itr = 0; g1_column_itr < 2; g1_column_itr++) {
            char* g1_item_endptr;
            char* g1_item_str = *(g1_item_temp + g1_column_itr);
            int g1_item = strtol(g1_item_str, &g1_item_endptr, 10);

            if (g1_item_endptr == g1_item_str || *g1_item_endptr != '') { exit(EXIT_FAILURE); }

            *(*(g1 + g1_row_itr) + g1_column_itr) = g1_item;
        }
    }

    char* m2_endptr;
    char* m2_str = readline();
    int m2 = strtol(m2_str, &m2_endptr, 10);

    if (m2_endptr == m2_str || *m2_endptr != '') { exit(EXIT_FAILURE); }

    int** g2 = malloc(m2 * sizeof(int*));

    for (int g2_row_itr = 0; g2_row_itr < m2; g2_row_itr++) {
        *(g2 + g2_row_itr) = malloc(2 * (sizeof(int)));

        char** g2_item_temp = split_string(readline());

        for (int g2_column_itr = 0; g2_column_itr < 2; g2_column_itr++) {
            char* g2_item_endptr;
            char* g2_item_str = *(g2_item_temp + g2_column_itr);
            int g2_item = strtol(g2_item_str, &g2_item_endptr, 10);

            if (g2_item_endptr == g2_item_str || *g2_item_endptr != '') { exit(EXIT_FAILURE); }

            *(*(g2 + g2_row_itr) + g2_column_itr) = g2_item;
        }
    }

    char* m3_endptr;
    char* m3_str = readline();
    int m3 = strtol(m3_str, &m3_endptr, 10);

    if (m3_endptr == m3_str || *m3_endptr != '') { exit(EXIT_FAILURE); }

    int** g3 = malloc(m3 * sizeof(int*));

    for (int g3_row_itr = 0; g3_row_itr < m3; g3_row_itr++) {
        *(g3 + g3_row_itr) = malloc(2 * (sizeof(int)));

        char** g3_item_temp = split_string(readline());

        for (int g3_column_itr = 0; g3_column_itr < 2; g3_column_itr++) {
            char* g3_item_endptr;
            char* g3_item_str = *(g3_item_temp + g3_column_itr);
            int g3_item = strtol(g3_item_str, &g3_item_endptr, 10);

            if (g3_item_endptr == g3_item_str || *g3_item_endptr != '') { exit(EXIT_FAILURE); }

            *(*(g3 + g3_row_itr) + g3_column_itr) = g3_item;
        }
    }

    int result = tripartiteMatching(n, m1, g1, m2, g2, m3, g3);

    fprintf(fptr, "%dn", result);

    fclose(fptr);

    return 0;
}

char* readline() {
    size_t alloc_length = 1024;
    size_t data_length = 0;
    char* data = malloc(alloc_length);

    while (true) {
        char* cursor = data + data_length;
        char* line = fgets(cursor, alloc_length - data_length, stdin);

        if (!line) { break; }

        data_length += strlen(cursor);

        if (data_length < alloc_length - 1 || data[data_length - 1] == 'n') { break; }

        size_t new_length = alloc_length << 1;
        data = realloc(data, new_length);

        if (!data) { break; }

        alloc_length = new_length;
    }

    if (data[data_length - 1] == 'n') {
        data[data_length - 1] = '';
    }

    data = realloc(data, data_length);

    return data;
}

char** split_string(char* str) {
    char** splits = NULL;
    char* token = strtok(str, " ");

    int spaces = 0;

    while (token) {
        splits = realloc(splits, sizeof(char*) * ++spaces);
        if (!splits) {
            return splits;
        }

        splits[spaces - 1] = token;

        token = strtok(NULL, " ");
    }

    return splits;
}

{“mode”:”full”,”isActive”:false}