Editorial for Problem D

Cut to the Chase 3

Author : TheoGermal

Required Knowledge : Maths, Lazy Segment Tree

Time Complexity : O(Qlog2N)O(Q\cdot log_2{N})

Editorialist : tnaveen2308

Approach:

At first glance, this problem appears to require a complex data structure to handle both the range updates and queries efficiently. Let's analyze the update operation more carefully.

For a given update operation with parameters LL, RR, and xx, the increment applied to A[i]A[i] (where LiRL \leq i \leq R) is: A[i]+=(iL+1)xA[i] += (i - L + 1) \cdot x

We can rewrite this as: A[i]+=(ix)+((1L)x)A[i] += (i \cdot x) + ((1 - L) \cdot x)

This is equivalent to two separate operations:

  • Add ixi \cdot x to A[i]A[i] for each ii in range [L,R][L, R]
  • Add (1L)x(1-L) \cdot x to A[i]A[i] for each ii in range [L,R][L, R]

Now we can use two separate segment trees with lazy propagation:

  • First segment tree (ST1ST1): To handle updates and queries for the term (1L)x(1-L) \cdot x
  • Second segment tree (ST2ST2): To handle updates and queries for the term ixi \cdot x

When we need to compute the sum of elements in range [L,R][L, R], we query both segment trees and add the results.

Code:

#include <bits/stdc++.h>

#include <ext/pb_ds/assoc_container.hpp>
#include <ext/pb_ds/tree_policy.hpp>

using namespace std;
using namespace __gnu_pbds;
template <class T>
using pbds =
tree<T, null_type, less<T>, rb_tree_tag, tree_order_statistics_node_update>;
// #define cerr if(false)cerr
#define int long long
#define pb push_back
#define F first
#define S second
#define yes cout << "Yes\n"
#define no cout << "No\n"
#define yn(x) x ? yes : no
#define f(i, s, e) for (int i = s; i < e; i++)
#define vi vector<int>
#define vb vector<bool>
#define pii pair<int, int>
#define vpi vector<pii>
#define all(x) x.begin(), x.end()
#define minele(x) *min_element(all(x))
#define maxele(x) *max_element(all(x))
#define endl '\n'
#define in(v, x) binary_search(all(v), x)

const int N = 2e5 + 1;
const int MOD = 1e9 + 7;
const int inf = LLONG_MAX;
const int minf = LLONG_MIN;

