Editorial for Problem I

Kolar.G1.Fields

Author : reddyjeevan

Required Knowledge : Segment Tree Beats, Heavy-Light Decomposition (HLD), Binary Search

Time Complexity : O(N+Qlog3N)\mathcal{O}(N + Q \log^3 N)

Editorialist : reddyjeevan

Approach:

The problem requires us to perform path range chmin updates (ai=min(ai,x)a_i = \min(a_i, x)) and path prefix-sum queries with an early-exit condition on a tree. We solve this by flattening the tree into an array using Heavy-Light Decomposition (HLD) and applying Segment Tree Beats.

Segment Tree Beats

Standard segment trees cannot easily perform a range chmin update (ai=min(ai,x)a_i = \min(a_i, x)) while efficiently maintaining the range sum. We solve this using Segment Tree Beats.

Each node in the segment tree tracks four values for its respective range: the sum (sumsum), the strict maximum (mxmx), the strictly second maximum (smxsmx), and the frequency of the maximum (mxcntmxcnt).

The update logic for chmin(v)\text{chmin}(v) is as follows:

Break Condition (vmxv \ge mx): The update has no effect. Return immediately.

Tag Condition (smx<v<mxsmx < v < mx): The update only changes elements tied for the maximum. We apply the update lazily: sum=sum(mxv)×mxcntsum = sum - (mx - v) \times mxcnt, set mx=vmx = v, and return.

Recurse Condition (vsmxv \le smx): The update affects multiple distinct values. We push the updates down to the children and recalculate.

Because we force the tree to recurse deeper when vsmxv \le smx, it might seem like the tree would be too slow. However, the time complexity of these queries is amortized. Every time we recurse deeply, we are forcing distinct values in the segment to merge (decreasing the number of unique elements). A strict formal proof using a potential function shows that the amortized time complexity per query is highly efficient.

You can refer to this codeforces Blog on Segment Tree Beats to find the exact proof of the time complexity.

Linearizing the Tree: Heavy-Light Decomposition (HLD)

Since our nodes form a tree, we must flatten the structure into a 1D array to utilize our Segment Tree Beats logic.

We use Heavy-Light Decomposition (HLD). By decomposing the tree into vertical "heavy chains" and mapping node indices in DFS order (always visiting heavy children first), each heavy chain becomes a contiguous subarray.

As a result, any tree path uvu \to v can be broken down into at most O(logn)\mathcal{O}(\log n) contiguous 1D array segments.

Type 1 Queries: Path Range chmin Updates

We need to update the production values on the path uvu \to v to be at most xx.

Use HLD to split the path uvu \to v into O(logn)\mathcal{O}(\log n) contiguous array segments.

For each segment [l,r][l, r], call the chmin update on the Segment Tree Beats.

Time Complexity: Amortized O(log2n)\mathcal{O}(\log^2 n) per query. The HLD breaks the path into O(logn)\mathcal{O}(\log n) segments, and the Segment Tree Beats processes range updates in an amortized O(logn)\mathcal{O}(\log n) time per segment.

Type 2 Queries: Path Queries with Early Exit

The traversal moves exactly from uu to vv, adding aia_i to a running sum, but stops immediately if it encounters a node where ai>xa_i > x.

Because HLD breaks the path into fragments, we must process them in the exact chronological order of the traversal. We split the path at the Lowest Common Ancestor, LCA(u,v)\text{LCA}(u, v).

Path 1: Moving Up ( uLCA(u,v)u \to \text{LCA}(u, v)) Moving up a heavy chain means moving from higher array indices to lower ones. We scan segments [l,r][l, r] from right to left.

Binary search the segment tree to find the rightmost index idxidx where mx>xmx > x.

If found: Add the sum of the range [idx+1,r][idx + 1, r] and terminate the entire query.

If not found: Add the sum of the range [l,r][l, r] and jump up to the next chain.

Path 2: Moving Down ( LCA(u,v)v\text{LCA}(u, v) \to v) Moving down a heavy chain means moving from lower array indices to higher ones. We scan segments [l,r][l, r] from left to right. (Note: Extract these segments bottom-up using HLD, then reverse the list to process them top-down).

Binary search the segment tree to find the leftmost index idxidx where mx>xmx > x.

If found: Add the sum of the range [l,idx1][l, idx - 1] and terminate the entire query.

If not found: Add the sum of the range [l,r][l, r] and move down to the next chain.

