1 | use super::monty::monty_modpow; |
2 | use super::BigUint; |
3 | |
4 | use crate::big_digit::{self, BigDigit}; |
5 | |
6 | use num_integer::Integer; |
7 | use num_traits::{One, Pow, ToPrimitive, Zero}; |
8 | |
9 | impl Pow<&BigUint> for BigUint { |
10 | type Output = BigUint; |
11 | |
12 | #[inline ] |
13 | fn pow(self, exp: &BigUint) -> BigUint { |
14 | if self.is_one() || exp.is_zero() { |
15 | BigUint::one() |
16 | } else if self.is_zero() { |
17 | Self::ZERO |
18 | } else if let Some(exp: u64) = exp.to_u64() { |
19 | self.pow(exp) |
20 | } else if let Some(exp: u128) = exp.to_u128() { |
21 | self.pow(exp) |
22 | } else { |
23 | // At this point, `self >= 2` and `exp >= 2¹²⁸`. The smallest possible result given |
24 | // `2.pow(2¹²⁸)` would require far more memory than 64-bit targets can address! |
25 | panic!("memory overflow" ) |
26 | } |
27 | } |
28 | } |
29 | |
30 | impl Pow<BigUint> for BigUint { |
31 | type Output = BigUint; |
32 | |
33 | #[inline ] |
34 | fn pow(self, exp: BigUint) -> BigUint { |
35 | Pow::pow(self, &exp) |
36 | } |
37 | } |
38 | |
39 | impl Pow<&BigUint> for &BigUint { |
40 | type Output = BigUint; |
41 | |
42 | #[inline ] |
43 | fn pow(self, exp: &BigUint) -> BigUint { |
44 | if self.is_one() || exp.is_zero() { |
45 | BigUint::one() |
46 | } else if self.is_zero() { |
47 | BigUint::ZERO |
48 | } else { |
49 | self.clone().pow(exp) |
50 | } |
51 | } |
52 | } |
53 | |
54 | impl Pow<BigUint> for &BigUint { |
55 | type Output = BigUint; |
56 | |
57 | #[inline ] |
58 | fn pow(self, exp: BigUint) -> BigUint { |
59 | Pow::pow(self, &exp) |
60 | } |
61 | } |
62 | |
63 | macro_rules! pow_impl { |
64 | ($T:ty) => { |
65 | impl Pow<$T> for BigUint { |
66 | type Output = BigUint; |
67 | |
68 | fn pow(self, mut exp: $T) -> BigUint { |
69 | if exp == 0 { |
70 | return BigUint::one(); |
71 | } |
72 | let mut base = self; |
73 | |
74 | while exp & 1 == 0 { |
75 | base = &base * &base; |
76 | exp >>= 1; |
77 | } |
78 | |
79 | if exp == 1 { |
80 | return base; |
81 | } |
82 | |
83 | let mut acc = base.clone(); |
84 | while exp > 1 { |
85 | exp >>= 1; |
86 | base = &base * &base; |
87 | if exp & 1 == 1 { |
88 | acc *= &base; |
89 | } |
90 | } |
91 | acc |
92 | } |
93 | } |
94 | |
95 | impl Pow<&$T> for BigUint { |
96 | type Output = BigUint; |
97 | |
98 | #[inline] |
99 | fn pow(self, exp: &$T) -> BigUint { |
100 | Pow::pow(self, *exp) |
101 | } |
102 | } |
103 | |
104 | impl Pow<$T> for &BigUint { |
105 | type Output = BigUint; |
106 | |
107 | #[inline] |
108 | fn pow(self, exp: $T) -> BigUint { |
109 | if exp == 0 { |
110 | return BigUint::one(); |
111 | } |
112 | Pow::pow(self.clone(), exp) |
113 | } |
114 | } |
115 | |
116 | impl Pow<&$T> for &BigUint { |
117 | type Output = BigUint; |
118 | |
119 | #[inline] |
120 | fn pow(self, exp: &$T) -> BigUint { |
121 | Pow::pow(self, *exp) |
122 | } |
123 | } |
124 | }; |
125 | } |
126 | |
127 | pow_impl!(u8); |
128 | pow_impl!(u16); |
129 | pow_impl!(u32); |
130 | pow_impl!(u64); |
131 | pow_impl!(usize); |
132 | pow_impl!(u128); |
133 | |
134 | pub(super) fn modpow(x: &BigUint, exponent: &BigUint, modulus: &BigUint) -> BigUint { |
135 | assert!( |
136 | !modulus.is_zero(), |
137 | "attempt to calculate with zero modulus!" |
138 | ); |
139 | |
140 | if modulus.is_odd() { |
141 | // For an odd modulus, we can use Montgomery multiplication in base 2^32. |
142 | monty_modpow(x, y:exponent, m:modulus) |
143 | } else { |
144 | // Otherwise do basically the same as `num::pow`, but with a modulus. |
145 | plain_modpow(base:x, &exponent.data, modulus) |
146 | } |
147 | } |
148 | |
149 | fn plain_modpow(base: &BigUint, exp_data: &[BigDigit], modulus: &BigUint) -> BigUint { |
150 | assert!( |
151 | !modulus.is_zero(), |
152 | "attempt to calculate with zero modulus!" |
153 | ); |
154 | |
155 | let i = match exp_data.iter().position(|&r| r != 0) { |
156 | None => return BigUint::one(), |
157 | Some(i) => i, |
158 | }; |
159 | |
160 | let mut base = base % modulus; |
161 | for _ in 0..i { |
162 | for _ in 0..big_digit::BITS { |
163 | base = &base * &base % modulus; |
164 | } |
165 | } |
166 | |
167 | let mut r = exp_data[i]; |
168 | let mut b = 0u8; |
169 | while r.is_even() { |
170 | base = &base * &base % modulus; |
171 | r >>= 1; |
172 | b += 1; |
173 | } |
174 | |
175 | let mut exp_iter = exp_data[i + 1..].iter(); |
176 | if exp_iter.len() == 0 && r.is_one() { |
177 | return base; |
178 | } |
179 | |
180 | let mut acc = base.clone(); |
181 | r >>= 1; |
182 | b += 1; |
183 | |
184 | { |
185 | let mut unit = |exp_is_odd| { |
186 | base = &base * &base % modulus; |
187 | if exp_is_odd { |
188 | acc *= &base; |
189 | acc %= modulus; |
190 | } |
191 | }; |
192 | |
193 | if let Some(&last) = exp_iter.next_back() { |
194 | // consume exp_data[i] |
195 | for _ in b..big_digit::BITS { |
196 | unit(r.is_odd()); |
197 | r >>= 1; |
198 | } |
199 | |
200 | // consume all other digits before the last |
201 | for &r in exp_iter { |
202 | let mut r = r; |
203 | for _ in 0..big_digit::BITS { |
204 | unit(r.is_odd()); |
205 | r >>= 1; |
206 | } |
207 | } |
208 | r = last; |
209 | } |
210 | |
211 | debug_assert_ne!(r, 0); |
212 | while !r.is_zero() { |
213 | unit(r.is_odd()); |
214 | r >>= 1; |
215 | } |
216 | } |
217 | acc |
218 | } |
219 | |
220 | #[test ] |
221 | fn test_plain_modpow() { |
222 | let two = &BigUint::from(2u32); |
223 | let modulus = BigUint::from(0x1100u32); |
224 | |
225 | let exp = vec![0, 0b1]; |
226 | assert_eq!( |
227 | two.pow(0b1_00000000_u32) % &modulus, |
228 | plain_modpow(two, &exp, &modulus) |
229 | ); |
230 | let exp = vec![0, 0b10]; |
231 | assert_eq!( |
232 | two.pow(0b10_00000000_u32) % &modulus, |
233 | plain_modpow(two, &exp, &modulus) |
234 | ); |
235 | let exp = vec![0, 0b110010]; |
236 | assert_eq!( |
237 | two.pow(0b110010_00000000_u32) % &modulus, |
238 | plain_modpow(two, &exp, &modulus) |
239 | ); |
240 | let exp = vec![0b1, 0b1]; |
241 | assert_eq!( |
242 | two.pow(0b1_00000001_u32) % &modulus, |
243 | plain_modpow(two, &exp, &modulus) |
244 | ); |
245 | let exp = vec![0b1100, 0, 0b1]; |
246 | assert_eq!( |
247 | two.pow(0b1_00000000_00001100_u32) % &modulus, |
248 | plain_modpow(two, &exp, &modulus) |
249 | ); |
250 | } |
251 | |
252 | #[test ] |
253 | fn test_pow_biguint() { |
254 | let base = BigUint::from(5u8); |
255 | let exponent = BigUint::from(3u8); |
256 | |
257 | assert_eq!(BigUint::from(125u8), base.pow(exponent)); |
258 | } |
259 | |