HackerRank Palindromic Border problem solution

In this HackerRank Palindromic Border problem solution, we have given a string s, consisting only of the first 8 lowercase letters of the English alphabet. we need to find the sum of all the non-empty substrings of the string.

HackerRank Palindromic Border problem solution

Problem solution in Python.

def is_palin(s):
    head, tail = 0, len(s) - 1
    while head < tail:
        if s[head] != s[tail]:
            return False
        head += 1
        tail -= 1
    return True

#key is a palin, value is the times it appears
def calc_palin_borders(palin_dict):
    #print('palin_dict= ', palin_dict)
    output = 0
    for palin, times in palin_dict.items():
        output += times * (times - 1) // 2
    return output

def mono_str(s):
    cc = s[0]
    for c in s:
        if c != cc:
            return False
    return True

def mono_str_result(s):
    output = 0
    for i in range(2, len(s) + 1):
        output += i * (i - 1) // 2
        output %= 1000000007
    return output

def pb(s):
    if mono_str(s):
        return mono_str_result(s)
    output = 0

    #palin tuple for substring of length 1
    odd = [[], {}, 1]
    for c in s:
        if c not in odd[1]:
            odd[1][c] = 0
        odd[1][c] += 1
    for i in range(len(s)):
        odd[0].append(i)
    output += calc_palin_borders(odd[1])
    #print('odd = ', odd)

    #palin tuple for substring of length 2
    even = [[], {}, 1]
    for i in range(len(s) - 1):
        if s[i] == s[i + 1]:
            even[0].append(i)
            ss = s[i:i + 2]
            if ss not in even[1]:
                even[1][ss] = 0
            even[1][ss] += 1
    output += calc_palin_borders(even[1])
    #print('even = ', even)

    for l in range(3, len(s)):
        #print('l = ', l)
        #working tuple
        if l % 2 == 0:
            wt = even
        else:
            wt = odd

        new_tuple = [[], {}, l] 
        for idx in wt[0]:
            if idx - 1 >= 0 and idx + l - 2 < len(s) and s[idx - 1] == s[idx + l - 2]:
                new_tuple[0].append(idx - 1)
                ss = s[idx - 1:idx - 1 + l]
                if ss not in new_tuple[1]:
                    new_tuple[1][ss] = 0
                new_tuple[1][ss] += 1

        #print('new_tuple= ', new_tuple)
        output += calc_palin_borders(new_tuple[1])
        output %= 1000000007
        if l % 2 == 0:
            even = new_tuple
        else:
            odd = new_tuple
    return output

if __name__ == '__main__':
    print(pb(input()))

{“mode”:”full”,”isActive”:false}

Problem solution in Java.

import java.io.*;
import java.util.Arrays;

public class timus2040 {

    static int[][] es;
    static int[] slink, len, cnt;
    static int free;

    static int newNode(int l) {
        len[free] = l;
        return free++;
    }

    static int get(int i, char c) {
        return es[c - 'a'][i];
    }

    static void set(int i, char c, int n) {
        es[c - 'a'][i] = n;
    }

    public static void solve(Input in, PrintWriter out) throws IOException {
        char[] s = in.next().toCharArray();
        int n = s.length;
        es = new int[8][n + 2];
        for (int[] ar : es) {
            Arrays.fill(ar, -1);
        }
        len = new int[n + 2];
        slink = new int[n + 2];
        cnt = new int[n + 2];
        int root0 = newNode(0);
        int rootm1 = newNode(-1);
        slink[root0] = slink[rootm1] = rootm1;
        int cur = root0;
        for (int i = 0; i < n; ++i) {
            while (i - len[cur] == 0 || s[i] != s[i - len[cur] - 1]) {
                cur = slink[cur];
            }
            if (get(cur, s[i]) == -1) {
                set(cur, s[i], newNode(len[cur] + 2));
                if (cur == rootm1) {
                    slink[get(cur, s[i])] = root0;
                } else {
                    int cur1 = slink[cur];
                    while (s[i] != s[i - len[cur1] - 1]) {
                        cur1 = slink[cur1];
                    }
                    slink[get(cur, s[i])] = get(cur1, s[i]);
                }
            }
            cur = get(cur, s[i]);
            cnt[cur]++;
        }
        long ans = 0;
        for (int i = free - 1; i >= 0; --i) {
            cnt[slink[i]] += cnt[i];
            if (len[i] > 0) {
                ans = (ans + 1L * cnt[i] * (cnt[i] - 1) / 2) % 1000000007;
            }
        }
        out.println(ans);
    }