Time Complexity: Amortized O(log3n)\mathcal{O}(\log^3 n) per query. The path is split into O(logn)\mathcal{O}(\log n) segments using Heavy-Light Decomposition. For each segment, a binary search over the segment of size kk takes O(logk)\mathcal{O}(\log k) steps, and each step requires a segment tree query in an amortized O(logn)\mathcal{O}(\log n) time. Since knk \le n, this results in O(log2n)\mathcal{O}(\log^2 n) per segment, and overall O(log3n)\mathcal{O}(\log^3 n).

Setter's Code

#include <bits/stdc++.h>
#define int long long
#define double long double
#define ff first
#define ss second
#define all(x) x.begin(), x.end()
#define rall(x) x.rbegin(), x.rend()
#define sz(x) (int)x.size()
#define pb push_back
#define ppb pop_back
#define pii pair<int, int>
#define vi vector<int>
#define vpii vector<pii>
using namespace std;
const int mod = 1e9 + 7;

vi pos, head, heavy, depth, sub, par, lt, val;
int idx;
vector<vi> adj;

class segtreebeats
{
public:
    struct Node
    {
        int sum;
        int mx;
        int smx;
        int mxcnt;
    };

    vector<Node> tree;
    int n;

    segtreebeats(vi &a)
    {
        n = a.size() - 1;
        tree.resize(4 * n + 10);
        build(1, 1, n, a);
    }

    void merge(int idx)
    {
        int lc = 2 * idx;
        int rc = 2 * idx + 1;

        tree[idx].sum = (tree[lc].sum + tree[rc].sum) % mod;

        if (tree[lc].mx == tree[rc].mx)
        {
            tree[idx].mx = tree[lc].mx;
            tree[idx].smx = max(tree[lc].smx, tree[rc].smx);
            tree[idx].mxcnt = tree[lc].mxcnt + tree[rc].mxcnt;
        }
        else if (tree[lc].mx > tree[rc].mx)
        {
            tree[idx].mx = tree[lc].mx;
            tree[idx].smx = max(tree[lc].smx, tree[rc].mx);
            tree[idx].mxcnt = tree[lc].mxcnt;
        }
        else
        {
            tree[idx].mx = tree[rc].mx;
            tree[idx].smx = max(tree[rc].smx, tree[lc].mx);
            tree[idx].mxcnt = tree[rc].mxcnt;
        }
    }

    void apply_chmin(int idx, int v)
    {
        if (v >= tree[idx].mx)
            return;

        int diff = (tree[idx].mx - v) % mod;
        int total_dec = (diff * (tree[idx].mxcnt % mod)) % mod;
        tree[idx].sum = (tree[idx].sum - total_dec + mod) % mod;

        tree[idx].mx = v;
    }

    void push_down(int idx)
    {
        apply_chmin(2 * idx, tree[idx].mx);
        apply_chmin(2 * idx + 1, tree[idx].mx);
    }

    void build(int idx, int l, int r, vector<int> &a)
    {
        if (l == r)
        {
            tree[idx].sum = a[l] % mod;
            tree[idx].mx = a[l];
            tree[idx].smx = -1e18;
            tree[idx].mxcnt = 1;
            return;
        }

        int mid = (l + r) / 2;
        build(2 * idx, l, mid, a);
        build(2 * idx + 1, mid + 1, r, a);

        merge(idx);
    }

    void update(int idx, int l, int r, int ql, int qr, int v)
    {
        if (l > qr || r < ql || tree[idx].mx <= v)
            return;

        if (l >= ql && r <= qr && v > tree[idx].smx)
        {
            apply_chmin(idx, v);
            return;
        }

        push_down(idx);

        int mid = (l + r) / 2;

        update(2 * idx, l, mid, ql, qr, v);
        update(2 * idx + 1, mid + 1, r, ql, qr, v);

        merge(idx);
    }

    int query(int idx, int l, int r, int ql, int qr)
    {
        if (l > qr || r < ql)
            return 0;

        if (l >= ql && r <= qr)
            return tree[idx].sum;

        push_down(idx);

        int mid = (l + r) / 2;

        return (query(2 * idx, l, mid, ql, qr) +
                query(2 * idx + 1, mid + 1, r, ql, qr)) %
               mod;
    }

    int query_max(int idx, int l, int r, int ql, int qr)
    {
       if (ql > qr || l > qr || r < ql)
            return -1e18;

        if (l >= ql && r <= qr)
            return tree[idx].mx;

        push_down(idx);

        int mid = (l + r) / 2;

        return max(query_max(2 * idx, l, mid, ql, qr),
                   query_max(2 * idx + 1, mid + 1, r, ql, qr));
    }

