Editorial for Problem B

Aizen's Plan

Author: naveentummala033

Required Knowledge : Strings, Trie

Time Complexity : O(Sum of all string lengths)O(\text{Sum of all string lengths})

Editorialist : naveentummala033

Approach:

According to the given operation, two strings will be equal if and only if one string is prefix or suffix of other string. This is because we can delete any number of characters from one string from one side. Thus, for each string, we can iterate forward each character and check if there is such string in the given strings. If there is then we can add 11. This covers the prefix. For suffix, we have to reverse all strings and do similar approach used for prefix.

If we do this using maps, we will get TLE as the size of strings is too big. So, we will use trie data structure. As the strings are unique, we will take flag and put it 11 if there is such string in the trie data structure. Similarly, it should be done for reverse strings.

But one edge case is that suffix might be equal to prefix. Thus, to handle this case we should only increment by 11 instead of 22 if they exist in the strings. This can be done in a single loop given that the tries have been built. We have to take each character from normal string and each character from reverse string. We have to check in forward trie and reverse trie. If both are 11 then we have to compare the forward string with the reverse of the backward one. If they are equal then just add 11 or else you can add 22.

Setter's Code:

// Author : Naveen

// Program Start
// Libraries and Namespace Start
#include <bits/stdc++.h>
using namespace std;
// Libraries and Namespace End

//----------------------------------------------------------------

// Important Shortcuts Start

// Declarations Start
typedef long long int ll;
typedef unsigned long long int ull;
typedef string str;
// Declarations End

// Constants Start
const char spc = ' ';
const char newl = '\n';
// Constants End

//----------------------------------------------------------------
class TrieNode {
public:
    TrieNode* child[26];
    ll count = 0;
};
// Solution Class Start
class Solution {
public:
    void solve(ull index) {
        //----------------------------------------------------------------

        ll n, i, j, k;
        cin >> n;
        vector<str> arr(n);
        for (auto& s : arr) {
            cin >> s;
        }
        vector<str> rarr(arr);
        for (auto& s : rarr) {
            reverse(s.begin(), s.end());
        }
        TrieNode* root = new TrieNode();
        for (auto s : arr) {
            TrieNode* node = root;
            for (char c : s) {
                if (!(node->child[c - 'a'])) {
                    node->child[c - 'a'] = new TrieNode();
                }
                node = node->child[c - 'a'];
            }
            node->count++;
        }
        TrieNode* rroot = new TrieNode();
        for (auto s : rarr) {
            TrieNode* node = rroot;
            for (char c : s) {
                if (!(node->child[c - 'a'])) {
                    node->child[c - 'a'] = new TrieNode();
                }
                node = node->child[c - 'a'];
            }
            node->count++;
        }
        ll ans = 0;
        for (i = 0;i < n;i++) {
            TrieNode* node = root;
            TrieNode* rnode = rroot;
            ll cou = 0;
            str s = "", rs = "";
            for (j = 0;j < arr[i].size() - 1;j++) {
                s += arr[i][j];
                rs += rarr[i][j];
                node = node->child[arr[i][j] - 'a'];
                rnode = rnode->child[rarr[i][j] - 'a'];
                cou += node->count + rnode->count;
                if (node->count && rnode->count) {
                    ll flag = true;
                    for (k = 0;k < s.size();k++) {
                        if (s[k] != rs[s.size() - 1 - k]) {
                            flag = false;
                            break;
                        }
                    }
                    if (flag) {
                        cou--;
                    }
                }
            }
            node = node->child[arr[i][arr[i].size() - 1] - 'a'];
            ans += cou;
        }
        cout << ans << newl;

        // cout << "Case #" << index << ": " << ans << newl;

        //----------------------------------------------------------------
    }

    bool test_cases = false;
};
// Solution Class End

// Main Function Start
int main() {
    Solution sol;
    ios_base::sync_with_stdio(false);
    cin.tie(nullptr);
    cout.tie(nullptr);
    ull test_cases = 1;
    if (sol.test_cases) {
        cin >> test_cases;
    }
    for (ull test_case = 1; test_case <= test_cases; ++test_case) {
        sol.solve(test_case);
    }
    return 0;
}
// Main Function End
// Program End
//----------------------------------------------------------------

Tester's Code:

// Author: Sreekar Vyas
#include <bits/stdc++.h>

#include <ext/pb_ds/assoc_container.hpp>
#include <ext/pb_ds/tree_policy.hpp>
using namespace std;
using namespace __gnu_pbds;
template <class T>
using pbds =
tree<T, null_type, less<T>, rb_tree_tag, tree_order_statistics_node_update>;
// #define cerr if(false)cerr
#define int long long
#define pb push_back
#define F first
#define S second
#define yes cout << "Yes\n"
#define no cout << "No\n"
#define yn(x) x ? yes : no
#define f(i, s, e) for (int i = s; i < e; i++)
#define vi vector<int>
#define vb vector<bool>
#define pii pair<int, int>
#define vpi vector<pii>
#define all(x) x.begin(), x.end()
#define minele(x) *min_element(all(x))
#define maxele(x) *max_element(all(x))
#define endl '\n'
#define in(v, x) binary_search(all(v), x)

const int N = 2e5;
const int MOD = 1e9 + 7;
const int inf = LLONG_MAX;
const int minf = LLONG_MIN;