    public static void main(String[] args) throws IOException {
//        FileWriter out = new FileWriter("output.txt");
//        solve(new FileReader("input.txt"), out);
        PrintWriter out = new PrintWriter(System.out);
        solve(new Input(new BufferedReader(new InputStreamReader(System.in))), out);
        out.close();
    }

    static class Input {
        BufferedReader in;
        StringBuilder sb = new StringBuilder();

        public Input(BufferedReader in) {
            this.in = in;
        }

        public Input(String s) {
            this.in = new BufferedReader(new StringReader(s));
        }

        public String next() throws IOException {
            sb.setLength(0);
            while (true) {
                int c = in.read();
                if (c == -1) {
                    return null;
                }
                if (" nrt".indexOf(c) == -1) {
                    sb.append((char)c);
                    break;
                }
            }
            while (true) {
                int c = in.read();
                if (c == -1 || " nrt".indexOf(c) != -1) {
                    break;
                }
                sb.append((char)c);
            }
            return sb.toString();
        }

        public int nextInt() throws IOException {
            return Integer.parseInt(next());
        }

        public long nextLong() throws IOException {
            return Long.parseLong(next());
        }

        public double nextDouble() throws IOException {
            return Double.parseDouble(next());
        }
    }
}

{“mode”:”full”,”isActive”:false}

Problem solution in C++.

