Editorial for Problem G

Good Knight

Author : Speedster1010

Required Knowledge : Binary Lifting, Segment Tree

Time Complexity : O((N+Q)logN)\mathcal{O}((N + Q) \cdot \log N)

Editorialist : Speedster1010

Approach 1: (Binary Lifting)

To better understand this problem, let's break down each step the hero takes.

1. The Walk Once the knight finishes teleporting and starts walking from some node uu, the power does not restrict them! The only rule is moving further from the capital. To maximize visited cities, we just take the longest path down to a leaf. The max cities visited from uu is exactly the height of the subtree rooted at uu. We can precompute this with a simple DFS.

2. The Teleport Before walking, the knight can teleport to an ancestor vv. Why only ancestors? Because teleporting higher up the tree (closer to the capital) strictly gives us a larger or equal downward path length! Higher is always better.

To teleport to an ancestor vv, the path from cic_i to vv must have max(d)min(d)pi\max(d) - \min(d) \le p_i. Because the maximum defense only increases and the minimum defense only decreases as we go higher up the tree, the instability is perfectly monotonic.

This problem can be optimized by using binary lifting. We can precompute the 2kth2^{k^{th}} ancestor for every node, along with the maximum and minimum defense values on that jump. Hence, allowing us to process a query in O(logN)\mathcal{O}(\log N) time.

#include<bits/stdc++.h>
using namespace std;
#define int int64_t
#define endl '\n'
const int inf = 1e18;

void solve() {
    int n, q;
    cin >> n >> q;
    vector<int> d(n);
    for(int i = 0; i < n; ++i) {
        cin >> d[i];
    }
    vector<vector<int>> adj(n);
    for(int u, v, i = 1; i < n; ++i) {
        cin >> u >> v;
        --u, --v;
        adj[u].push_back(v);
        adj[v].push_back(u);
    }
    
    // calculating the height, parent of each city
    vector<int> height(n, 1), parent(n);
    function<void(int, int)> dfs = [&](int u, int p) {
        parent[u] = p;
        for(auto &v : adj[u]) {
            if(v != p) {
                dfs(v, u);
                height[u] = max(height[u], height[v] + 1);
            }
        }
    }; dfs(0, -1);
    
    // precomputing binary lifting table
    int LOG = 0;
    while ((1 << LOG) <= n) ++LOG;
    ++LOG;
    vector<vector<int>> up(LOG, vector<int>(n, -1));
    vector<vector<int>> mn(LOG, vector<int>(n, inf));
    vector<vector<int>> mx(LOG, vector<int>(n, -inf));
    for(int u = 0; u < n; ++u) {
        up[0][u] = parent[u];
        if(parent[u] != -1) {
            mn[0][u] = min(d[u], d[parent[u]]);
            mx[0][u] = max(d[u], d[parent[u]]);
        }
    }
    for(int i = 1; i < LOG; ++i) {
        for(int u = 0; u < n; ++u) {
            int mid = up[i - 1][u];
            if(mid != -1) {
                up[i][u] = up[i - 1][mid];
                mn[i][u] = min(mn[i - 1][u], mn[i - 1][mid]);
                mx[i][u] = max(mx[i - 1][u], mx[i - 1][mid]);
            }
        }
    }
    
    // answering queries
    while(q--) {
        int c, p;
        cin >> c >> p;
        --c;
    
        // maximum and minimum in the current path
        int maxd = d[c], mind = d[c];
        for(int k = LOG - 1; k >= 0; --k) {
            if(up[k][c] == -1) {
                continue;
            }
            if(max(maxd, mx[k][c]) - min(mind, mn[k][c]) <= p) {
                mind = min(mind, mn[k][c]);
                maxd = max(maxd, mx[k][c]);
                c = up[k][c];
            }
        }
        cout << height[c] << endl;
    }
}
 
signed main() {
    ios_base::sync_with_stdio(0);
    cin.tie(0);
    
    int t = 1;
    cin >> t;
    while(t--)
        solve();
}

Approach 2: (Offline Queries with Segment Tree)

Instead of using binary lifting, we can elegantly process the queries offline! We can traverse the tree using a DFS to maintain the active path from the capital down to our current city. By building a Segment Tree over this active path, we can track the maximum and minimum defense values. When processing a query at city cic_i, we can binary search the length of the valid upward path. For a guessed path length, we simply query the segment tree to check if the instability (max(d)min(d))(\max(d) - \min(d))ispi\le p_i. This approach reduces the search for the highest valid ancestor to O(log2N)\mathcal{O}(\log^2 N) time per query, offering an intuitive approach.

#include<bits/stdc++.h>
using namespace std;
#define int int64_t
const int INF = 1e18;

struct S { 
    int mn, mx, diff;
};
S op(S a, S b) {
    int mini = min(a.mn, b.mn);
    int maxi = max(a.mx, b.mx);
    return {mini, maxi, maxi - mini};
}
S e() {
    return {INF, -INF, -2 * INF}; 
}

struct segtree {
    int N;
    vector<S> seg;
    segtree(int N) : N(N), seg(4 * N, e()) {}
    
    void set(int ind, int start, int end, int pos, S val) {
        if(start == end) {
            seg[ind] = val;
            return;
        }
        int mid = (start + end) / 2;
        if(pos <= mid) set(ind * 2, start, mid, pos, val);
        else set(ind * 2 + 1, mid + 1, end, pos, val);
        seg[ind] = op(seg[ind * 2], seg[ind * 2 + 1]);
    }
    void set(int pos, S val) { set(1, 0, N - 1, pos, val); }
    
