test/9E.py
2026-07-01 19:46:51 +03:00

105 lines
No EOL
2.5 KiB
Python

import sys
input_data = sys.stdin.read().split()
n = int(input_data[0])
k = int(input_data[1])
idx = 2
a = [int(x) for x in input_data[idx : idx + n]]
idx += n
b = [int(x) for x in input_data[idx : idx + k]]
idx += k
q = int(input_data[idx])
idx += 1
queries = [int(x) for x in input_data[idx : idx + q]]
MOD = 998244353
G = 3
def ntt(a, invert):
n = len(a)
j = 0
for i in range(1, n):
bit = n >> 1
while j & bit:
j ^= bit
bit >>= 1
j ^= bit
if i < j:
a[i], a[j] = a[j], a[i]
step = 2
while step <= n:
half = step >> 1
wlen = pow(G, (MOD - 1) // step, MOD)
if invert:
wlen = pow(wlen, MOD - 2, MOD)
w = [1] * half
curr = 1
for i in range(half):
w[i] = curr
curr = (curr * wlen) % MOD
for i in range(0, n, step):
for j in range(half):
u = a[i + j]
v = (a[i + j + half] * w[j]) % MOD
a[i + j] = (u + v) % MOD
a[i + j + half] = (u - v + MOD) % MOD
step <<= 1
if invert:
n_inv = pow(n, MOD - 2, MOD)
for i in range(n):
a[i] = (a[i] * n_inv) % MOD
def multiply(a, b):
sz = len(a) + len(b) - 1
n = 1
while n < sz:
n <<= 1
a_pad = a + [0] * (n - len(a))
b_pad = b + [0] * (n - len(b))
ntt(a_pad, False)
ntt(b_pad, False)
for i in range(n):
a_pad[i] = (a_pad[i] * b_pad[i]) % MOD
ntt(a_pad, True)
return a_pad[:sz]
max_a = 300005
counts = [0] * max_a
for x in a:
if x < max_a:
counts[x] += 1
final_ans = [0] * len(queries)
for H in b:
polynoms = []
for x in range(1, min(H, max_a)):
if counts[x] > 0:
polynoms.append([1, counts[x]])
if not polynoms:
P = [1]
else:
while len(polynoms) > 1:
next_level = []
for i in range(0, len(polynoms), 2):
if i + 1 < len(polynoms):
next_level.append(multiply(polynoms[i], polynoms[i+1]))
else:
next_level.append(polynoms[i])
polynoms = next_level
P = polynoms[0]
for idx_q, Q in enumerate(queries):
white_needed = Q // 2 - H - 1
if 0 <= white_needed < len(P):
ways = (P[white_needed] * pow(2, white_needed, MOD)) % MOD
final_ans[idx_q] = (final_ans[idx_q] + ways) % MOD
for res in final_ans:
print(res)