Editorial for Problem C

The Postcard Collector

Author : Speedster1010

Required Knowledge : Zobrist Hashing

Time Complexity : O(NlogN)\mathcal{O}(N \log N)

Editorialist : Speedster1010

Approach:

The collection of a city is the set of distinct postcard designs that appear on the path from the Capital (town 1) to the corresponding city. Comparing entire sets of postcards for every pair of towns is too slow, so we need a faster way to uniquely identify them.

XOR hashing (often called as Zobrist hashing) is a technique used for this exact purpose. It allows us to compare sets of distinct numbers in constant time. To implement this, assign a massive, randomly generated 64-bit integer to every distinct postcard design AiA_i.

The XOR sum of the random numbers assigned to a set of distinct elements uniquely identifies the set itself, making it incredibly easy to count matching pairs. While this algorithmic approach is randomized, meaning there is theoretically a non-zero probability of the solution failing. Using a 64-bit space reduces that probability to something so extremely low that it will comfortably pass all test cases.

Now, let's talk about the implementation. We can build these hashes along the paths efficiently using a single Depth First Search from town 1. To ensure we only XOR a postcard's value once per distinct design, we maintain a frequency count of postcards on the current active path:

  1. 1.Going Down: When you visit town uu, check the frequency of AuA_u. If it is 0, it is a new design. XOR its assigned random value into your running hash, and then increment its frequency.
  2. 2.Backtracking: When leaving town uu, decrement the frequency of AuA_u.

To calculate the final answer, save the final computed hashes for each town. Count the overall frequencies of these hashes across the tree. If a specific hash appears CC times, it contributes C×(C1)/2C \times (C - 1) / 2 valid pairs to the final answer.

Note: The range from which the random numbers are picked must be large enough to avoid collisions caused. Birthday Paradox explains this peculiar phenomenon of collisions. 64-bit integers are a great fit for this, keeping the collision chance practically at zero.

Setter's Code:

#include<bits/stdc++.h>
using namespace std;

mt19937_64 rng(chrono::steady_clock::now().time_since_epoch().count());

void solve() {
    int n;
    cin >> n;
    vector<int> a(n);
    map<int, uint64_t> R;
    for (int i = 0; i < n; i++) {
        cin >> a[i];
        if(!R.count(a[i])) {
            R[a[i]] = rng();
        }
    }

    vector<vector<int>> adj(n);
    for(int i = 0; i < n - 1; i++) {
        int u, v;
        cin >> u >> v;
        --u, --v;
        adj[u].push_back(v);
        adj[v].push_back(u);
    }

    map<int, int> freq;
    map<uint64_t, int> hashes;
    function<void(int, int, uint64_t)> dfs = [&](int u, int p, uint64_t curr_hash) {
        if(freq[a[u]] == 0) {
            curr_hash ^= R[a[u]];
        }
        
        freq[a[u]]++;
        hashes[curr_hash]++;

        for(int v : adj[u]) {
            if(v != p) {
                dfs(v, u, curr_hash);
            }
        }
        
        freq[a[u]]--;
    };
    dfs(0, 0, 0);

    long long ans = 0;
    for (auto [h, cnt] : hashes) {
        ans += cnt * (cnt - 1LL) / 2LL;
    }
    cout << ans << "\n";
}

signed main() {
    ios_base::sync_with_stdio(0);
    cin.tie(0);
    
    int t = 1;
    cin >> t;
    while(t--) {
        solve();
    }
}

Tester's Code:

import sys
import random

def solve():
    input_data = sys.stdin.read().split()
    if not input_data:
        return
    
    t = int(input_data[0])
    idx = 1
    out = []
    
    for _ in range(t):
        n = int(input_data[idx])
        idx += 1
        
        a = []
        for _ in range(n):
            a.append(int(input_data[idx]))
            idx += 1
            
        # Coordinate compression
        unique_vals = list(set(a))
        val_to_id = {val: i for i, val in enumerate(unique_vals)}
        m = len(unique_vals)
        
        # Generate random hashes for each unique postcard
        H = [random.getrandbits(64) for _ in range(m)]
        a_ids = [val_to_id[val] for val in a]
        
        adj = [[] for _ in range(n + 1)]
        for _ in range(n - 1):
            u = int(input_data[idx])
            v = int(input_data[idx+1])
            idx += 2
            adj[u].append(v)
            adj[v].append(u)
            
        freq = [0] * m
        cur_hash = 0
        all_hashes = []
        
        # Iterative DFS: stack stores (current_node, parent_node, edge_index)
        stack = [(1, 0, 0)]
        
        # Enter root
        root_id = a_ids[0]
        if freq[root_id] == 0:
            cur_hash ^= H[root_id]
        freq[root_id] += 1
        all_hashes.append(cur_hash)
        
        while stack:
            u, p, e_idx = stack.pop()
            
            if e_idx < len(adj[u]):
                v = adj[u][e_idx]
                # Push back current node with the next edge index
                stack.append((u, p, e_idx + 1))
                
                if v != p:
                    # Enter v
                    v_id = a_ids[v - 1]
                    if freq[v_id] == 0:
                        cur_hash ^= H[v_id]
                    freq[v_id] += 1
                    all_hashes.append(cur_hash)
                    
                    # Schedule exploration of v's children
                    stack.append((v, u, 0))
            else:
                # Finished exploring u, backtrack
                if u != 1:
                    u_id = a_ids[u - 1]
                    freq[u_id] -= 1
                    if freq[u_id] == 0:
                        cur_hash ^= H[u_id]
        
        # Count matching paths
        all_hashes.sort()
        ans = 0
        count = 1
        for i in range(1, len(all_hashes)):
            if all_hashes[i] == all_hashes[i-1]:
                count += 1
            else:
                ans += count * (count - 1) // 2
                count = 1
        ans += count * (count - 1) // 2
        
        out.append(str(ans))
        
    sys.stdout.write('\n'.join(out) + '\n')

if __name__ == '__main__':
    solve()