树链剖分

#include <bits/stdc++.h>

using namespace std;
using ll = long long;
using p = pair<int, int>;
const int maxn(1e5 + 10);
const int maxm(2e5 + 10);
int mod, ecnt, v[maxn], head[maxn];
int dep[maxn], siz[maxn], fa[maxn], son[maxn];
int tim, dfn[maxn], top[maxn], w[maxn];

struct edge {
    int to, nxt;
} edges[maxm];

struct node {
    int l, r;
    int sum, lz;
} tree[maxn << 2];

template<typename T = int>
inline const T read()
{
    T x = 0, f = 1;
    char ch = getchar();
    while (ch < '0' || ch > '9') {
        if (ch == '-') f = -1;
        ch = getchar();
    }
    while (ch >= '0' && ch <= '9') {
        x = (x << 3) + (x << 1) + ch - '0';
        ch = getchar();
    }
    return x * f;
}

template<typename T>
inline void write(T x, bool ln)
{
    if (x < 0) {
        putchar('-');
        x = -x;
    }
    if (x > 9) write(x / 10, false);
    putchar(x % 10 + '0');
    if (ln) putchar(10);
}

inline void addEdge(int u, int v)
{
    edges[ecnt].to = v;
    edges[ecnt].nxt = head[u];
    head[u] = ecnt++;
}

inline int ls(int cur)
{
    return cur << 1;
}

inline int rs(int cur)
{
    return cur << 1 bitor 1;
}

inline void push_up(int cur)
{
    tree[cur].sum = tree[ls(cur)].sum + tree[rs(cur)].sum;
}

inline void push_down(int cur)
{
    if (tree[cur].lz) {
        tree[ls(cur)].lz = (tree[ls(cur)].lz + tree[cur].lz) % mod;
        tree[rs(cur)].lz = (tree[rs(cur)].lz + tree[cur].lz) % mod;
        tree[ls(cur)].sum = (tree[ls(cur)].sum + (tree[ls(cur)].r - tree[ls(cur)].l + 1) * tree[cur].lz) % mod;
        tree[rs(cur)].sum = (tree[rs(cur)].sum + (tree[rs(cur)].r - tree[rs(cur)].l + 1) * tree[cur].lz) % mod;
        tree[cur].lz = false;
    }
}

void build(int cur, int l, int r)
{
    tree[cur].l = l;
    tree[cur].r = r;
    if (l == r) {
        tree[cur].sum = w[l];
        return;
    }
    int mid = (l + r) >> 1;
    build(ls(cur), l, mid);
    build(rs(cur), mid + 1, r);
    push_up(cur);
}

void update(int cur, int l, int r, int v)
{
    if (tree[cur].l == l and tree[cur].r == r) {
        tree[cur].sum = (tree[cur].sum + (r - l + 1) * v) % mod;
        tree[cur].lz = (tree[cur].lz + v) % mod;
        return;
    }
    push_down(cur);
    int mid = (tree[cur].l + tree[cur].r) >> 1;
    if (r <= mid) {
        update(ls(cur), l, r, v);
    } else if (l > mid) {
        update(rs(cur), l, r, v);
    } else {
        update(ls(cur), l, mid, v);
        update(rs(cur), mid + 1, r, v);
    }
    push_up(cur);
}

int query(int cur, int l, int r)
{
    if (tree[cur].l == l and tree[cur].r == r) {
        return tree[cur].sum;
    }
    push_down(cur);
    int mid = (tree[cur].l + tree[cur].r) >> 1;
    if (r <= mid) {
        return query(ls(cur), l, r);
    }
    if (l > mid) {
        return query(rs(cur), l, r);
    }
    return (query(ls(cur), l, mid) + query(rs(cur), mid + 1, r)) % mod;
}

void dfs1(int cur, int pre)
{
    fa[cur] = pre;
    dep[cur] = dep[pre] + 1;
    siz[cur] = 1;
    int maxsize = -1;
    for (int i = head[cur]; compl i; i = edges[i].nxt) {
        int nxt = edges[i].to;
        if (nxt == pre) continue;
        dfs1(nxt, cur);
        siz[cur] += siz[nxt];
        if (siz[nxt] > maxsize) {
            maxsize = siz[nxt];
            son[cur] = nxt;
        }
    }
}

void dfs2(int cur, int tp)
{
    dfn[cur] = ++tim;
    w[tim] = v[cur];
    top[cur] = tp;
    if (not son[cur]) return;
    dfs2(son[cur], tp);
    for (int i = head[cur]; compl i; i = edges[i].nxt) {
        int nxt = edges[i].to;
        if (nxt not_eq fa[cur] and nxt not_eq son[cur]) {
            dfs2(nxt, nxt);
        }
    }
}

void update_chain(int u, int v, int val)
{
    val %= mod;
    while (top[u] not_eq top[v]) {
        if (dep[top[u]] < dep[top[v]]) {
            swap(u, v);
        }
        update(1, dfn[top[u]], dfn[u], val);
        u = fa[top[u]];
    }
    if (dep[u] > dep[v]) {
        swap(u, v);
    }
    update(1, dfn[u], dfn[v], val);
}

int query_chain(int u, int v)
{
    int res = 0;
    while (top[u] not_eq top[v]) {
        if (dep[top[u]] < dep[top[v]]) {
            swap(u, v);
        }
        res += query(1, dfn[top[u]], dfn[u]);
        u = fa[top[u]];
    }
    if (dep[u] > dep[v]) {
        swap(u, v);
    }
    res += query(1, dfn[u], dfn[v]);
    return res % mod;
}

inline void update_son(int cur, int val)
{
    update(1, dfn[cur], dfn[cur] + siz[cur] - 1, val);
}

inline int query_son(int cur)
{
    return query(1, dfn[cur], dfn[cur] + siz[cur] - 1);
}

int main()
{
#ifndef ONLINE_JUDGE
    freopen("input.txt", "r", stdin);
#endif
    memset(head, -1, sizeof head);
    int n = read(), m = read(), r = read();
    mod = read();
    for (int i = 1; i <= n; ++i) {
        v[i] = read();
    }
    for (int i = 1; i <= n - 1; ++i) {
        int u = read(), v = read();
        addEdge(u, v);
        addEdge(v, u);
    }
    dfs1(r, r);
    dfs2(r, r);
    build(1, 1, n);
    while (m--) {
        int op = read();
        if (op == 1) {
            int x = read(), y = read(), z = read();
            update_chain(x, y, z);
        } else if (op == 2) {
            int x = read(), y = read();
            write(query_chain(x, y), true);
        } else if (op == 3) {
            int x = read(), z = read();
            update_son(x, z);
        } else {
            int x = read();
            write(query_son(x), true);
        }
    }
    return 0;
}

最后更新于