Editorial for Problem B

The Packet Game

Author : tnaveen2308

Required Knowledge : Graphs, Game Theory, Probability

Time Complexity : O(n+m)O(n+m)

Editorialist : tnaveen2308

Approach:

This problem involves a game played on a directed acyclic graph (DAG) where data packets are positioned at nodes, and players move them along edges.

The winner of the game is determined by the xor of Grundy values for all nodes containing data packets. Since every Grundy value m\leq \sqrt{m} (where mm is the number of edges), the xor doesn't exceed 512. Let PvP_v be the probability of Alice's victory if the current xor is vv.

We can express PvP_v using the recurrence relation: Pv=Ptoprob(vto)+[v0]1n+1P_v = \sum P_{to} \cdot prob(v \rightarrow to) + [v \neq 0] \cdot \frac{1}{n+1}

In the second term, when x=n+1x = n+1 is chosen (with probability 1n+1\frac{1}{n+1}), the simulation ends. If the current xor value vv is non-zero at this point, Alice wins. Otherwise, she loses.

We have that prob(vto)=cnt[vto]n+1prob(v \rightarrow to) = \frac{cnt[v\oplus to]}{n+1}, where cnt[x]cnt[x] is the number of nodes with Grundy value equal to xx. This represents the probability that a new packet is added to a node with Grundy value vtov\oplus to.

This gives us a system of 512 linear equations with variables PvP_v. We can solve it using the Gauss method. The answer is P0P_0, which represents the probability that Alice wins if the current xor is 0.

The answer can be represented as a fraction PQ\frac{P}{Q} where PP and QQ are coprime integers and Q≢0(mod998244353)Q \not\equiv 0 \pmod{998244353}. We need to output PQ1mod998244353P \cdot Q^{-1} \bmod 998244353.

There is also an alternative solution using the Hadamard transform.

Code:

#include <bits/stdc++.h>
using namespace std;
#define ll long long

const int mod2 = 998244353;
const int N = 1e5 + 5, M = 512;

vector<int> g[N], d;
bool c[N];
int w[N], t[N];

// Power function for modular exponentiation
ll power(ll a, ll b, int m) {
    ll x = 1;
    if (a >= m) a %= m;
    while (b) {
        if (b & 1) x = x * a % m;
        a = a * a % m;
        b >>= 1;
    }
    return x;
}

// Subtraction with modulo
ll sub(ll a, ll b, int m) {
    if (a >= m) a %= m;
    if (b >= m) b %= m;
    a -= b;
    return a < 0 ? a + m : a;
}

// Find MEX (minimum excludent) of a vector
int mex(vector<int> a) {
    sort(a.begin(), a.end());
    a.resize(unique(a.begin(), a.end()) - a.begin());
    for(int i = 0; i < a.size(); i++) {
        if (a[i] != i) return i;
    }
    return a.size();
}

// Gaussian elimination
vector<ll> gauss(vector<vector<ll>> a) {
    int n = a.size();
    for (int j = 0, i = 0; j < n && i < n; j++, i++) {
        for(int k = 0; k < n; k++) {
            if (k != i) {
                ll v = a[k][j] * power(a[i][j], mod2 - 2, mod2) % mod2;
                for(int l = j; l <= n; l++) 
                    a[k][l] = sub(a[k][l], v * a[i][l] % mod2, mod2);
            }
        }
    }
    vector<ll> s(n);
    for(int i = 0; i < n; i++) 
        s[i] = a[i][n] * power(a[i][i], mod2 - 2, mod2) % mod2;
    return s;
}

// DFS for topological sort
void dfs(int u) {
    c[u] = 1;
    for(auto v : g[u]) {
        if (!c[v]) dfs(v);
    }
    d.push_back(u);
}

int main() {
    ios_base::sync_with_stdio(0); 
    cin.tie(NULL); 
    cout.tie(NULL);
    
    int n, m; 
    cin >> n >> m;
    
    while (m--) {
        int u, v; 
        cin >> u >> v;
        g[u].push_back(v);
    }
    
    // Topological sort
    for(int i = 1; i <= n; i++) {
        if (!c[i]) dfs(i);
    }
    
    // Calculate Grundy values
    for(auto u : d) {
        vector<int> a;
        for(auto v : g[u]) a.push_back(w[v]);
        w[u] = mex(a); 
        t[w[u]]++;
    }
    
    // Set up system of linear equations
    vector<vector<ll>> a(M, vector<ll>(M + 1));
    ll inv = power(n + 1, mod2 - 2, mod2);
    
    for(int i = 0; i < M; i++) {
        a[i][i] = 1;
        if (i) a[i][M] = inv;
        for(int j = 0; j < M; j++) 
            a[i][j] = sub(a[i][j], t[i^j] * inv % mod2, mod2);
    }
    
    // Solve the system
    vector<ll> s = gauss(a);
    cout << s[0];
    
    return 0;
}