#include <stdio.h>
#include <set>
#include <algorithm>
#include <cstring>
using namespace std;
#define ll long long
#define mod 1000000007
#define L 5000011
int sa[L];
int sai[L];
int lcp[L];
int v[L];
char s[L];
ll ts[L];
int p[L<<1];
char t[L<<1];
int m, n;
set<ll> found;
bool scomp(int i, int j) {
    return s[i] < s[j];
}
bool tscomp(int i, int j) {
    return ts[i] < ts[j];
}
void get_suffix_array() {
    for (int i = 0; i < n; i++) v[i] = i;
    sort(v, v + n, scomp);
    sai[v[0]] = 1;
    for (int i = 1; i < n; i++) {
        if (s[v[i]] == s[v[i - 1]]) {
            sai[v[i]] = sai[v[i - 1]];
        } else {
            sai[v[i]] = i+1;
        }
    }
    for (int p = 1; p <= n; p <<= 1) {
        for (int i = 0; i < n-p; i++) ts[i] = sai[i] * (n+1LL) + sai[i+p];
        for (int i = n-p; i < n; i++) ts[i] = sai[i] * (n+1LL);
        sort(v, v + n, tscomp);
        sai[v[0]] = 1;
        for (int i = 1; i < n; i++) {
            if (ts[v[i]] == ts[v[i - 1]]) {
                sai[v[i]] = sai[v[i - 1]];
            } else {
                sai[v[i]] = i+1;
            }
        }
    }
    for (int i = 0; i < n; i++) sai[i]--;
    for (int i = 0; i < n; i++) sa[sai[i]] = i;
}
void get_lcp() {
    for (int i = 0; i < n; i++) lcp[i] = 0;
    int l = 0;
    for (int i = 0; i < n-1; i++) {
        int k = sai[i];
        int j = k ? sa[k-1] : sa[n-1];
        while (j + l < n and s[i + l] == s[j + l]) {
            l++;
        }
        lcp[k] = l;
        if (l > 0) {
            l--;
        }
    }
}
void manacher() {
    // from wikipedia
    // t has been processed
    int center = 0, end = 0, left = 0, right = 0;
    for (int i = 1; i < m; i++) {
        if (i > end) {
            p[i] = 0;
            left = i - 1;
            right = i + 1;
        } else {
            int j = 2*center - i; // index on the other side
            if (p[j] < end - i) { // whole palindrome is inside
                p[i] = p[j];
                left = -1; // so we don't enter the loop below
            } else { 
                p[i] = end - i;
                right = end + 1;
                left = 2*i - right;
            }
        }
        while (left >= 0 and right < m and t[left] == t[right]) {
            p[i]++;
            left--;
            right++;
        }
        if (i + p[i] > end) {
            center = i;
            end = i + p[i];
        }
    }
}
struct Node {
    int i, j, v;
    Node *p, *l, *r;
    Node(int i, int j, Node *p = NULL): i(i), j(j), p(p) {
        if (j - i == 1) {
            l = r = NULL;
            v = lcp[i];
        } else {
            int k = i + j >> 1;
            l = new Node(i, k, this);
            r = new Node(k, j, this);
            v = min(l->v, r->v);
        }
    }
};
int node_minocc(Node *node, int v, int i) {
    // find maximum j, 0 <= j <= i such that a[j] < v
    while (node->l) {
        if (i < node->l->j) {
            node = node->l;
        } else {
            node = node->r;
        }
    }
    // now node->i = i < node->j = i+1
    if (node->v < v) {
        return node->i;
    }
    while (true) {
        while (node->p and node->p->l == node) {
            node = node->p;
        }
        if (!node->p) {
            return 0;
        }
        node = node->p;
        if (node->l->v < v) {
            node = node->l;
            break;
        }
    }
    while (node->l) {
        if (node->r->v < v) {
            node = node->r;
        } else {
            node = node->l;
        }
    }
    return node->i;
}
int node_maxocc(Node *node, int v, int i) {
    // find maximum j, i <= j <= n such that a[j] >= v
    if (i == n) {
        return n;
    }
    while (node->l) {
        if (i < node->l->j) {
            node = node->l;
        } else {
            node = node->r;
        }
    }
    // now node->i = i < node->j = i+1
    if (node->v < v) {
        return node->i;
    }
    while (true) {
        while (node->p and node->p->r == node) {
            node = node->p;
        }
        if (!node->p) {
            return n;
        }
        node = node->p;
        if (node->r->v < v) {
            node = node->r;
            break;
        }
    }
    while (node->l) {
        if (node->l->v < v) {
            node = node->l;
        } else {
            node = node->r;
        }
    }
    return node->i;
}
int main() {
    scanf("%s", s);
    n = strlen(s);
    // suffix tree
    get_suffix_array();
    get_lcp();
    Node *root = new Node(0, n);
    m = 0;
    for (int i = 0; i < n; i++) {
        t[m++] = '#';
        t[m++] = s[i];
    }
    t[m++] = '$';
    t[0] = '^';
    manacher();
    ll ans = 0;
    for (int i = 1; i < m-1; i++) {
        int k = p[i];
        if (t[i-k] == '#') k--;
        for(; k >= 0; k -= 2) {
            int start = sai[i-k>>1];
            int mino = node_minocc(root,k+1,start);
            ll hsh = k*(n+1LL)+mino;
            if (found.find(hsh) != found.end()) {
                break;
            }
            found.insert(hsh);
            int maxo = node_maxocc(root,k+1,start+1);
            ll c = maxo - mino;
            ans += c*(c-1);
            ans %= mod;
        }
    }
    ans = ans * (mod+1>>1) % mod;
    printf("%lldn", ans);
}

{“mode”:”full”,”isActive”:false}