    S prod(int ind, int start, int end, int l, int r) {
        if(l <= start && end <= r) return seg[ind];
        if(end < l || r < start) return e();
        int mid = (start + end) / 2;
        return op(prod(ind * 2, start, mid, l, r), prod(ind * 2 + 1, mid + 1, end, l, r));
    }
    S prod(int l, int r) { return prod(1, 0, N - 1, l, r); }
};

void solve() {
    int n, q;
    cin >> n >> q;
    vector<int> d(n);
    for(int i = 0; i < n; ++i) {
        cin >> d[i];
    }

    vector<vector<int>> adj(n);
    for(int u, v, i = 1; i < n; ++i) {
        cin >> u >> v;
        --u; --v;
        adj[u].push_back(v);
        adj[v].push_back(u);
    }

    vector<vector<pair<int, int>>> queries(n);
    for(int c, p, i = 0; i < q; ++i) {
        cin >> c >> p;
        --c;
        queries[c].push_back({ p, i });
    }

    vector<int> height(n, 1);
    function<void(int, int)> dfs1 = [&](int u, int p) {
        for(int v : adj[u]) {
            if(v != p) {
                dfs1(v, u);
                height[u] = max(height[u], height[v] + 1);
            }
        }
    }; 
    dfs1(0, -1);

    vector<int> ans(q), path;
    segtree seg(n + 5);
    function<void(int, int, int)> dfs2 = [&](int u, int pr, int depth) {
        path.push_back(u);
        seg.set(depth, {d[u], d[u], 0});

        for(auto &[p, i] : queries[u]) {
            int lo = 0, hi = depth + 1;
            while(hi - lo > 1) {
                int mid = (lo + hi) / 2;
                if(seg.prod(depth - mid, depth).diff <= p) {
                    lo = mid;
                } else {
                    hi = mid;
                }
            }
            ans[i] = height[path[depth - lo]];
        }

        for(int v : adj[u]) {
            if(v != pr) {
                dfs2(v, u, depth + 1);
            }
        }
        path.pop_back();
    }; 
    dfs2(0, -1, 0);

    for(int i = 0; i < q; ++i) {
        cout << ans[i] << '\n';
    }
}

signed main() {
    ios_base::sync_with_stdio(0);
    cin.tie(0);
    
    int t = 1;
    cin >> t;
    while(t--) 
        solve();
}


Tester's Code:

import sys
 
def solve():
    input_data = sys.stdin.read().split()
    if not input_data:
        return
    
    iterator = iter(input_data)
    t = int(next(iterator))
    out = []
    
    for _ in range(t):
        n = int(next(iterator))
        q = int(next(iterator))
        
        d = [0] * (n + 1)
        for i in range(1, n + 1):
            d[i] = int(next(iterator))
            
        adj = [[] for _ in range(n + 1)]
        for _ in range(n - 1):
            u = int(next(iterator))
            v = int(next(iterator))
            adj[u].append(v)
            adj[v].append(u)
            
        parent = [0] * (n + 1)
        children = [[] for _ in range(n + 1)]
        order = []
        
        qq = [1]
        head = 0
        while head < len(qq):
            u = qq[head]
            head += 1
            order.append(u)
            for v in adj[u]:
                if v != parent[u]:
                    parent[v] = u
                    children[u].append(v)
                    qq.append(v)
                    
        height = [1] * (n + 1)
        for i in range(len(order) - 1, -1, -1):
            u = order[i]
            h = 1
            for v in children[u]:
                if height[v] + 1 > h:
                    h = height[v] + 1
            height[u] = h
            
        LOG = 0
        while (1 << LOG) <= n:
            LOG += 1
        LOG += 1
        
        up = [[0] * (n + 1) for _ in range(LOG)]
        mn = [[0] * (n + 1) for _ in range(LOG)]
        mx = [[0] * (n + 1) for _ in range(LOG)]
        
        for u in range(1, n + 1):
            up[0][u] = parent[u]
            if parent[u] != 0:
                mn[0][u] = min(d[u], d[parent[u]])
                mx[0][u] = max(d[u], d[parent[u]])
            else:
                mn[0][u] = d[u]
                mx[0][u] = d[u]
                
        for k in range(1, LOG):
            for u in range(1, n + 1):
                mid = up[k - 1][u]
                if mid != 0:
                    up[k][u] = up[k - 1][mid]
                    mn[k][u] = min(mn[k - 1][u], mn[k - 1][mid])
                    mx[k][u] = max(mx[k - 1][u], mx[k - 1][mid])
                else:
                    up[k][u] = 0
                    mn[k][u] = mn[k - 1][u]
                    mx[k][u] = mx[k - 1][u]
                    
        for _ in range(q):
            c = int(next(iterator))
            p = int(next(iterator))
            
            u = c
            cur_min = d[c]
            cur_max = d[c]
            
            for k in range(LOG - 1, -1, -1):
                nxt = up[k][u]
                if nxt != 0:
                    new_min = min(cur_min, mn[k][u])
                    new_max = max(cur_max, mx[k][u])
                    if new_max - new_min <= p:
                        u = nxt
                        cur_min = new_min
                        cur_max = new_max
                        
            out.append(str(height[u]))
            
    sys.stdout.write('\n'.join(out) + '\n')
 
if __name__ == '__main__':
    solve()