1use crate::{CheckedSum, Decode, Encode, Error, Reader, Result, Writer};
4use alloc::{boxed::Box, vec::Vec};
5use core::fmt;
6
7#[cfg(feature = "bigint")]
8use crate::Uint;
9
10#[cfg(feature = "subtle")]
11use subtle::{Choice, ConstantTimeEq};
12
13#[cfg(any(feature = "bigint", feature = "zeroize"))]
14use zeroize::Zeroize;
15#[cfg(feature = "bigint")]
16use zeroize::Zeroizing;
17
18#[cfg_attr(not(feature = "subtle"), derive(Clone))]
44#[cfg_attr(feature = "subtle", derive(Clone, Ord, PartialOrd))] pub struct Mpint {
46 inner: Box<[u8]>,
48}
49
50impl Mpint {
51 pub fn from_bytes(bytes: &[u8]) -> Result<Self> {
57 bytes.try_into()
58 }
59
60 pub fn from_positive_bytes(mut bytes: &[u8]) -> Result<Self> {
66 let mut inner = Vec::with_capacity(bytes.len());
67
68 while bytes.first().copied() == Some(0) {
69 bytes = &bytes[1..];
70 }
71
72 match bytes.first().copied() {
73 Some(n) if n >= 0x80 => inner.push(0),
74 _ => (),
75 }
76
77 inner.extend_from_slice(bytes);
78 inner.into_boxed_slice().try_into()
79 }
80
81 pub fn as_bytes(&self) -> &[u8] {
87 &self.inner
88 }
89
90 pub fn as_positive_bytes(&self) -> Option<&[u8]> {
96 match self.as_bytes() {
97 [0x00, rest @ ..] => Some(rest),
98 [byte, ..] if *byte < 0x80 => Some(self.as_bytes()),
99 _ => None,
100 }
101 }
102
103 pub fn is_positive(&self) -> bool {
105 self.as_positive_bytes().is_some()
106 }
107}
108
109impl AsRef<[u8]> for Mpint {
110 fn as_ref(&self) -> &[u8] {
111 self.as_bytes()
112 }
113}
114
115#[cfg(feature = "subtle")]
116impl ConstantTimeEq for Mpint {
117 fn ct_eq(&self, other: &Self) -> Choice {
118 self.as_ref().ct_eq(other.as_ref())
119 }
120}
121
122#[cfg(feature = "subtle")]
123impl Eq for Mpint {}
124
125#[cfg(feature = "subtle")]
126impl PartialEq for Mpint {
127 fn eq(&self, other: &Self) -> bool {
128 self.ct_eq(other).into()
129 }
130}
131
132impl Decode for Mpint {
133 type Error = Error;
134
135 fn decode(reader: &mut impl Reader) -> Result<Self> {
136 Vec::decode(reader)?.into_boxed_slice().try_into()
137 }
138}
139
140impl Encode for Mpint {
141 fn encoded_len(&self) -> Result<usize> {
142 [4, self.as_bytes().len()].checked_sum()
143 }
144
145 fn encode(&self, writer: &mut impl Writer) -> Result<()> {
146 self.as_bytes().encode(writer)?;
147 Ok(())
148 }
149}
150
151impl TryFrom<&[u8]> for Mpint {
152 type Error = Error;
153
154 fn try_from(bytes: &[u8]) -> Result<Self> {
155 Vec::from(bytes).into_boxed_slice().try_into()
156 }
157}
158
159impl TryFrom<Box<[u8]>> for Mpint {
160 type Error = Error;
161
162 fn try_from(bytes: Box<[u8]>) -> Result<Self> {
163 match &*bytes {
164 [0x00] => Err(Error::MpintEncoding),
166 [0x00, n, ..] if *n < 0x80 => Err(Error::MpintEncoding),
168 _ => Ok(Self { inner: bytes }),
169 }
170 }
171}
172
173#[cfg(feature = "zeroize")]
174impl Zeroize for Mpint {
175 fn zeroize(&mut self) {
176 self.inner.zeroize();
177 }
178}
179
180impl fmt::Debug for Mpint {
181 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
182 write!(f, "Mpint({self:X})")
183 }
184}
185
186impl fmt::Display for Mpint {
187 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
188 write!(f, "{self:X}")
189 }
190}
191
192impl fmt::LowerHex for Mpint {
193 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
194 for byte in self.as_bytes() {
195 write!(f, "{byte:02x}")?;
196 }
197 Ok(())
198 }
199}
200
201impl fmt::UpperHex for Mpint {
202 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
203 for byte in self.as_bytes() {
204 write!(f, "{byte:02X}")?;
205 }
206 Ok(())
207 }
208}
209
210#[cfg(feature = "bigint")]
211impl TryFrom<Uint> for Mpint {
212 type Error = Error;
213
214 fn try_from(uint: Uint) -> Result<Mpint> {
215 Mpint::try_from(&uint)
216 }
217}
218
219#[cfg(feature = "bigint")]
220impl TryFrom<&Uint> for Mpint {
221 type Error = Error;
222
223 fn try_from(uint: &Uint) -> Result<Mpint> {
224 let bytes = Zeroizing::new(uint.to_be_bytes());
225 Mpint::from_positive_bytes(&bytes)
226 }
227}
228
229#[cfg(feature = "bigint")]
230impl TryFrom<Mpint> for Uint {
231 type Error = Error;
232
233 fn try_from(mpint: Mpint) -> Result<Uint> {
234 Uint::try_from(&mpint)
235 }
236}
237
238#[cfg(feature = "bigint")]
239impl TryFrom<&Mpint> for Uint {
240 type Error = Error;
241
242 fn try_from(mpint: &Mpint) -> Result<Uint> {
243 let bytes = mpint.as_positive_bytes().ok_or(Error::MpintEncoding)?;
245 Ok(Uint::from_be_slice_vartime(bytes))
246 }
247}
248
249#[cfg(test)]
250mod tests {
251 use super::Mpint;
252 use hex_literal::hex;
253
254 #[test]
255 fn decode_0() {
256 let n = Mpint::from_bytes(b"").unwrap();
257 assert_eq!(b"", n.as_bytes())
258 }
259
260 #[test]
261 fn reject_extra_leading_zeroes() {
262 assert!(Mpint::from_bytes(&hex!("00")).is_err());
263 assert!(Mpint::from_bytes(&hex!("00 00")).is_err());
264 assert!(Mpint::from_bytes(&hex!("00 01")).is_err());
265 }
266
267 #[test]
268 fn decode_9a378f9b2e332a7() {
269 assert!(Mpint::from_bytes(&hex!("09 a3 78 f9 b2 e3 32 a7")).is_ok());
270 }
271
272 #[test]
273 fn decode_80() {
274 let n = Mpint::from_bytes(&hex!("00 80")).unwrap();
275
276 assert_eq!(&hex!("80"), n.as_positive_bytes().unwrap())
278 }
279 #[test]
280 fn from_positive_bytes_strips_leading_zeroes() {
281 assert_eq!(
282 Mpint::from_positive_bytes(&hex!("00")).unwrap().as_ref(),
283 b""
284 );
285 assert_eq!(
286 Mpint::from_positive_bytes(&hex!("00 00")).unwrap().as_ref(),
287 b""
288 );
289 assert_eq!(
290 Mpint::from_positive_bytes(&hex!("00 01")).unwrap().as_ref(),
291 b"\x01"
292 );
293 }
294
295 #[test]
297 fn decode_neg_1234() {
298 let n = Mpint::from_bytes(&hex!("ed cc")).unwrap();
299 assert!(n.as_positive_bytes().is_none());
300 }
301
302 #[test]
304 fn decode_neg_deadbeef() {
305 let n = Mpint::from_bytes(&hex!("ff 21 52 41 11")).unwrap();
306 assert!(n.as_positive_bytes().is_none());
307 }
308}