Editorial for Problem F

Iron Man and the Energy Array

Author : Udynamo

Required Knowledge : Prefix XOR Sums

Time Complexity : O(N)\mathcal{O}(N)

Editorialist : Udynamo

Approach:

First, calculate the initial energy output. Compute the XOR sum of elements at odd indices (a1a3a_1 \oplus a_3 \oplus \dots) and even indices (a2a4a_2 \oplus a_4 \oplus \dots). If the required XOR sum (based on mm) already equals xx, print YES\text{YES}.

If an insertion is needed, our goal is to find the smallest non-negative valval, and among those, the smallest position pospos.

Let's know a key property of XOR is that for any aa and cc, there is a unique bb such that ab=ca \oplus b = c (specifically, b=acb = a \oplus c). This means that if we insert our val at a position that contributes to the final XOR sum, we can always find a val to make the result equal to xx. The challenge is to find the smallest such val by checking all possible positions.

A naive O(n2)O(n^2) approach (trying all n+1n+1 positions and recalculating) is too slow. We can optimize to O(n)O(n) using prefix XORs.

The key observation is that inserting at position flips the parity of all elements at and after that position. An element aia_i (where iposi \ge \text{pos}) that was even-indexed becomes odd-indexed, and vice-versa.

We can precompute two prefix XOR arrays (one for evens, one for odds) in O(n)O(n). Let's use 0-based indexing (a[0],a[1],a[0], a[1], \dots) for this.

evenprefix[i]\text{evenprefix[i]} : XOR sum of a[0],a[2],a[0], a[2], \dots up to index i1i-1.

oddprefix[i]\text{oddprefix[i]} : XOR sum of a[1],a[3],a[1], a[3], \dots up to index i1i-1.

Now, iterate position from 1 to n+1n+1. Let j=pos1j = \text{pos} - 1. For each pos, we can calculate the XOR sum of the original elements after the parity shift in O(1)O(1):

New Odd XOR Sum (for m=1m=1): evenprefix[j](totaloddxoroddprefix[j])\text{evenprefix[j]} \oplus (\text{totaloddxor} \oplus \text{oddprefix[j]})

New Even XOR Sum (for m=0m=0): oddprefix[j](totalevenxorevenprefix[j])\text{oddprefix[j]} \oplus (\text{totalevenxor} \oplus \text{evenprefix[j]})

Let this be current-xor\text{current-xor}. We have two cases:

  • Position has the wrong parity (e.g., m=0m=0 but pos is odd or m=1m=1 but pos is even): The inserted val isn't part of the sum. If current-xor==x\text{current-xor} == x, we have a valid solution by inserting smallest value 00.

  • Position has the correct parity (e.g., m=0m=0 and pos is even or m=1m=1 and pos is odd): The inserted val is part of the sum. We need current-xorval=x\text{current-xor} \oplus \text{val} = x. We can always achieve this with val=current-xorx\text{val} = \text{current-xor} \oplus x.

After checking all n+1n+1 positions, Among all valid soltuion, take solution with smallest value valval and if muliple positions exists take solution with smallest position.Print YES\text{YES} followed by this valval and pospos.

The overall time complexity is O(n)O(n) if we track the best solution while iterating, or O(nlogn)O(n \log n) if we store all O(n)O(n) solutions and sort them.

Setter's Code:

#include <bits/stdc++.h>
#include <ext/pb_ds/assoc_container.hpp>
#include <ext/pb_ds/tree_policy.hpp>
using namespace __gnu_pbds;
typedef tree<int, null_type, std::less<int>, rb_tree_tag, 
tree_order_statistics_node_update> pbds;
//find_by_order(k), order_of_key(k)
#define int long long
#define ff first
#define ss second
#define pb push_back
#define vi vector<int>
#define vpi vector<pair<int, int>>
#define pi pair<int, int>
#define all(x) x.begin(),x.end()
#define rall(x) x.rbegin(),x.rend()
#define maxEle(a) *max_element(a.begin(), a.end())
#define minEle(a) *min_element(a.begin(), a.end())
#define sumAll(a) accumulate(a.begin(), a.end(), 0LL)

using namespace std;

// Input Overloads
template<typename T>
istream& operator>>(istream &is, vector<T> &v) {
    for (auto &x : v) is >> x;
    return is;
}

