Author : Speedster1010
Required Knowledge : Zobrist Hashing
Time Complexity :
Editorialist : Speedster1010
The collection of a city is the set of distinct postcard designs that appear on the path from the Capital (town 1) to the corresponding city. Comparing entire sets of postcards for every pair of towns is too slow, so we need a faster way to uniquely identify them.
XOR hashing (often called as Zobrist hashing) is a technique used for this exact purpose. It allows us to compare sets of distinct numbers in constant time. To implement this, assign a massive, randomly generated 64-bit integer to every distinct postcard design .
The XOR sum of the random numbers assigned to a set of distinct elements uniquely identifies the set itself, making it incredibly easy to count matching pairs. While this algorithmic approach is randomized, meaning there is theoretically a non-zero probability of the solution failing. Using a 64-bit space reduces that probability to something so extremely low that it will comfortably pass all test cases.
Now, let's talk about the implementation. We can build these hashes along the paths efficiently using a single Depth First Search from town 1. To ensure we only XOR a postcard's value once per distinct design, we maintain a frequency count of postcards on the current active path:
, check the frequency of . If it is 0, it is a new design. XOR its assigned random value into your running hash, and then increment its frequency., decrement the frequency of .To calculate the final answer, save the final computed hashes for each town. Count the overall frequencies of these hashes across the tree. If a specific hash appears times, it contributes valid pairs to the final answer.
Note: The range from which the random numbers are picked must be large enough to avoid collisions caused. Birthday Paradox explains this peculiar phenomenon of collisions. 64-bit integers are a great fit for this, keeping the collision chance practically at zero.
#include<bits/stdc++.h>
using namespace std;
mt19937_64 rng(chrono::steady_clock::now().time_since_epoch().count());
void solve() {
int n;
cin >> n;
vector<int> a(n);
map<int, uint64_t> R;
for (int i = 0; i < n; i++) {
cin >> a[i];
if(!R.count(a[i])) {
R[a[i]] = rng();
}
}
vector<vector<int>> adj(n);
for(int i = 0; i < n - 1; i++) {
int u, v;
cin >> u >> v;
--u, --v;
adj[u].push_back(v);
adj[v].push_back(u);
}
map<int, int> freq;
map<uint64_t, int> hashes;
function<void(int, int, uint64_t)> dfs = [&](int u, int p, uint64_t curr_hash) {
if(freq[a[u]] == 0) {
curr_hash ^= R[a[u]];
}
freq[a[u]]++;
hashes[curr_hash]++;
for(int v : adj[u]) {
if(v != p) {
dfs(v, u, curr_hash);
}
}
freq[a[u]]--;
};
dfs(0, 0, 0);
long long ans = 0;
for (auto [h, cnt] : hashes) {
ans += cnt * (cnt - 1LL) / 2LL;
}
cout << ans << "\n";
}
signed main() {
ios_base::sync_with_stdio(0);
cin.tie(0);
int t = 1;
cin >> t;
while(t--) {
solve();
}
}
import sys
import random
def solve():
input_data = sys.stdin.read().split()
if not input_data:
return
t = int(input_data[0])
idx = 1
out = []
for _ in range(t):
n = int(input_data[idx])
idx += 1
a = []
for _ in range(n):
a.append(int(input_data[idx]))
idx += 1
# Coordinate compression
unique_vals = list(set(a))
val_to_id = {val: i for i, val in enumerate(unique_vals)}
m = len(unique_vals)
# Generate random hashes for each unique postcard
H = [random.getrandbits(64) for _ in range(m)]
a_ids = [val_to_id[val] for val in a]
adj = [[] for _ in range(n + 1)]
for _ in range(n - 1):
u = int(input_data[idx])
v = int(input_data[idx+1])
idx += 2
adj[u].append(v)
adj[v].append(u)
freq = [0] * m
cur_hash = 0
all_hashes = []
# Iterative DFS: stack stores (current_node, parent_node, edge_index)
stack = [(1, 0, 0)]
# Enter root
root_id = a_ids[0]
if freq[root_id] == 0:
cur_hash ^= H[root_id]
freq[root_id] += 1
all_hashes.append(cur_hash)
while stack:
u, p, e_idx = stack.pop()
if e_idx < len(adj[u]):
v = adj[u][e_idx]
# Push back current node with the next edge index
stack.append((u, p, e_idx + 1))
if v != p:
# Enter v
v_id = a_ids[v - 1]
if freq[v_id] == 0:
cur_hash ^= H[v_id]
freq[v_id] += 1
all_hashes.append(cur_hash)
# Schedule exploration of v's children
stack.append((v, u, 0))
else:
# Finished exploring u, backtrack
if u != 1:
u_id = a_ids[u - 1]
freq[u_id] -= 1
if freq[u_id] == 0:
cur_hash ^= H[u_id]
# Count matching paths
all_hashes.sort()
ans = 0
count = 1
for i in range(1, len(all_hashes)):
if all_hashes[i] == all_hashes[i-1]:
count += 1
else:
ans += count * (count - 1) // 2
count = 1
ans += count * (count - 1) // 2
out.append(str(ans))
sys.stdout.write('\n'.join(out) + '\n')
if __name__ == '__main__':
solve()