树上差分
树上差分模板题,注意点和边的差分有一点区别
#include <iostream>
#include<cstdio>
#include<algorithm>
#include<cstring>
#include<queue>
#include<string>
#include<bitset>
#include<cmath>
using namespace std;
const int N = 1e5 + 10;
const int up = 18;
#define lowbit(x) (x&(-x))
typedef long long ll;
struct Edge{
int to, next;
}edge[N << 1];
int head[N], tot, dep[N];
ll ev[N], nv[N];
int fat[N][20];
int n, m, k;
void addedge(int from, int to)
{
edge[tot].to = to;
edge[tot].next = head[from];
head[from] = tot++;
}
void dfs(int u, int fa)
{
for (int i = head[u]; i != -1; i = edge[i].next)
{
int v = edge[i].to;
if (v == fa) continue;
dep[v] = dep[u] + 1;
fat[v][0] = u;
dfs(v, u);
}
}
void Dfs(int u, int fa)
{
for (int i = head[u]; i != -1; i = edge[i].next)
{
int v = edge[i].to;
if (v == fa) continue;
Dfs(v, u);
ev[u] += ev[v];
nv[u] += nv[v];
}
}
void dp()
{
for (int j = 1; j <= up; j++)
{
for (int i = 1; i <= n; i++)
{
fat[i][j] = fat[fat[i][j - 1]][j - 1];
}
}
}
int lca(int u, int v)
{
if (dep[u] < dep[v]) swap(u, v);
for (int j = up; j >= 0; j--)
{
if ((dep[u] - dep[v]) >> j & 1){
u = fat[u][j];
}
}
if (u == v) return u;
for (int j = up; j >= 0; j--)
{
if (fat[u][j] != fat[v][j]){
u = fat[u][j];
v = fat[v][j];
}
}
return fat[u][0];
}
void init()
{
tot = 0;
memset(head, -1, sizeof(head));
memset(ev, 0, sizeof(ev));
memset(nv, 0, sizeof(nv));
}
void add1(int u, int v, int k)
{
int t = lca(u, v);
if (t == u){
nv[v] += k;
}
else if (t == v){
nv[u] += k;
}
else{
nv[t] -= k;
nv[u] += k;
nv[v] += k;
}
nv[fat[t][0]] -= k;
}
void add2(int u, int v, int k)
{
int t = lca(u, v);
ev[t] -= 2 * k;
ev[u] += k;
ev[v] += k;
}
int main()
{
int t;
scanf("%d", &t);
for (int j = 1; j <= t; j++)
{
scanf("%d%d", &n, &m);
init();
for (int i = 1; i < n; i++)
{
int u, v;
scanf("%d%d", &u, &v);
addedge(u, v);
addedge(v, u);
}
dep[1] = 1;
fat[1][0] = 0;
dfs(1, -1);
dp();
char op[10];
int u, v, k;
for (int i = 1; i <= m; i++)
{
scanf("%s%d%d%d", op, &u, &v, &k);
if (op[3] == '1'){
add1(u, v, k);
}
else{
add2(u, v, k);
}
}
Dfs(1, -1);
printf("Case #%d:\n", j);
for (int i = 1; i <= n; i++)
{
if (i - 1) putchar(' ');
printf("%lld", nv[i]);
}
puts("");
for (int i = 0; i < n - 1; i++)
{
u = edge[i << 1].to, v = edge[i << 1 | 1].to;
if (dep[u] < dep[v]) swap(u, v);
if (i) putchar(' ');
printf("%lld", ev[u]);
}
puts("");
}
return 0;
}
评论