LCA+线段树

由于至多只有1个出现奇数次的树,那我们可以用所有数的异或来的得到这个出现奇数次数,用线段树维护到根节点的异或和,询问的时候询问节点,LCA在这个过程中被异或了两次,再异或一下LCA消除影响,注意:询问的是单个节点,即线段树的叶子节点。

还有就是原数字可能出现0,我们统一+1处理,询问的时候再减一,显然这样不影响答案。

#include <iostream>
#include<cstdio>
#include<algorithm>
#include<cstring>
#include<queue>
#include<string>
#include<bitset>
#include<cmath>
#include<map>
using namespace std;
const int N = 1e5 + 10;
const int up = 18;
#define lowbit(x) (x&(-x))
typedef long long ll;
struct Edge{
    int to, next;
}edge[N << 1];
int head[N], tot, dep[N],a[N],in[N],out[N],top;
int fat[N][20];
int c[N << 2], add[N << 2];
int n, m;
void addedge(int from, int to)
{
    edge[tot].to = to;
    edge[tot].next = head[from];
    head[from] = tot++;
}
void dfs(int u, int fa)
{
    in[u] = ++top;
    for (int i = head[u]; i != -1; i = edge[i].next)
    {
        int v = edge[i].to;
        if (v == fa) continue;
        dep[v] = dep[u] + 1;
        fat[v][0] = u;
        dfs(v, u);
    }
    out[u] = top;
}
void pushdown(int rt)
{
    if (c[rt]){
        c[rt << 1] ^= c[rt];
        c[rt <<1|1] ^= c[rt];
        c[rt] = 0;
    }
}
void change(int rt, int l, int r, int x, int y, int val)
{
    if (x <= l&&r <= y){
        c[rt] ^= val;
        return;
    }
    pushdown(rt);
    int mid = (l + r) >> 1;
    if (x <= mid) change(rt << 1, l, mid, x, y, val);
    if (y > mid) change(rt << 1 | 1, mid + 1, r, x, y, val);
}
int query(int rt, int l, int r, int pos)
{
    if (l==r){
        return c[rt];
    }
    pushdown(rt);
    int mid = (l + r) >> 1;
    int ans1 = 0,ans2=0;
    if (pos<= mid) return  query(rt << 1, l, mid, pos);
    else return query(rt << 1 | 1, mid + 1, r,pos);
    return ans1^ans2;
}
void dp()
{
    for (int j = 1; j <= up; j++)
    {
        for (int i = 1; i <= n; i++)
        {
            fat[i][j] = fat[fat[i][j - 1]][j - 1];
        }
    }
}
int lca(int u, int v)
{
    if (dep[u] < dep[v]) swap(u, v);
    for (int j = up; j >= 0; j--)
    {
        if ((dep[u] - dep[v]) >> j & 1){
            u = fat[u][j];
        }
    }
    if (u == v) return u;
    for (int j = up; j >= 0; j--)
    {
        if (fat[u][j] != fat[v][j]){
            u = fat[u][j];
            v = fat[v][j];
        }
    }
    return fat[u][0];
}
void init()
{
    tot = 0;
    top = 0;
    memset(head, -1, sizeof(head));
    memset(c, 0, sizeof(c));

}
void solve()
{
    dep[1] = 1;
    dfs(1, -1);
    dp();
}
int main()
{
    int t;
    scanf("%d", &t);
    for (int j = 1; j <= t; j++)
    {
        scanf("%d%d", &n, &m);
        init();
        for (int i = 1; i < n; i++)
        {
            int u ,v ;
            scanf("%d%d", &u, &v);
            addedge(u, v);
            addedge(v, u);
        }
        solve();
        for (int i = 1; i <= n; i++){
            scanf("%d", &a[i]);
            a[i]++;
            change(1, 1, n, in[i], out[i], a[i]);
        }
        while (m--)
        {
            int op, x, y;
            scanf("%d%d%d", &op, &x, &y);
            if (op == 1){
                int Lca = lca(x, y);
                int ans = query(1, 1, n, in[x]);
                ans ^= query(1, 1, n, in[y]);
                ans ^= a[Lca];
                printf("%d\n", ans - 1);
            }
            else{
                y++;
                change(1, 1, n, in[x], out[x], a[x] ^ y);
                a[x] = y;
            }
        }
        
    }
    return 0;
}