#ifndef ONLINE_JUDGE
#define debug(x)            \
    cerr << (#x) << " is "; \
    _print(x)
#define dbg(x...)           \
    cerr << (#x) << " is "; \
    _print(x)
#else
#define debug(x)
#define dbg(x)
#define dbg(x...)
#endif

template <typename T>
void _print(T a) {
    cerr << a;
}
template <typename T1, typename... T2>
void _print(T1 t1, T2... t2) {
    cerr << t1 << ", ";
    _print(t2...);
    cerr << endl;
}
template <typename T>
void print(T a) {
    cout << a << ' ';
}
template <typename T>
void println(T a) {
    cout << a << endl;
}
template <class T>
istream &operator>>(istream &is, vector<T> &a) {
    for (auto &x : a) is >> x;
    return is;
}
template <class T>
ostream &operator<<(ostream &os, const vector<T> &a) {
    for (const auto &x : a) os << x << ' ';
    return os;
}

template <class T, class V>
void _print(pair<T, V> p);
template <class T>
void _print(vector<T> v);
template <class T>
void _print(set<T> v);
template <class T, class V>
void _print(map<T, V> v);
template <class T>
void _print(multiset<T> v);
template <class T, class V>
void _print(pair<T, V> p) {
    cerr << "{";
    _print(p.F);
    cerr << ",";
    _print(p.S);
    cerr << "} ";
}
template <class T>
void _print(vector<T> v) {
    cerr << "[ ";
    for (T i : v) {
        _print(i);
        cerr << " ";
    }
    cerr << "]";
    cerr << endl;
}
template <class T>
void _print(set<T> v) {
    cerr << "[ ";
    for (T i : v) {
        _print(i);
        cerr << " ";
    }
    cerr << "]";
    cerr << endl;
}
template <class T>
void _print(multiset<T> v) {
    cerr << "[ ";
    for (T i : v) {
        _print(i);
        cerr << " ";
    }
    cerr << "]";
    cerr << endl;
}
template <class T, class V>
void _print(map<T, V> v) {
    cerr << "[ ";
    for (auto i : v) {
        _print(i);
        cerr << " ";
    }
    cerr << "]";
    cerr << endl;
}
template <class T, class V>
void _print(unordered_map<T, V> v) {
    cerr << "[ ";
    for (auto i : v) {
        _print(i);
        cerr << " ";
    }
    cerr << "]";
    cerr << endl;
}

void fast() {
    ios_base::sync_with_stdio(false);
    cin.tie(NULL);
    cout.tie(NULL);
}

int binpow(int a, int b) {
    int ans = 1;
    a %= MOD;
    while (b) {
        if (b & 1) ans = ((ans % MOD) * (a % MOD)) % MOD;

        a = ((a % MOD) * (a % MOD)) % MOD;
        b >>= 1;
    }
    return ans;
}
int add(int a, int b) {
    a %= MOD;
    b %= MOD;
    return (a + b) % MOD;
}
int sub(int a, int b) {
    a %= MOD;
    b %= MOD;
    return (a - b + MOD) % MOD;
}
int mul(int a, int b) {
    a %= MOD;
    b %= MOD;
    return (a * b) % MOD;
}
int mdiv(int a, int b) { return mul(a, binpow(b, MOD - 2)); }

class PolyHash {
   public:
    vector<int> hash, powers, mmi;
    int n;
    string s;
    const int p = 31;

    PolyHash(string &s) {
        this->n = s.length();
        this->s = s;
        hash.assign(n, 0);
        powers.assign(n, 1);
        mmi.assign(n, 1);

        for (int i = 1; i < n; i++) {
            powers[i] = mul(powers[i - 1], p);
            mmi[i] = mdiv(1, powers[i]);
        }
        hash[0] = s[0] - 'a' + 1;

        for (int i = 1; i < n; i++) {
            hash[i] = add(hash[i - 1], mul(s[i] - 'a' + 1, powers[i]));
        }
    }

    // returns hash value of substring s[L:R]
    int hashval(int L, int R) {
        if (L == 0) return hash[R];
        int ans = sub(hash[R], hash[L - 1]);
        ans = mul(ans, mmi[L]);
        return ans;
    }
};

unordered_map<char, int> mp;
void cmp() {
    int x = 0;
    for (char i = 'a'; i <= 'z'; i++) {
        mp[i] = x++;
    }
}

class Trie {
   public:
    Trie *children[26];
    int cnt;
    Trie() {
        for (int i = 0; i < 26; i++) children[i] = nullptr;
        cnt = 0;
    }
    void insert(string s) {
        Trie *node = this;
        int n = s.length();
        for (int i = 0; i < n; i++) {
            int mask = mp[s[i]];

            if (node->children[mask] == nullptr) {
                node->children[mask] = new Trie();
            }
            node = node->children[mask];
        }
        node->cnt++;
    }
};

void solve() {
    int n;
    cin >> n;
    vector<string> v(n);
    for (auto &i : v) cin >> i;

    Trie ftrie, btrie;
    for (auto i : v) {
        ftrie.insert(i);
        string nw = i;
        reverse(all(nw));
        btrie.insert(nw);
    }
    // cerr << "hello\n";
    int ans = 0;

    for (auto i : v) {
        int L = 0, R = i.length() - 1;
        string fwd = "", bwd = "";

        Trie *froot = &ftrie;
        Trie *broot = &btrie;

        for (int j = 0; j < i.length() - 1; j++) {
            fwd += i[L++];
            bwd += i[R--];
            int fmask = mp[i[L - 1]];
            int bmask = mp[i[R + 1]];
            froot = froot->children[fmask];
            broot = broot->children[bmask];

            ans += froot->cnt + broot->cnt;
            if (froot->cnt && broot->cnt) {
                string nw = bwd;
                reverse(all(nw));
                if (nw == fwd) {
                    ans--;
                }
            }
        }
    }
    cout << ans << endl;
}

signed main() {
    fast();
    int t = 1;
    cmp();
    // cin >> t;
    while (t--) {
        solve();
    }
    return 0;
}