    void update_path(int a, int b, int v)
    {
        while (head[a] != head[b])
        {
            if (depth[head[a]] > depth[head[b]])
            {
                swap(a, b);
            }
            update(1, 1, n, pos[head[b]], pos[b], v);
            b = par[head[b]];
        }
        if (depth[a] > depth[b])
        {
            swap(a, b);
        }
        update(1, 1, n, pos[a], pos[b], v);
    }

    int find_first(int l, int r, int v)
    {
        int idx = -1;
        while (l <= r)
        {
            int mid = (l + r) / 2;
            if (query_max(1, 1, n, l, mid) > v)
            {
                idx = mid;
                r = mid - 1;
            }
            else
            {
                l = mid + 1;
            }
        }
        return idx;
    }

    int find_last(int l, int r, int v)
    {
        int idx = -1;
        while (l <= r)
        {
            int mid = (l + r) / 2;
            if (query_max(1, 1, n, mid, r) > v)
            {
                idx = mid;
                l = mid + 1;
            }
            else
            {
                r = mid - 1;
            }
        }
        return idx;
    }

    int query_path(int a, int b, int v)
    {
        int ans = 0;
        vector<pair<int, int>> down;
        int f = 0;

        while (head[a] != head[b])
        {
            if (depth[head[a]] > depth[head[b]])
            {
                int idx = find_last(pos[head[a]], pos[a], v);

                if (idx == -1)
                {
                    ans = (ans + query(1, 1, n, pos[head[a]], pos[a])) % mod;
                    a = par[head[a]];
                }
                else
                {
                    {
                        ans = (ans + query(1, 1, n, idx + 1, pos[a])) % mod;
                    }
                    f = 1;
                    break;
                }
            }
            else
            {
                down.push_back({pos[head[b]], pos[b]});
                b = par[head[b]];
            }
        }

        if (f)
            return ans;

        if (depth[a] > depth[b])
        {
            int idx = find_last(pos[b], pos[a], v);
            if (idx == -1)
            {
                ans = (ans + query(1, 1, n, pos[b], pos[a])) % mod;
            }
            else
            {
                {
                    ans = (ans + query(1, 1, n, idx + 1, pos[a])) % mod;
                }
                return ans;
            }
        }
        else
        {
            down.push_back({pos[a], pos[b]});
        }

        for (int i = down.size() - 1; i >= 0; i--)
        {
            int l = down[i].first;
            int r = down[i].second;

            int idx = find_first(l, r, v);

            if (idx == -1)
            {
                ans = (ans + query(1, 1, n, l, r)) % mod;
            }
            else
            {
                {
                    ans = (ans + query(1, 1, n, l, idx - 1)) % mod;
                }
                break;
            }
        }
        return ans;
    }
};

void dfs(int node, int p)
{
    sub[node] = 1;
    par[node] = p;

    for (auto it : adj[node])
    {
        if (it == p)
            continue;
        depth[it] = depth[node] + 1;
        dfs(it, node);
        sub[node] += sub[it];
        if (!heavy[node] || sub[it] > sub[heavy[node]])
        {
            heavy[node] = it;
        }
    }
}

void dfshld(int node, int h)
{
    head[node] = h;
    pos[node] = idx;
    lt[idx++] = val[node];

    if (heavy[node])
    {
        dfshld(heavy[node], h);
    }
    for (auto it : adj[node])
    {
        if (it == par[node] || it == heavy[node])
            continue;
        dfshld(it, it);
    }
}

void G1solve()
{
    int n, q;
    cin >> n >> q;
    val.assign(n + 1, 0);
    head.assign(n + 1, 0);
    pos.assign(n + 1, 0);
    sub.assign(n + 1, 0);
    depth.assign(n + 1, 0);
    par.assign(n + 1, 0);
    heavy.assign(n + 1, 0);
    lt.assign(n + 1, 0);

    for (int i = 1; i <= n; i++)
        cin >> val[i];

    adj.assign(n + 1, vector<int>());
    for (int i = 0; i < n - 1; i++)
    {
        int u, v;
        cin >> u >> v;
        adj[u].pb(v);
        adj[v].pb(u);
    }

    idx = 1;
    depth[1] = 1;
    dfs(1, 0);
    dfshld(1, 1);

    segtreebeats st(lt);
    while (q--)
    {
        int t;
        cin >> t;
        if (t == 1)
        {
            int a, b, v;
            cin >> a >> b >> v;
            st.update_path(a, b, v);
        }
        else
        {
            int a, b, v;
            cin >> a >> b >> v;
            cout << st.query_path(a, b, v) << "\n";
        }
    }
}