const bool test_cases = true;
const int mod = 1e9 + 7, INF = 1e18;

void UNsolve() {
    
    int n, m, x;
    cin >> n >> m >> x;
    vi a(n);
    cin >> a;

    int evenxor = 0;
    int oddxor = 0;
    for(int i = 0; i < n; i++){
        if(i % 2 == 0) evenxor ^= a[i];
        else oddxor ^= a[i];
    }
    
    int xv = (m == 1) ? evenxor : oddxor;
    if(xv == x){
        cout << "YES\n";
        return;
    }

    vi evenpref(n + 1, 0), oddpref(n + 1, 0);
    for (int i = 0; i < n; i++) {
        evenpref[i + 1] = evenpref[i];
        oddpref[i + 1] = oddpref[i];
        if (i % 2 == 0) { 
            evenpref[i + 1] ^= a[i];
        } else { 
            oddpref[i + 1] ^= a[i];
        }
    }
    
    vpi ans;

    if(m == 0){ 
        for(int i = 1; i <= n + 1; i++){
            int j = i - 1; 
            int xorval = oddpref[i - 1] ^ (evenxor ^ evenpref[i - 1]); 
            int tempval = x ^ xorval;
            if(i & 1){
                if(xorval == x){
                    ans.pb({0, i}); 
                }
            }else{
                ans.pb({tempval, i});
            }
        }
    } else { 
        for(int i = 1; i <= n + 1; i++){
            int xorval = evenpref[i - 1] ^ (oddxor ^ oddpref[i - 1]); 
            int tempval = x ^ xorval;
            if(i & 1){
                ans.pb({tempval, i});
            }else{
                if(xorval == x){
                    ans.pb({0, i});
                }
            }
        }
    }

    sort(all(ans));
    cout << "YES " << ans[0].ff << " " << ans[0].ss << "\n";
}

signed main() {

    ios::sync_with_stdio(false);
    cin.tie(nullptr);
    int t = 1;
    if (test_cases) cin >> t;
    while (t--) {
        UNsolve();
    }
    return 0;
}

Tester's Code:

import sys
input = sys.stdin.readline


t = int(input())
for _ in range(t):
    n, m, x = map(int, input().split())
    a = list(map(int, input().split()))
    pre = [0] * n

    for i in range(n - 1, -1, -1):
        if i + 2 < n:
            pre[i] = a[i] ^ pre[i + 2]
        else:
            pre[i] = a[i]

    ans = (1 << 63) - 1
    ind = -1
    f = 0
    c = 0

    if n == 1:
        if a[0] == x:
            if m == 1:
                f = 1
            else:
                ans = 0
                ind = 0
        elif m == 0:
            ind = 1
            ans = x
        else:
            ind = 0
            ans = x
    else:
        for i in range(n):
            if m == 1 and pre[0] == x:
                f = 1
                break
            if m == 0 and pre[1] == x:
                f = 1
                break
            if m == 0:
                if i % 2 == 1:
                    if i == n - 1:
                        k = x ^ c
                        if ans > k:
                            ans = k
                            ind = i
                    if i + 1 < n:
                        k = pre[i + 1] ^ c
                        k = k ^ x
                        if ans > k:
                            ans = k
                            ind = i
                    c ^= a[i]
                else:
                    k = pre[i] ^ c
                    if k == x:
                        if ans > 0:
                            ans = 0
                            ind = i
            elif m == 1:
                if i % 2 == 0:
                    if i == n - 1:
                        k = x ^ c
                        if ans > k:
                            ans = k
                            ind = i
                    if i + 1 < n:
                        k = pre[i + 1] ^ c
                        k = k ^ x
                        if ans > k:
                            ans = k
                            ind = i
                    c ^= a[i]
                else:
                    k = pre[i] ^ c
                    if k == x:
                        if ans > 0:
                            ans = 0
                            ind = i

    if m == 0 and n % 2 == 1:
        k = x ^ c
        if ans > k:
            ans = k
            ind = n
    if m == 1 and n % 2 == 0:
        k = x ^ c
        if ans > k:
            ans = k
            ind = n

    if f == 1 or (x == 0 and m == 0):
        print("YES")
    else:
        if ind == -1:
            print("NO")
        else:
            print("YES", ans, ind + 1)