1use crate::{CheckedSum, Decode, Encode, Error, Reader, Result, Writer};
4use alloc::{boxed::Box, vec::Vec};
5use core::fmt;
6
7#[cfg(feature = "bigint")]
8use crate::Uint;
9
10#[cfg(feature = "ctutils")]
11use ctutils::{Choice, CtEq};
12
13#[cfg(any(feature = "bigint", feature = "zeroize"))]
14use zeroize::Zeroize;
15#[cfg(feature = "bigint")]
16use zeroize::Zeroizing;
17
18#[cfg_attr(not(feature = "ctutils"), derive(Clone))]
44#[cfg_attr(feature = "ctutils", derive(Clone, Ord, PartialOrd))] pub struct Mpint {
46 inner: Box<[u8]>,
48}
49
50impl Mpint {
51 pub fn from_bytes(bytes: &[u8]) -> Result<Self> {
59 bytes.try_into()
60 }
61
62 #[must_use]
68 pub fn from_positive_bytes(mut bytes: &[u8]) -> Self {
69 while bytes.first().copied() == Some(0) {
71 bytes = &bytes[1..];
72 }
73
74 let inner = match bytes.first().copied() {
76 Some(n) if n >= 0x80 => {
77 let mut inner = Vec::with_capacity(bytes.len().saturating_add(1));
78 inner.push(0);
79 inner.extend_from_slice(bytes);
80 inner
81 }
82 _ => Vec::from(bytes),
83 };
84
85 Self {
86 inner: inner.into_boxed_slice(),
87 }
88 }
89
90 #[must_use]
96 pub fn as_bytes(&self) -> &[u8] {
97 &self.inner
98 }
99
100 #[must_use]
106 pub fn as_positive_bytes(&self) -> Option<&[u8]> {
107 match self.as_bytes() {
108 [0x00, rest @ ..] => Some(rest),
109 [byte, ..] if *byte < 0x80 => Some(self.as_bytes()),
110 _ => None,
111 }
112 }
113
114 #[must_use]
116 pub fn is_positive(&self) -> bool {
117 self.as_positive_bytes().is_some()
118 }
119}
120
121impl AsRef<[u8]> for Mpint {
122 fn as_ref(&self) -> &[u8] {
123 self.as_bytes()
124 }
125}
126
127#[cfg(feature = "ctutils")]
128impl CtEq for Mpint {
129 fn ct_eq(&self, other: &Self) -> Choice {
130 self.as_ref().ct_eq(other.as_ref())
131 }
132}
133
134#[cfg(feature = "ctutils")]
135impl Eq for Mpint {}
136
137#[cfg(feature = "ctutils")]
138impl PartialEq for Mpint {
139 fn eq(&self, other: &Self) -> bool {
140 self.ct_eq(other).into()
141 }
142}
143
144impl Decode for Mpint {
145 type Error = Error;
146
147 fn decode(reader: &mut impl Reader) -> Result<Self> {
148 Vec::decode(reader)?.into_boxed_slice().try_into()
149 }
150}
151
152impl Encode for Mpint {
153 fn encoded_len(&self) -> Result<usize> {
154 [4, self.as_bytes().len()].checked_sum()
155 }
156
157 fn encode(&self, writer: &mut impl Writer) -> Result<()> {
158 self.as_bytes().encode(writer)?;
159 Ok(())
160 }
161}
162
163impl TryFrom<&[u8]> for Mpint {
164 type Error = Error;
165
166 fn try_from(bytes: &[u8]) -> Result<Self> {
167 Vec::from(bytes).into_boxed_slice().try_into()
168 }
169}
170
171impl TryFrom<Box<[u8]>> for Mpint {
172 type Error = Error;
173
174 fn try_from(bytes: Box<[u8]>) -> Result<Self> {
175 match &*bytes {
176 [0x00] => Err(Error::MpintEncoding),
178 [0x00, n, ..] if *n < 0x80 => Err(Error::MpintEncoding),
180 _ => Ok(Self { inner: bytes }),
181 }
182 }
183}
184
185#[cfg(feature = "zeroize")]
186impl Zeroize for Mpint {
187 fn zeroize(&mut self) {
188 self.inner.zeroize();
189 }
190}
191
192impl fmt::Debug for Mpint {
193 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
194 write!(f, "Mpint({self:X})")
195 }
196}
197
198impl fmt::Display for Mpint {
199 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
200 write!(f, "{self:X}")
201 }
202}
203
204impl fmt::LowerHex for Mpint {
205 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
206 for byte in self.as_bytes() {
207 write!(f, "{byte:02x}")?;
208 }
209 Ok(())
210 }
211}
212
213impl fmt::UpperHex for Mpint {
214 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
215 for byte in self.as_bytes() {
216 write!(f, "{byte:02X}")?;
217 }
218 Ok(())
219 }
220}
221
222#[cfg(feature = "bigint")]
223impl From<&Uint> for Mpint {
224 fn from(uint: &Uint) -> Mpint {
225 let bytes = Zeroizing::new(uint.to_be_bytes());
226 Mpint::from_positive_bytes(&bytes)
227 }
228}
229
230#[cfg(feature = "bigint")]
231impl From<Uint> for Mpint {
232 fn from(uint: Uint) -> Mpint {
233 Mpint::from(&uint)
234 }
235}
236
237#[cfg(feature = "bigint")]
238impl TryFrom<Mpint> for Uint {
239 type Error = Error;
240
241 fn try_from(mpint: Mpint) -> Result<Uint> {
242 Uint::try_from(&mpint)
243 }
244}
245
246#[cfg(feature = "bigint")]
247impl TryFrom<&Mpint> for Uint {
248 type Error = Error;
249
250 fn try_from(mpint: &Mpint) -> Result<Uint> {
251 let bytes = mpint.as_positive_bytes().ok_or(Error::MpintEncoding)?;
253 Ok(Uint::from_be_slice_vartime(bytes))
254 }
255}
256
257#[cfg(test)]
258mod tests {
259 use super::Mpint;
260 use hex_literal::hex;
261
262 #[test]
263 fn decode_0() {
264 let n = Mpint::from_bytes(b"").unwrap();
265 assert_eq!(b"", n.as_bytes());
266 }
267
268 #[test]
269 fn reject_extra_leading_zeroes() {
270 assert!(Mpint::from_bytes(&hex!("00")).is_err());
271 assert!(Mpint::from_bytes(&hex!("00 00")).is_err());
272 assert!(Mpint::from_bytes(&hex!("00 01")).is_err());
273 }
274
275 #[test]
276 fn decode_9a378f9b2e332a7() {
277 assert!(Mpint::from_bytes(&hex!("09 a3 78 f9 b2 e3 32 a7")).is_ok());
278 }
279
280 #[test]
281 fn decode_80() {
282 let n = Mpint::from_bytes(&hex!("00 80")).unwrap();
283
284 assert_eq!(&hex!("80"), n.as_positive_bytes().unwrap());
286 }
287 #[test]
288 fn from_positive_bytes_strips_leading_zeroes() {
289 assert_eq!(Mpint::from_positive_bytes(&hex!("00")).as_ref(), b"");
290 assert_eq!(Mpint::from_positive_bytes(&hex!("00 00")).as_ref(), b"");
291 assert_eq!(Mpint::from_positive_bytes(&hex!("00 01")).as_ref(), b"\x01");
292 }
293
294 #[test]
296 fn decode_neg_1234() {
297 let n = Mpint::from_bytes(&hex!("ed cc")).unwrap();
298 assert!(n.as_positive_bytes().is_none());
299 }
300
301 #[test]
303 fn decode_neg_deadbeef() {
304 let n = Mpint::from_bytes(&hex!("ff 21 52 41 11")).unwrap();
305 assert!(n.as_positive_bytes().is_none());
306 }
307}