int32_t main()
{
    ios::sync_with_stdio(false);
    cin.tie(nullptr);
    int t = 1;
    while (t--)
    {
        G1solve();
    }
    return 0;
}

Tester's Code

import java.io.*;
import java.util.*;

public class Main {
    static final int MOD = 1_000_000_007;
    static int[] head, pos, parent, depth, heavy, size, nodeAt;
    static long[] treeSum, treeMax, treeMax2, treeMaxCnt;
    static int curPos, n;

    public static void main(String[] args) throws IOException {
        FastReader fr = new FastReader(System.in);
        PrintWriter out = new PrintWriter(System.out);

        String ns = fr.next();
        if (ns == null) return;
        n = Integer.parseInt(ns);
        int q = fr.nextInt();
        long[] a = new long[n + 1];
        for (int i = 1; i <= n; i++) a[i] = fr.nextLong();

        List<Integer>[] adj = new ArrayList[n + 1];
        for (int i = 1; i <= n; i++) adj[i] = new ArrayList<>();
        for (int i = 0; i < n - 1; i++) {
            int u = fr.nextInt(), v = fr.nextInt();
            adj[u].add(v); adj[v].add(u);
        }

        head = new int[n + 1]; pos = new int[n + 1]; parent = new int[n + 1];
        depth = new int[n + 1]; heavy = new int[n + 1]; size = new int[n + 1];
        nodeAt = new int[n + 1];
        
        dfsSize(1, 0, 0, adj);
        dfsHld(1, 1, adj);

        treeSum = new long[4 * n]; treeMax = new long[4 * n];
        treeMax2 = new long[4 * n]; treeMaxCnt = new long[4 * n];
        build(1, 0, n - 1, a);

        while (q-- > 0) {
            int type = fr.nextInt();
            int u = fr.nextInt(), v = fr.nextInt();
            long x = fr.nextLong();
            if (type == 1) updatePath(u, v, x);
            else out.println(querySurge(u, v, x));
        }
        out.flush();
    }

    static void dfsSize(int u, int p, int d, List<Integer>[] adj) {
        parent[u] = p; depth[u] = d; size[u] = 1;
        for (int v : adj[u]) {
            if (v != p) {
                dfsSize(v, u, d + 1, adj);
                size[u] += size[v];
                if (heavy[u] == 0 || size[v] > size[heavy[u]]) heavy[u] = v;
            }
        }
    }

    static void dfsHld(int u, int h, List<Integer>[] adj) {
        head[u] = h; pos[u] = curPos; nodeAt[curPos++] = u;
        if (heavy[u] != 0) {
            dfsHld(heavy[u], h, adj);
            for (int v : adj[u]) {
                if (v != parent[u] && v != heavy[u]) dfsHld(v, v, adj);
            }
        }
    }

    static void build(int v, int tl, int tr, long[] a) {
        if (tl == tr) {
            treeSum[v] = treeMax[v] = a[nodeAt[tl]];
            treeMax2[v] = -1; treeMaxCnt[v] = 1;
            return;
        }
        int tm = (tl + tr) / 2;
        build(2 * v, tl, tm, a); 
        build(2 * v + 1, tm + 1, tr, a);
        pushUp(v);
    }

    static void pushUp(int v) {
        int lc = 2 * v, rc = 2 * v + 1;
        treeSum[v] = (treeSum[lc] + treeSum[rc]) % MOD;
        if (treeMax[lc] == treeMax[rc]) {
            treeMax[v] = treeMax[lc];
            treeMaxCnt[v] = treeMaxCnt[lc] + treeMaxCnt[rc];
            treeMax2[v] = Math.max(treeMax2[lc], treeMax2[rc]);
        } else if (treeMax[lc] > treeMax[rc]) {
            treeMax[v] = treeMax[lc];
            treeMaxCnt[v] = treeMaxCnt[lc];
            treeMax2[v] = Math.max(treeMax2[lc], treeMax[rc]);
        } else {
            treeMax[v] = treeMax[rc];
            treeMaxCnt[v] = treeMaxCnt[rc];
            treeMax2[v] = Math.max(treeMax[lc], treeMax2[rc]);
        }
    }

    static void putTag(int v, long x) {
        if (treeMax[v] <= x) return;
        long diff = (treeMax[v] - x) % MOD;
        treeSum[v] = (treeSum[v] - (treeMaxCnt[v] % MOD) * diff % MOD + MOD) % MOD;
        treeMax[v] = x;
    }

