test/7B.py
2026-07-01 19:46:51 +03:00

57 lines
No EOL
1.3 KiB
Python

import sys
def f_lca(u, v_node):
if d[u] < d[v_node]:
u, v_node = v_node, u
diff = d[u] - d[v_node]
for k in range(18):
if (diff >> k) & 1:
u = up[u][k]
if u == v_node:
return u
for k in range(17, -1, -1):
if up[u][k] != up[v_node][k]:
u = up[u][k]
v_node = up[v_node][k]
return up[u][0]
data = sys.stdin.read().split()
n = int(data[0])
v = [0] + [int(x) for x in data[1 : n + 1]]
g = [[] for _ in range(n + 1)]
idx = n + 1
for _ in range(n - 1):
u, p_node = int(data[idx]), int(data[idx + 1])
g[u].append(p_node)
g[p_node].append(u)
idx += 2
m = int(data[idx])
idx += 1
d = [0] * (n + 1)
p = [0] * (n + 1)
par = [0] * (n + 1)
d[1] = 1
p[1] = v[1]
q = [1]
for u in q:
for nv in g[u]:
if nv != par[u]:
par[nv] = u
d[nv] = d[u] + 1
p[nv] = p[u] + v[nv]
q.append(nv)
up = [[0] * 18 for _ in range(n + 1)]
for i in range(1, n + 1):
up[i][0] = par[i]
for k in range(1, 18):
for i in range(1, n + 1):
up[i][k] = up[up[i][k - 1]][k - 1]
out = []
for _ in range(m):
x, y = int(data[idx]), int(data[idx + 1])
idx += 2
lca = f_lca(x, y)
ans = p[x] + p[y] - p[lca] - p[up[lca][0]]
out.append(str(ans))
print("\n".join(out))