树上差分

树上差分模板题,注意点和边的差分有一点区别

#include <iostream>
#include<cstdio>
#include<algorithm>
#include<cstring>
#include<queue>
#include<string>
#include<bitset>
#include<cmath>
using namespace std;
const int N = 1e5 + 10;
const int up = 18;
#define lowbit(x) (x&(-x))
typedef long long ll;
struct Edge{
    int to, next;
}edge[N << 1];
int head[N], tot, dep[N];
ll ev[N], nv[N];
int fat[N][20];
int n, m, k;
void addedge(int from, int to)
{
    edge[tot].to = to;
    edge[tot].next = head[from];
    head[from] = tot++;
}
void dfs(int u, int fa)
{
    for (int i = head[u]; i != -1; i = edge[i].next)
    {
        int v = edge[i].to;
        if (v == fa) continue;
        dep[v] = dep[u] + 1;
        fat[v][0] = u;
        dfs(v, u);
    }
}
void Dfs(int u, int fa)
{
    for (int i = head[u]; i != -1; i = edge[i].next)
    {
        int v = edge[i].to;
        if (v == fa) continue;
        Dfs(v, u);
        ev[u] += ev[v];
        nv[u] += nv[v];
    }
}
void dp()
{
    for (int j = 1; j <= up; j++)
    {
        for (int i = 1; i <= n; i++)
        {
            fat[i][j] = fat[fat[i][j - 1]][j - 1];
        }
    }
}
int lca(int u, int v)
{
    if (dep[u] < dep[v]) swap(u, v);
    for (int j = up; j >= 0; j--)
    {
        if ((dep[u] - dep[v]) >> j & 1){
            u = fat[u][j];
        }
    }
    if (u == v) return u;
    for (int j = up; j >= 0; j--)
    {
        if (fat[u][j] != fat[v][j]){
            u = fat[u][j];
            v = fat[v][j];
        }
    }
    return fat[u][0];
}
void init()
{
    tot = 0;
    memset(head, -1, sizeof(head));
    memset(ev, 0, sizeof(ev));
    memset(nv, 0, sizeof(nv));
}
void add1(int u, int v, int k)
{
    int t = lca(u, v);
    if (t == u){
        nv[v] += k;
    }
    else if (t == v){
        nv[u] += k;
    }
    else{
        nv[t] -= k;
        nv[u] += k;
        nv[v] += k;
    }
    nv[fat[t][0]] -= k;
}
void add2(int u, int v, int k)
{
    int t = lca(u, v);
    ev[t] -= 2 * k;
    ev[u] += k;
    ev[v] += k;
}
int main()
{
    int t;
    scanf("%d", &t);
    for (int j = 1; j <= t; j++)
    {
        scanf("%d%d", &n, &m);
        init();
        for (int i = 1; i < n; i++)
        {
            int u, v;
            scanf("%d%d", &u, &v);
            addedge(u, v);
            addedge(v, u);
        }
        dep[1] = 1;
        fat[1][0] = 0;
        dfs(1, -1);
        dp();
        char op[10];
        int u, v, k;
        for (int i = 1; i <= m; i++)
        {
            scanf("%s%d%d%d", op, &u, &v, &k);
            if (op[3] == '1'){
                add1(u, v, k);
            }
            else{
                add2(u, v, k);
            }
        }
        Dfs(1, -1);
        printf("Case #%d:\n", j);
        for (int i = 1; i <= n; i++)
        {
            if (i - 1) putchar(' ');
            printf("%lld", nv[i]);
        }
        puts("");
        for (int i = 0; i < n - 1; i++)
        {
            u = edge[i << 1].to, v = edge[i << 1 | 1].to;
            if (dep[u] < dep[v]) swap(u, v);
            if (i) putchar(' ');
            printf("%lld", ev[u]);
        }
        puts("");
    }
    return 0;
}