    static void pushDown(int v) {
        putTag(2 * v, treeMax[v]);
        putTag(2 * v + 1, treeMax[v]);
    }

    static void update(int v, int tl, int tr, int l, int r, long x) {
        if (l > tr || r < tl || treeMax[v] <= x) return;
        if (l <= tl && tr <= r && treeMax2[v] < x) {
            putTag(v, x); return;
        }
        pushDown(v);
        int tm = (tl + tr) / 2;
        update(2 * v, tl, tm, l, r, x);
        update(2 * v + 1, tm + 1, tr, l, r, x);
        pushUp(v);
    }

    static long getSum(int v, int tl, int tr, int l, int r) {
        if (l > tr || r < tl) return 0;
        if (l <= tl && tr <= r) return treeSum[v];
        pushDown(v);
        int tm = (tl + tr) / 2;
        return (getSum(2 * v, tl, tm, l, r) + getSum(2 * v + 1, tm + 1, tr, l, r)) % MOD;
    }

    static int findFirstGT(int v, int tl, int tr, int l, int r, long x, boolean rev) {
        if (l > tr || r < tl || treeMax[v] <= x) return -1;
        if (tl == tr) return tl;
        pushDown(v);
        int tm = (tl + tr) / 2;
        if (rev) {
            int res = findFirstGT(2 * v + 1, tm + 1, tr, l, r, x, true);
            return (res != -1) ? res : findFirstGT(2 * v, tl, tm, l, r, x, true);
        } else {
            int res = findFirstGT(2 * v, tl, tm, l, r, x, false);
            return (res != -1) ? res : findFirstGT(2 * v + 1, tm + 1, tr, l, r, x, false);
        }
    }

    static void updatePath(int u, int v, long x) {
        while (head[u] != head[v]) {
            if (depth[head[u]] < depth[head[v]]) { int t = u; u = v; v = t; }
            update(1, 0, n - 1, pos[head[u]], pos[u], x);
            u = parent[head[u]];
        }
        if (depth[u] < depth[v]) { int t = u; u = v; v = t; }
        update(1, 0, n - 1, pos[v], pos[u], x);
    }

    static long querySurge(int u, int v, long x) {
        int lca = getLCA(u, v);
        long totalSum = 0;
        int curr = u;
        while (true) {
            int l = pos[head[curr]], r = pos[curr];
            if (head[curr] == head[lca]) l = pos[lca];
            int idx = findFirstGT(1, 0, n - 1, l, r, x, true);
            if (idx != -1) return (totalSum + getSum(1, 0, n - 1, idx + 1, r)) % MOD;
            totalSum = (totalSum + getSum(1, 0, n - 1, l, r)) % MOD;
            if (head[curr] == head[lca]) break;
            curr = parent[head[curr]];
        }
        List<int[]> segments = new ArrayList<>();
        curr = v;
        while (head[curr] != head[lca]) {
            segments.add(new int[]{pos[head[curr]], pos[curr]});
            curr = parent[head[curr]];
        }
        if (pos[curr] > pos[lca]) segments.add(new int[]{pos[lca] + 1, pos[curr]});
        for (int i = segments.size() - 1; i >= 0; i--) {
            int l = segments.get(i)[0], r = segments.get(i)[1];
            int idx = findFirstGT(1, 0, n - 1, l, r, x, false);
            if (idx != -1) return (totalSum + getSum(1, 0, n - 1, l, idx - 1)) % MOD;
            totalSum = (totalSum + getSum(1, 0, n - 1, l, r)) % MOD;
        }
        return totalSum;
    }

    static int getLCA(int u, int v) {
        while (head[u] != head[v]) {
            if (depth[head[u]] > depth[head[v]]) u = parent[head[u]];
            else v = parent[head[v]];
        }
        return depth[u] < depth[v] ? u : v;
    }

    static class FastReader {
        BufferedReader br;
        StringTokenizer st;
        FastReader(InputStream in) { br = new BufferedReader(new InputStreamReader(in)); }
        String next() {
            while (st == null || !st.hasMoreElements()) {
                try {
                    String line = br.readLine();
                    if (line == null) return null;
                    st = new StringTokenizer(line);
                } catch (IOException e) {}
            }
            return st.nextToken();
        }
        int nextInt() { return Integer.parseInt(next()); }
        long nextLong() { return Long.parseLong(next()); }
    }
}