HackerRank Max Transform problem solution YASH PAL, 31 July 2024 In this HackerRank Max Transform problem solution, we are given an array and we need to find the sum of the elements of S(S(array)). the max transform of the max transform of the array. since the answer is very large, so we only find it modulo 10 to power 9 plus 7. Problem solution in Python programming. #!/bin/python3 import math import os import random import re import sys # Complete the solve function below. import math import os import random import re import sys sys.setrecursionlimit(9999999) from decimal import Decimal def t1(n): return Decimal(n * (n + 1) / 2) def t2(n): return Decimal(n * (n + 1) * (n + 2) / 6) def u2(n): return Decimal(n * (n + 2) * (2 * n + 5) / 24) def countzip(a, b): return u2(a + b) - u2(abs(a - b)) + t2(abs(a - b)) def countends(x, n, ex): return countzip(n, ex) - countzip(x, ex) - countzip(n - 1 - x, 0) def countsplit(x, n): return t1(t1(n)) - t1(x) - countzip(n - x - 1, x - 1) K = 20 lg = [0] * (1 << K) for i in range(K): lg[1 << i] = i for i in range(1, 1 << K): lg[i] = max(lg[i], lg[i - 1]) def make_rangemax(A): n = len(A) assert 1 << K > n key = lambda x: A[x] mxk = [] mxk.append(range(n)) for k in range(K - 1): mxk.append(list(mxk[-1])) for i in range(n - (1 << k)): mxk[k + 1][i] = max( mxk[k][i], mxk[k][i + (1 << k)], key=key) def rangemax(i, j): k = lg[j - i] return max(mxk[k][i], mxk[k][j - (1 << k)], key=key) return rangemax def brutesolo(A): rangemax = make_rangemax(A) stack = [(0, len(A))] ans = 0 while stack: i, j = stack.pop() if i != j: x = rangemax(i, j) stack.append((i, x)) stack.append((x + 1, j)) ans += A[x] * (x - i + 1) * (j - x) return ans def make_brute(A): rangemax = make_rangemax(A) def brute(i, j): stack = [(i, j)] ans = 0 while stack: i, j = stack.pop() if i != j: x = rangemax(i, j) stack.append((i, x)) stack.append((x + 1, j)) ans += A[x] * countends(x - i, j - i, 0) return ans return brute, rangemax def ends(A, B): brutea, rangemaxa = make_brute(A) bruteb, rangemaxb = make_brute(B) stack = [(len(A), len(B))] ans = 0 while stack: i, j = stack.pop() if i == 0: ans += bruteb(0, j) elif j == 0: ans += brutea(0, i) else: x = rangemaxa(0, i) y = rangemaxb(0, j) if A[x] < B[y]: ans += bruteb(y + 1, j) ans += B[y] * countends(y, j, i) stack.append((i, y)) else: ans += brutea(x + 1, i) ans += A[x] * countends(x, i, j) stack.append((x, j)) return ans def maxpairs(a): return [max(x, y) for x, y in zip(a, a[1:])] def solve(A): n = len(A) x = max(range(n), key=lambda x: A[x]) return (int((brutesolo(A[:x]) + ends(A[x + 1:][::-1], maxpairs(A[:x])) + A[x] * countsplit(x, n))%(10**9+7))) if __name__ == '__main__': fptr = open(os.environ['OUTPUT_PATH'], 'w') n = int(input()) A = list(map(int, input().rstrip().split())) result = solve(A) fptr.write(str(result) + 'n') fptr.close() Problem solution in Java Programming. import java.io.*; import java.math.*; import java.security.*; import java.text.*; import java.util.*; import java.util.concurrent.*; import java.util.regex.*; public class Solution { // Complete the solve function below. static final int SUM_DIV = 1000000007; static class Plateau { final int start; final int end; final int v; Plateau(int start, int end, int v) { this.start = start; this.end = end; this.v = v; } @Override public String toString() { return new StringJoiner(", ", "[", "]") .add("start=" + start) .add("end=" + end) .add("v=" + v) .toString(); } } static int solve(int[] input) { // Return the sum of S(S(A)) modulo 10^9+7. final Map<Integer, Plateau> mapStart = new HashMap<>(input.length * 2); final Map<Integer, Plateau> mapEnd = new HashMap<>(input.length * 2); for (int i = 0; i < input.length; ++i) { Plateau p = new Plateau(i, i, input[i]); mapStart.put(i, p); mapEnd.put(i, p); } long subtract = 0; Plateau cur = mapStart.remove(0); mapEnd.remove(0); for (;;) { if (mapStart.isEmpty()) { long total = totalCount(input.length) ; long result = ((((long)cur.v) * total + subtract) + SUM_DIV) % SUM_DIV; // System.out.println("total=" + total + " subtract=" + subtract + " result=" + result); return (int)result; } Plateau prev = mapEnd.get(normalize(cur.start - 1, input)); if (prev.v == cur.v) { // extend plateau cur = new Plateau(prev.start, cur.end, cur.v); // System.out.println("Extending plateau back, " + cur.toString()); mapStart.remove(prev.start); mapEnd.remove(prev.end); continue; } Plateau next = mapStart.get(normalize(cur.end + 1, input)); if (next.v == cur.v) { cur = new Plateau(cur.start, next.end, cur.v); // System.out.println("Extending plateau forward, " + cur.toString()); mapStart.remove(next.start); mapEnd.remove(next.end); continue; } if (next.v > cur.v && prev.v > cur.v) { // found plateau; pull it up int nextV = Math.min(next.v, prev.v); long delta = (long) (nextV - cur.v); if (cur.end >= cur.start) { delta *= calculateCounts(normalize(cur.end - cur.start + 1, input)); } else { delta *= countInverse(input.length - cur.start, normalize(cur.end + 1 - cur.start, input)); } // System.out.println("Pull up, nextV=" + nextV + " cur=" + cur + // " subDelta=" + delta + " sub=" + subtract + "->" + (subtract - delta)); subtract -= delta; subtract %= SUM_DIV; cur = new Plateau(cur.start, cur.end, nextV); continue; } // System.out.println("value=" + (countMaxClean(input) + subtract + " " + Arrays.toString(input))); boolean back = prev.v < cur.v; Plateau successor; if (back) { successor = prev; } else { //next < v successor = next; } mapStart.remove(successor.start); mapEnd.remove(successor.end); mapStart.put(cur.start, cur); mapEnd.put(cur.end, cur); cur = successor; // System.out.println("Switch " + (back ? "back" : "forw") + ", " + cur.toString()); } } private static int normalize(int idx, int[] input) { return (idx + input.length) % input.length; } private static int getByIdx(int[] input, int i) { return input[normalize(i + input.length, input)]; } private static long totalCount(long n) { long s1Size = n * (n + 1) / 2 % SUM_DIV; long s2Size = (s1Size * (s1Size + 1) / 2) % SUM_DIV; return s2Size; } private static long calculateCounts(long n) { return (n * n * n + 3 * n * n + 2 * n) / 6 % SUM_DIV; } private static long countInverse(long c1, long l) { // Don't ask if (c1 <= l / 2) { return (-4 * c1 * c1 * c1 + l * l * l + 6 * c1 * c1 * l - 3 * c1 * l * l - 3 * c1 * l + 3 * l * l - 2 * c1 + 2 * l) / 6 % SUM_DIV; } else { return ((countInverse(l - c1 - 1, l) - temp(c1 + 1) % SUM_DIV + temp(l - c1) % SUM_DIV) + SUM_DIV) % SUM_DIV; } } private static long temp(long n) { return (n * n + n) / 2; } private static final Scanner scanner = new Scanner(System.in); public static void main(String[] args) throws IOException { BufferedWriter bufferedWriter = new BufferedWriter(new FileWriter(System.getenv("OUTPUT_PATH"))); int n = scanner.nextInt(); scanner.skip("(rn|[nru2028u2029u0085])?"); int[] A = new int[n]; String[] AItems = scanner.nextLine().split(" "); scanner.skip("(rn|[nru2028u2029u0085])?"); for (int i = 0; i < n; i++) { int AItem = Integer.parseInt(AItems[i]); A[i] = AItem; } int result = solve(A); bufferedWriter.write(String.valueOf(result)); bufferedWriter.newLine(); bufferedWriter.close(); scanner.close(); } } Problem solution in C++ programming. #include <cstdio> #include <iostream> #include <sstream> #include <deque> #include <queue> #include <cstring> #include <algorithm> #include <cmath> #include <vector> #include <map> #include <set> #include <string> #include <cstdlib> #include <ctime> using namespace std; #define P 1000000007 #define N 1100000 int used[N], fa[N], sum[N], f[N], now, ans, T, cc; vector <int> V[N]; int n; int a[N]; int gf(int x) { if (fa[x] != x) fa[x] = gf(fa[x]); return fa[x]; } void merge(int x, int y) { x = gf(x); y = gf(y); sum[x] += sum[y]; fa[y] = x; } void add(int x) { used[x] = 1; sum[x] = 1; if (used[x - 1]) { now = (now - f[sum[gf(x - 1)]] + P) % P; merge(x, x - 1); } if (used[x + 1]) { now = (now - f[sum[gf(x + 1)]] + P) % P; merge(x, x + 1); } now = (now + f[sum[gf(x)]]) % P; int L = sum[gf(1)], R = sum[gf(n)]; // printf("?? %d %d %dn", x, L, R); x = min(R, L - 1); if (x <= 0) { cc = now; return ; } cc = now; // printf("?? %d %dn", cc, x); cc = (cc + 1LL * x * L * (R + 1)) % P; cc = (cc - 1LL * x * (x + 1) / 2 % P * (L + R + 1)) % P; cc = (cc + 1LL * x * (x + 1) * (2 * x + 1) / 6) % P; cc = (cc + P) % P; // printf("! %dn", cc); return ; } int main() { scanf("%d", &n); int ma = 0; for (int i = 1; i <= n; i++) scanf("%d", &a[i]), V[a[i]].push_back(i), ma = max(ma, a[i]); T = 1LL * n * (n + 1) / 2 % P; T = 1LL * T * (T + 1) / 2 % P; for (int i = 1; i <= n; i++) f[i] = (1LL * i * (i + 1) * (2 * i + 1) / 6 + 1LL * i * (i + 1) / 2) / 2 % P; for (int i = 1; i <= n; i++) fa[i] = i; now = 0; for (int i = 0; i < ma; i++) { for (int j = 0; j < (int) V[i].size(); j++) add(V[i][j]); ans = (ans + T - cc) % P; } ans = (ans + P) % P; printf("%dn", ans); } Problem solution in C programming. #pragma GCC optimize ("Ofast") #pragma GCC target ("sse4") #include<stdio.h> #include<string.h> #include<stdlib.h> const int mod = 1000000007, _2 = 500000004; int N, MX = 0, tp, a[200010], i_1[200010], st[200010], mxl[200010], mxr[200010], sxl[200010], sxr[200010]; long long M, CNT, ANS = 0; void calc(int w, int x, int y) { if( x < y ) { int temp = x; x = y; y = temp; } int k; if( x == y ) { k = ( ( (long long)( x + y ) * i_1[y] % mod - (long long)x * x % mod ) % mod + mod ) % mod; } else { k = ( ( (long long)y * ( i_1[x-1] - i_1[y] ) % mod + (long long)( x + y ) * i_1[y] % mod ) % mod + mod ) % mod; } ANS = ( ANS + (long long)w * k ) % mod; CNT -= k; if( CNT < 0 ) { CNT += mod; } } void calcl(int w, int x, int y) { if( x == 1 || y == 0 ) { return; } int k; if( y < x ) { k = i_1[y]; } else { k = ( i_1[x-1] + (long long)( y - x + 1 ) * ( x - 1 ) ) % mod; } ANS = ( ANS + (long long)w * k ) % mod; CNT -= k; if( CNT < 0 ) { CNT += mod; } } void calcr(int w, int x, int y) { if( x == 0 || y == 1 ) { return; } int k; if( y + 1 <= x ) { k = i_1[y-1]; } else { k = ( i_1[x] + (long long)( y - x - 1 ) * x ) % mod; } ANS = ( ANS + (long long)w * k ) % mod; CNT -= k; if( CNT < 0 ) { CNT += mod; } } int main() { int p; scanf("%d", &N); for( int i = 1 ; i <= N ; i++ ) { scanf("%d", &a[i]); MX = MX > a[i] ? MX : a[i]; } M = ( (long long)N * ( N + 1 ) >> 1 ) % mod; M = (long long)M * ( M + 1 ) % mod * _2 % mod; CNT = M; for( int i = 1 ; i <= N ; i++ ) { i_1[i] = ( i_1[i-1] + i ) % mod; } for( int i = 1 ; i <= N ; i++ ) { sxl[i] = sxl[i-1] > a[i] ? sxl[i-1] : a[i]; } for( int i = N ; i ; i-- ) { sxr[i] = sxr[i+1] > a[i] ? sxr[i+1] : a[i]; } tp = 0; for( int i = 1 ; i <= N ; i++ ) { while( tp > 0 && a[st[tp]] <= a[i] ) { tp--; } if(tp) { mxl[i] = st[tp] + 1; } else { mxl[i] = 1; } st[++tp] = i; } tp = 0; for( int i = N ; i ; i-- ) { while( tp > 0 && a[st[tp]] < a[i] ) { tp--; } if(tp) { mxr[i] = st[tp] - 1; } else { mxr[i] = N; } st[++tp] = i; } for( int i = 1 ; i <= N ; i++ ) { calc(a[i], i-mxl[i]+1, mxr[i]-i+1); } p = N; for( int i = 1 ; i <= N ; i++ ) { int g = sxl[i]; while( p > i && sxr[p] < g ) { p--; } while( p < i ) { p++; } calcl(g, i, N-p); } p = 1; for( int i = N ; i ; i-- ) { int g = sxr[i]; while( p < i && sxl[p] <= g ) { p++; } while( p > i ) { p--; } calcr(g, N-i+1, p-1); } CNT = ( CNT % mod + mod ) % mod; ANS = ( ANS + (long long)CNT * MX ) % mod; printf("%lld", ANS); return 0; } coding problems data structure