1use crate::{Error, Result};
4use alloc::{boxed::Box, vec::Vec};
5use core::fmt;
6use encoding::{CheckedSum, Decode, Encode, Reader, Writer};
7use subtle::{Choice, ConstantTimeEq};
8use zeroize::Zeroize;
9
10#[cfg(any(feature = "dsa", feature = "rsa"))]
11use zeroize::Zeroizing;
12
13#[derive(Clone, PartialOrd, Ord)]
42pub struct Mpint {
43 inner: Box<[u8]>,
45}
46
47impl Mpint {
48 pub fn from_bytes(bytes: &[u8]) -> Result<Self> {
54 bytes.try_into()
55 }
56
57 pub fn from_positive_bytes(mut bytes: &[u8]) -> Result<Self> {
63 let mut inner = Vec::with_capacity(bytes.len());
64
65 while bytes.first().copied() == Some(0) {
66 bytes = &bytes[1..];
67 }
68
69 match bytes.first().copied() {
70 Some(n) if n >= 0x80 => inner.push(0),
71 _ => (),
72 }
73
74 inner.extend_from_slice(bytes);
75 inner.into_boxed_slice().try_into()
76 }
77
78 pub fn as_bytes(&self) -> &[u8] {
84 &self.inner
85 }
86
87 pub fn as_positive_bytes(&self) -> Option<&[u8]> {
93 match self.as_bytes() {
94 [0x00, rest @ ..] => Some(rest),
95 [byte, ..] if *byte < 0x80 => Some(self.as_bytes()),
96 _ => None,
97 }
98 }
99}
100
101impl AsRef<[u8]> for Mpint {
102 fn as_ref(&self) -> &[u8] {
103 self.as_bytes()
104 }
105}
106
107impl ConstantTimeEq for Mpint {
108 fn ct_eq(&self, other: &Self) -> Choice {
109 self.as_ref().ct_eq(other.as_ref())
110 }
111}
112
113impl Eq for Mpint {}
114
115impl PartialEq for Mpint {
116 fn eq(&self, other: &Self) -> bool {
117 self.ct_eq(other).into()
118 }
119}
120
121impl Decode for Mpint {
122 type Error = Error;
123
124 fn decode(reader: &mut impl Reader) -> Result<Self> {
125 Vec::decode(reader)?.into_boxed_slice().try_into()
126 }
127}
128
129impl Encode for Mpint {
130 fn encoded_len(&self) -> encoding::Result<usize> {
131 [4, self.as_bytes().len()].checked_sum()
132 }
133
134 fn encode(&self, writer: &mut impl Writer) -> encoding::Result<()> {
135 self.as_bytes().encode(writer)?;
136 Ok(())
137 }
138}
139
140impl TryFrom<&[u8]> for Mpint {
141 type Error = Error;
142
143 fn try_from(bytes: &[u8]) -> Result<Self> {
144 Vec::from(bytes).into_boxed_slice().try_into()
145 }
146}
147
148impl TryFrom<Box<[u8]>> for Mpint {
149 type Error = Error;
150
151 fn try_from(bytes: Box<[u8]>) -> Result<Self> {
152 match &*bytes {
153 [0x00] => Err(Error::FormatEncoding),
155 [0x00, n, ..] if *n < 0x80 => Err(Error::FormatEncoding),
157 _ => Ok(Self { inner: bytes }),
158 }
159 }
160}
161
162impl Zeroize for Mpint {
163 fn zeroize(&mut self) {
164 self.inner.zeroize();
165 }
166}
167
168impl fmt::Debug for Mpint {
169 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
170 write!(f, "Mpint({self:X})")
171 }
172}
173
174impl fmt::Display for Mpint {
175 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
176 write!(f, "{self:X}")
177 }
178}
179
180impl fmt::LowerHex for Mpint {
181 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
182 for byte in self.as_bytes() {
183 write!(f, "{byte:02x}")?;
184 }
185 Ok(())
186 }
187}
188
189impl fmt::UpperHex for Mpint {
190 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
191 for byte in self.as_bytes() {
192 write!(f, "{byte:02X}")?;
193 }
194 Ok(())
195 }
196}
197
198#[cfg(any(feature = "dsa", feature = "rsa"))]
199impl TryFrom<bigint::BigUint> for Mpint {
200 type Error = Error;
201
202 fn try_from(uint: bigint::BigUint) -> Result<Mpint> {
203 Mpint::try_from(&uint)
204 }
205}
206
207#[cfg(any(feature = "dsa", feature = "rsa"))]
208impl TryFrom<&bigint::BigUint> for Mpint {
209 type Error = Error;
210
211 fn try_from(uint: &bigint::BigUint) -> Result<Mpint> {
212 let bytes = Zeroizing::new(uint.to_bytes_be());
213 Mpint::from_positive_bytes(bytes.as_slice())
214 }
215}
216
217#[cfg(any(feature = "dsa", feature = "rsa"))]
218impl TryFrom<Mpint> for bigint::BigUint {
219 type Error = Error;
220
221 fn try_from(mpint: Mpint) -> Result<bigint::BigUint> {
222 bigint::BigUint::try_from(&mpint)
223 }
224}
225
226#[cfg(any(feature = "dsa", feature = "rsa"))]
227impl TryFrom<&Mpint> for bigint::BigUint {
228 type Error = Error;
229
230 fn try_from(mpint: &Mpint) -> Result<bigint::BigUint> {
231 mpint
232 .as_positive_bytes()
233 .map(bigint::BigUint::from_bytes_be)
234 .ok_or(Error::Crypto)
235 }
236}
237
238#[cfg(test)]
239mod tests {
240 use super::Mpint;
241 use hex_literal::hex;
242
243 #[test]
244 fn decode_0() {
245 let n = Mpint::from_bytes(b"").unwrap();
246 assert_eq!(b"", n.as_bytes())
247 }
248
249 #[test]
250 fn reject_extra_leading_zeroes() {
251 assert!(Mpint::from_bytes(&hex!("00")).is_err());
252 assert!(Mpint::from_bytes(&hex!("00 00")).is_err());
253 assert!(Mpint::from_bytes(&hex!("00 01")).is_err());
254 }
255
256 #[test]
257 fn decode_9a378f9b2e332a7() {
258 assert!(Mpint::from_bytes(&hex!("09 a3 78 f9 b2 e3 32 a7")).is_ok());
259 }
260
261 #[test]
262 fn decode_80() {
263 let n = Mpint::from_bytes(&hex!("00 80")).unwrap();
264
265 assert_eq!(&hex!("80"), n.as_positive_bytes().unwrap())
267 }
268 #[test]
269 fn from_positive_bytes_strips_leading_zeroes() {
270 assert_eq!(
271 Mpint::from_positive_bytes(&hex!("00")).unwrap().as_ref(),
272 b""
273 );
274 assert_eq!(
275 Mpint::from_positive_bytes(&hex!("00 00")).unwrap().as_ref(),
276 b""
277 );
278 assert_eq!(
279 Mpint::from_positive_bytes(&hex!("00 01")).unwrap().as_ref(),
280 b"\x01"
281 );
282 }
283
284 #[test]
286 fn decode_neg_1234() {
287 let n = Mpint::from_bytes(&hex!("ed cc")).unwrap();
288 assert!(n.as_positive_bytes().is_none());
289 }
290
291 #[test]
293 fn decode_neg_deadbeef() {
294 let n = Mpint::from_bytes(&hex!("ff 21 52 41 11")).unwrap();
295 assert!(n.as_positive_bytes().is_none());
296 }
297}