题目

原题链接

题目大意

nn 个数 a1ana_1…a_nmm 个数 b1bmb_1…b_m 和一个质数 pp。第 ii 个集合是这样生成的:一开始只有一个 1。每次找集合内的一个元素 cc 和一个下标 j ( j[1,m])j\ (\ j∈[1,m]),若 c×aibj mod pc×a_i^{b_j}\ mod\ p 不在集合里,则加进去。求这 nn 个集合的并集大小。
n104,m105,ai<p109,bi<109n \leq 10^4, m \leq 10^5, a_i<p≤10^9, b_i<10^9

思路

我们考虑单独一个集合, 实际上这个集合的元素可以表示为 aij=1mkjbj mod pa_i^{\sum_{j=1}^{m}k_jb_j}\ mod\ p. 由欧拉定理, 我们知道 aia_i 的指数实际上是在膜 p1p-1 意义下的. 如果我们记 B=gcd(p1,b1,...bm)B=gcd(p-1,b_1,...b_m), 那么任意一个元素, 我们都可以表示成 aikB mod pa_{i}^{kB}\ mod\ p. 注意到对于任意集合, BB 都是一个定值, 于是我们可以预处理每一个 aia_{i}aiBa_{i}^{B}, 这样之后我们只需要合并所有的 (aiB)k mod p(a_{i}^{B})^k\ mod\ p 即可.

这个时候我们有一个很直接的思路, 对于每一个 aiBa_i^{B}, 我们可以枚举它的所有指数, 用一个 mapmap 记录对应的值, 来统计答案. 对于指数枚举的范围, 我们可以求出 aiBa_i^{B}pp 的阶. 不幸的是, 这样会tle. 原因也很显然 : 如果每个 aiBa_i^B 都是 pp 的原根, 那么我们的复杂度就来到了 O(np)O(np), 这显然是我们不能接受的. 我们要考虑优化.

可以发现由于底数不同导致我们在统计答案时需要额外枚举底数, 那么我们可不可以让底数统一呢, 答案是可以的. 我们可以利用 pp 的一个原根 gg 来表示每一个 aiBa_i^{B}. 即设 gAi=aiBg^{A_i}=a_i^B. 这样我们的问题就来到了统计 gkAi mod pg^{kA_i}\ mod\ p 的个数了. 同样的, 我们设 Ai=gcd(Ai,p1)A{^i}'=gcd(A^i,p-1). 那么我们的问题就变成了 : 有一个序列 { Ai}\{\ {A^i}'\}, 求 1p11-p-1 中至少是序列里面其中一个数的倍数的个数. 这个可以用容斥解决.

代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
#include <bits/stdc++.h>
using i64 = long long;

template <class T>
T power(T a, T b, T p)
{
T res = 1;
for( ; b; b /= 2, a = 1LL * a * a % p)
if(b & 1) res = 1LL * res * a % p;

return res;
}

void repeater()
{
int n, m, p; std::cin >> n >> m >> p;
std::vector<int> a(n);
for(int i = 0; i < n; i++) std::cin >> a[i];

int ph = p - 1, g = ph;
for(int i = 0; i < m; i++)
{
int x; std::cin >> x;
g = std::gcd(g, x);
}

for(auto &i : a) i = power(i, g, p);

std::vector<int> fac;
for(int i = 1; i * i <= ph; i++)
{
if(ph % i) continue;
fac.emplace_back(i);
if(i * i != ph) fac.emplace_back(ph / i);
}
sort(fac.begin(), fac.end());

std::vector<int> A(n);
for(int i = 0; i < n; i++)
{
int t = 0;
for(auto j : fac)
{
if(power(a[i], j, p) == 1)
{
t = j;
break;
}
}
if(a[i] == 1) t = 1;
A[i] = ph / t;
}

sort(A.begin(), A.end());
reverse(A.begin(), A.end());
A.erase(std::unique(A.begin(), A.end()), A.end());

int sz = fac.size();
std::vector<int> vis(sz), f(sz);
for(int i = 0; i < sz; i++)
{
for(auto j : A)
{
if(fac[i] % j == 0)
{
vis[i] = 1;
break;
}
}
}

int ans = 0;
for(int i = sz - 1; i >= 0; i--)
{
if(!vis[i]) continue;
f[i] = ph / fac[i];
for(int j = i + 1; j < sz; j++)
{
if(fac[j] % fac[i] == 0)
f[i] -= f[j];
}
ans += f[i];
}
std::cout << ans << "\n";
}

int main()
{
std::ios::sync_with_stdio(false);
std::cin.tie(nullptr);

int t = 1; //std::cin >> t;
while(t--) repeater();

return 0;
}