题目
原题链接
题目大意
有 n n n 个数 a 1 … a n a_1…a_n a 1 … a n 和 m m m 个数 b 1 … b m b_1…b_m b 1 … b m 和一个质数 p p p 。第 i i i 个集合是这样生成的:一开始只有一个 1。每次找集合内的一个元素 c c c 和一个下标 j ( j ∈ [ 1 , m ] ) j\ (\ j∈[1,m]) j ( j ∈ [ 1 , m ] ) ,若 c × a i b j m o d p c×a_i^{b_j}\ mod\ p c × a i b j m o d p 不在集合里,则加进去。求这 n n n 个集合的并集大小。
n ≤ 1 0 4 , m ≤ 1 0 5 , a i < p ≤ 1 0 9 , b i < 1 0 9 n \leq 10^4, m \leq 10^5, a_i<p≤10^9, b_i<10^9 n ≤ 1 0 4 , m ≤ 1 0 5 , a i < p ≤ 1 0 9 , b i < 1 0 9
思路
我们考虑单独一个集合, 实际上这个集合的元素可以表示为 a i ∑ j = 1 m k j b j m o d p a_i^{\sum_{j=1}^{m}k_jb_j}\ mod\ p a i ∑ j = 1 m k j b j m o d p . 由欧拉定理, 我们知道 a i a_i a i 的指数实际上是在膜 p − 1 p-1 p − 1 意义下的. 如果我们记 B = g c d ( p − 1 , b 1 , . . . b m ) B=gcd(p-1,b_1,...b_m) B = g c d ( p − 1 , b 1 , . . . b m ) , 那么任意一个元素, 我们都可以表示成 a i k B m o d p a_{i}^{kB}\ mod\ p a i k B m o d p . 注意到对于任意集合, B B B 都是一个定值, 于是我们可以预处理每一个 a i a_{i} a i 为 a i B a_{i}^{B} a i B , 这样之后我们只需要合并所有的 ( a i B ) k m o d p (a_{i}^{B})^k\ mod\ p ( a i B ) k m o d p 即可.
这个时候我们有一个很直接的思路, 对于每一个 a i B a_i^{B} a i B , 我们可以枚举它的所有指数, 用一个 m a p map m a p 记录对应的值, 来统计答案. 对于指数枚举的范围, 我们可以求出 a i B a_i^{B} a i B 膜 p p p 的阶. 不幸的是, 这样会tle. 原因也很显然 : 如果每个 a i B a_i^B a i B 都是 p p p 的原根, 那么我们的复杂度就来到了 O ( n p ) O(np) O ( n p ) , 这显然是我们不能接受的. 我们要考虑优化.
可以发现由于底数不同导致我们在统计答案时需要额外枚举底数, 那么我们可不可以让底数统一呢, 答案是可以的. 我们可以利用 p p p 的一个原根 g g g 来表示每一个 a i B a_i^{B} a i B . 即设 g A i = a i B g^{A_i}=a_i^B g A i = a i B . 这样我们的问题就来到了统计 g k A i m o d p g^{kA_i}\ mod\ p g k A i m o d p 的个数了. 同样的, 我们设 A i ′ = g c d ( A i , p − 1 ) A{^i}'=gcd(A^i,p-1) A i ′ = g c d ( A i , p − 1 ) . 那么我们的问题就变成了 : 有一个序列 { A i ′ } \{\ {A^i}'\} { A i ′ } , 求 1 − p − 1 1-p-1 1 − p − 1 中至少是序列里面其中一个数的倍数的个数. 这个可以用容斥解决.
代码
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 #include <bits/stdc++.h> using i64 = long long ;template <class T >T power (T a, T b, T p) { T res = 1 ; for ( ; b; b /= 2 , a = 1LL * a * a % p) if (b & 1 ) res = 1LL * res * a % p; return res; } void repeater () { int n, m, p; std::cin >> n >> m >> p; std::vector<int > a (n) ; for (int i = 0 ; i < n; i++) std::cin >> a[i]; int ph = p - 1 , g = ph; for (int i = 0 ; i < m; i++) { int x; std::cin >> x; g = std::gcd (g, x); } for (auto &i : a) i = power (i, g, p); std::vector<int > fac; for (int i = 1 ; i * i <= ph; i++) { if (ph % i) continue ; fac.emplace_back (i); if (i * i != ph) fac.emplace_back (ph / i); } sort (fac.begin (), fac.end ()); std::vector<int > A (n) ; for (int i = 0 ; i < n; i++) { int t = 0 ; for (auto j : fac) { if (power (a[i], j, p) == 1 ) { t = j; break ; } } if (a[i] == 1 ) t = 1 ; A[i] = ph / t; } sort (A.begin (), A.end ()); reverse (A.begin (), A.end ()); A.erase (std::unique (A.begin (), A.end ()), A.end ()); int sz = fac.size (); std::vector<int > vis (sz) , f (sz) ; for (int i = 0 ; i < sz; i++) { for (auto j : A) { if (fac[i] % j == 0 ) { vis[i] = 1 ; break ; } } } int ans = 0 ; for (int i = sz - 1 ; i >= 0 ; i--) { if (!vis[i]) continue ; f[i] = ph / fac[i]; for (int j = i + 1 ; j < sz; j++) { if (fac[j] % fac[i] == 0 ) f[i] -= f[j]; } ans += f[i]; } std::cout << ans << "\n" ; } int main () { std::ios::sync_with_stdio (false ); std::cin.tie (nullptr ); int t = 1 ; while (t--) repeater (); return 0 ; }