#ifndef ONLINE_JUDGE
#define debug(x)            \
    cerr << (#x) << " is "; \
    _print(x)
#define dbg(x...)           \
    cerr << (#x) << " is "; \
    _print(x)
#else
#define debug(x)
#define dbg(x)
#define dbg(x...)
#endif

template <typename T>
void _print(T a) {
    cerr << a;
}
template <typename T1, typename... T2>
void _print(T1 t1, T2... t2) {
    cerr << t1 << ", ";
    _print(t2...);
    cerr << endl;
}
template <typename T>
void print(T a) {
    cout << a << ' ';
}
template <typename T>
void println(T a) {
    cout << a << endl;
}
template <class T>
istream &operator>>(istream &is, vector<T> &a) {
    for (auto &x : a) is >> x;
    return is;
}
template <class T>
ostream &operator<<(ostream &os, const vector<T> &a) {
    for (const auto &x : a) os << x << ' ';
    return os;
}

template <typename T>
void chmax(T &a, T b) {
    a = max(a, b);
}
template <typename T>
void chmin(T &a, T b) {
    a = min(a, b);
}

template <class T, class V>
void _print(pair<T, V> p);
template <class T>
void _print(vector<T> v);
template <class T>
void _print(set<T> v);
template <class T, class V>
void _print(map<T, V> v);
template <class T>
void _print(multiset<T> v);
template <class T, class V>
void _print(pair<T, V> p) {
    cerr << "{";
    _print(p.F);
    cerr << ",";
    _print(p.S);
    cerr << "} ";
}
template <class T>
void _print(vector<T> v) {
    cerr << "[ ";
    for (T i : v) {
        _print(i);
        cerr << " ";
    }
    cerr << "]";
    cerr << endl;
}
template <class T>
void _print(set<T> v) {
    cerr << "[ ";
    for (T i : v) {
        _print(i);
        cerr << " ";
    }
    cerr << "]";
    cerr << endl;
}
template <class T>
void _print(multiset<T> v) {
    cerr << "[ ";
    for (T i : v) {
        _print(i);
        cerr << " ";
    }
    cerr << "]";
    cerr << endl;
}
template <class T, class V>
void _print(map<T, V> v) {
    cerr << "[ ";
    for (auto i : v) {
        _print(i);
        cerr << " ";
    }
    cerr << "]";
    cerr << endl;
}
template <class T, class V>
void _print(unordered_map<T, V> v) {
    cerr << "[ ";
    for (auto i : v) {
        _print(i);
        cerr << " ";
    }
    cerr << "]";
    cerr << endl;
}

void fast() {
    ios_base::sync_with_stdio(false);
    cin.tie(NULL);
    cout.tie(NULL);
}

class ST1 {
   public:
    int n;
    vi lazy, v, tree;
    void build(int low, int high, int ind) {
        if (low == high) {
            tree[ind] = v[low];
            return;
        }
        int mid = (low + high) >> 1;

        build(low, mid, 2 * ind + 1);
        build(mid + 1, high, 2 * ind + 2);

        tree[ind] = tree[2 * ind + 1] + tree[2 * ind + 2];
    }
    ST1(vi &v) {
        this->v = v;
        this->n = v.size();
        lazy.resize(4 * n, 0);
        tree.resize(4 * n);
        this->build(0, n - 1, 0);
    }
    void update(int low, int high, int l, int r, int value, int ind) {
        if (lazy[ind]) {
            tree[ind] += (high - low + 1) * lazy[ind];
            if (low != high) {
                lazy[2 * ind + 1] += lazy[ind];
                lazy[2 * ind + 2] += lazy[ind];
            }
            lazy[ind] = 0;
        }
        int mid = (low + high) >> 1;
        if (low > r || high < l) {
            return;
        } else if (low >= l && high <= r) {
            tree[ind] += (high - low + 1) * value;
            if (low != high) {
                lazy[2 * ind + 1] += value;
                lazy[2 * ind + 2] += value;
            }
            return;
        }

        update(low, mid, l, r, value, 2 * ind + 1);
        update(mid + 1, high, l, r, value, 2 * ind + 2);

        tree[ind] = tree[2 * ind + 1] + tree[2 * ind + 2];
    }
    int query(int low, int high, int l, int r, int ind) {
        if (lazy[ind]) {
            tree[ind] += (high - low + 1) * lazy[ind];
            if (low != high) {
                lazy[2 * ind + 1] += lazy[ind];
                lazy[2 * ind + 2] += lazy[ind];
            }
            lazy[ind] = 0;
        }
        int mid = (low + high) >> 1;
        if (low > r || high < l) {
            return 0;
        }
        if (low >= l && high <= r) {
            return tree[ind];
        }
        int left = query(low, mid, l, r, 2 * ind + 1);
        int right = query(mid + 1, high, l, r, 2 * ind + 2);

        return (left + right);
    }
    void update(int l, int r, int x) { update(0, n - 1, l, r, x, 0); }
    int query(int l, int r) { return query(0, n - 1, l, r, 0); }
};

class ST2 {
   public:
    int n;
    vi lazy, v, tree;
    void build(int low, int high, int ind) {
        if (low == high) {
            tree[ind] = low * v[low];
            return;
        }
        int mid = (low + high) >> 1;

        build(low, mid, 2 * ind + 1);
        build(mid + 1, high, 2 * ind + 2);

        tree[ind] = tree[2 * ind + 1] + tree[2 * ind + 2];
    }
    ST2(vi &v) {
        this->v = v;
        this->n = v.size();
        lazy.resize(4 * n, 0);
        tree.resize(4 * n);
        this->build(0, n - 1, 0);
    }
    void update(int low, int high, int l, int r, int value, int ind) {
        if (lazy[ind]) {
            int sz = (high - low + 1);
            tree[ind] += lazy[ind] * ((sz * (high + low)) / 2);
            if (low != high) {
                lazy[2 * ind + 1] += lazy[ind];
                lazy[2 * ind + 2] += lazy[ind];
            }
            lazy[ind] = 0;
        }
        int mid = (low + high) >> 1;
        if (low > r || high < l) {
            return;
        } else if (low >= l && high <= r) {
            int sz = (high - low + 1);
            tree[ind] += value * ((sz * (high + low)) / 2);
            if (low != high) {
                lazy[2 * ind + 1] += value;
                lazy[2 * ind + 2] += value;
            }
            return;
        }

        update(low, mid, l, r, value, 2 * ind + 1);
        update(mid + 1, high, l, r, value, 2 * ind + 2);

        tree[ind] = tree[2 * ind + 1] + tree[2 * ind + 2];
    }
    int query(int low, int high, int l, int r, int ind) {
        if (lazy[ind]) {
            int sz = (high - low + 1);
            tree[ind] += lazy[ind] * ((sz * (high + low)) / 2);
            if (low != high) {
                lazy[2 * ind + 1] += lazy[ind];
                lazy[2 * ind + 2] += lazy[ind];
            }
            lazy[ind] = 0;
        }
        int mid = (low + high) >> 1;
        if (low > r || high < l) {
            return 0;
        }
        if (low >= l && high <= r) {
            return tree[ind];
        }
        int left = query(low, mid, l, r, 2 * ind + 1);
        int right = query(mid + 1, high, l, r, 2 * ind + 2);

        return (left + right);
    }
    void update(int l, int r, int x) { update(0, n - 1, l, r, x, 0); }
    int query(int l, int r) { return query(0, n - 1, l, r, 0); }
};

void solve() {
    int n, q;
    cin >> n >> q;

    vector<int> v(n + 1);

    ST1 st1(v);
    ST2 st2(v);

    while (q--) {
        int choice;
        cin >> choice;

        if (choice == 1) {
            int l, r, a;
            cin >> l >> r >> a;
            st1.update(l, r, (1 - l) * a);
            st2.update(l, r, a);
        } else {
            int l, r;
            cin >> l >> r;
            cout << st1.query(l, r) + st2.query(l, r) << endl;
        }
    }
}

signed main() {
    fast();
    int t = 1;
    // cin >> t;
    while (t--) {
        solve();
    }
    return 0;
}