ssh_key/
mpint.rs

1//! Multiple precision integer
2
3use crate::{Error, Result};
4use alloc::{boxed::Box, vec::Vec};
5use core::fmt;
6use encoding::{CheckedSum, Decode, Encode, Reader, Writer};
7use subtle::{Choice, ConstantTimeEq};
8use zeroize::Zeroize;
9
10#[cfg(any(feature = "dsa", feature = "rsa"))]
11use zeroize::Zeroizing;
12
13/// Multiple precision integer, a.k.a. "mpint".
14///
15/// This type is used for representing the big integer components of
16/// DSA and RSA keys.
17///
18/// Described in [RFC4251 ยง 5](https://datatracker.ietf.org/doc/html/rfc4251#section-5):
19///
20/// > Represents multiple precision integers in two's complement format,
21/// > stored as a string, 8 bits per byte, MSB first.  Negative numbers
22/// > have the value 1 as the most significant bit of the first byte of
23/// > the data partition.  If the most significant bit would be set for
24/// > a positive number, the number MUST be preceded by a zero byte.
25/// > Unnecessary leading bytes with the value 0 or 255 MUST NOT be
26/// > included.  The value zero MUST be stored as a string with zero
27/// > bytes of data.
28/// >
29/// > By convention, a number that is used in modular computations in
30/// > Z_n SHOULD be represented in the range 0 <= x < n.
31///
32/// ## Examples
33///
34/// | value (hex)     | representation (hex) |
35/// |-----------------|----------------------|
36/// | 0               | `00 00 00 00`
37/// | 9a378f9b2e332a7 | `00 00 00 08 09 a3 78 f9 b2 e3 32 a7`
38/// | 80              | `00 00 00 02 00 80`
39/// |-1234            | `00 00 00 02 ed cc`
40/// | -deadbeef       | `00 00 00 05 ff 21 52 41 11`
41#[derive(Clone, PartialOrd, Ord)]
42pub struct Mpint {
43    /// Inner big endian-serialized integer value
44    inner: Box<[u8]>,
45}
46
47impl Mpint {
48    /// Create a new multiple precision integer from the given
49    /// big endian-encoded byte slice.
50    ///
51    /// Note that this method expects a leading zero on positive integers whose
52    /// MSB is set, but does *NOT* expect a 4-byte length prefix.
53    pub fn from_bytes(bytes: &[u8]) -> Result<Self> {
54        bytes.try_into()
55    }
56
57    /// Create a new multiple precision integer from the given big endian
58    /// encoded byte slice representing a positive integer.
59    ///
60    /// The input may begin with leading zeros, which will be stripped when
61    /// converted to [`Mpint`] encoding.
62    pub fn from_positive_bytes(mut bytes: &[u8]) -> Result<Self> {
63        let mut inner = Vec::with_capacity(bytes.len());
64
65        while bytes.first().copied() == Some(0) {
66            bytes = &bytes[1..];
67        }
68
69        match bytes.first().copied() {
70            Some(n) if n >= 0x80 => inner.push(0),
71            _ => (),
72        }
73
74        inner.extend_from_slice(bytes);
75        inner.into_boxed_slice().try_into()
76    }
77
78    /// Get the big integer data encoded as big endian bytes.
79    ///
80    /// This slice will contain a leading zero if the value is positive but the
81    /// MSB is also set. Use [`Mpint::as_positive_bytes`] to ensure the number
82    /// is positive and strip the leading zero byte if it exists.
83    pub fn as_bytes(&self) -> &[u8] {
84        &self.inner
85    }
86
87    /// Get the bytes of a positive integer.
88    ///
89    /// # Returns
90    /// - `Some(bytes)` if the number is positive. The leading zero byte will be stripped.
91    /// - `None` if the value is negative
92    pub fn as_positive_bytes(&self) -> Option<&[u8]> {
93        match self.as_bytes() {
94            [0x00, rest @ ..] => Some(rest),
95            [byte, ..] if *byte < 0x80 => Some(self.as_bytes()),
96            _ => None,
97        }
98    }
99}
100
101impl AsRef<[u8]> for Mpint {
102    fn as_ref(&self) -> &[u8] {
103        self.as_bytes()
104    }
105}
106
107impl ConstantTimeEq for Mpint {
108    fn ct_eq(&self, other: &Self) -> Choice {
109        self.as_ref().ct_eq(other.as_ref())
110    }
111}
112
113impl Eq for Mpint {}
114
115impl PartialEq for Mpint {
116    fn eq(&self, other: &Self) -> bool {
117        self.ct_eq(other).into()
118    }
119}
120
121impl Decode for Mpint {
122    type Error = Error;
123
124    fn decode(reader: &mut impl Reader) -> Result<Self> {
125        Vec::decode(reader)?.into_boxed_slice().try_into()
126    }
127}
128
129impl Encode for Mpint {
130    fn encoded_len(&self) -> encoding::Result<usize> {
131        [4, self.as_bytes().len()].checked_sum()
132    }
133
134    fn encode(&self, writer: &mut impl Writer) -> encoding::Result<()> {
135        self.as_bytes().encode(writer)?;
136        Ok(())
137    }
138}
139
140impl TryFrom<&[u8]> for Mpint {
141    type Error = Error;
142
143    fn try_from(bytes: &[u8]) -> Result<Self> {
144        Vec::from(bytes).into_boxed_slice().try_into()
145    }
146}
147
148impl TryFrom<Box<[u8]>> for Mpint {
149    type Error = Error;
150
151    fn try_from(bytes: Box<[u8]>) -> Result<Self> {
152        match &*bytes {
153            // Unnecessary leading 0
154            [0x00] => Err(Error::FormatEncoding),
155            // Unnecessary leading 0
156            [0x00, n, ..] if *n < 0x80 => Err(Error::FormatEncoding),
157            _ => Ok(Self { inner: bytes }),
158        }
159    }
160}
161
162impl Zeroize for Mpint {
163    fn zeroize(&mut self) {
164        self.inner.zeroize();
165    }
166}
167
168impl fmt::Debug for Mpint {
169    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
170        write!(f, "Mpint({self:X})")
171    }
172}
173
174impl fmt::Display for Mpint {
175    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
176        write!(f, "{self:X}")
177    }
178}
179
180impl fmt::LowerHex for Mpint {
181    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
182        for byte in self.as_bytes() {
183            write!(f, "{byte:02x}")?;
184        }
185        Ok(())
186    }
187}
188
189impl fmt::UpperHex for Mpint {
190    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
191        for byte in self.as_bytes() {
192            write!(f, "{byte:02X}")?;
193        }
194        Ok(())
195    }
196}
197
198#[cfg(any(feature = "dsa", feature = "rsa"))]
199impl TryFrom<bigint::BigUint> for Mpint {
200    type Error = Error;
201
202    fn try_from(uint: bigint::BigUint) -> Result<Mpint> {
203        Mpint::try_from(&uint)
204    }
205}
206
207#[cfg(any(feature = "dsa", feature = "rsa"))]
208impl TryFrom<&bigint::BigUint> for Mpint {
209    type Error = Error;
210
211    fn try_from(uint: &bigint::BigUint) -> Result<Mpint> {
212        let bytes = Zeroizing::new(uint.to_bytes_be());
213        Mpint::from_positive_bytes(bytes.as_slice())
214    }
215}
216
217#[cfg(any(feature = "dsa", feature = "rsa"))]
218impl TryFrom<Mpint> for bigint::BigUint {
219    type Error = Error;
220
221    fn try_from(mpint: Mpint) -> Result<bigint::BigUint> {
222        bigint::BigUint::try_from(&mpint)
223    }
224}
225
226#[cfg(any(feature = "dsa", feature = "rsa"))]
227impl TryFrom<&Mpint> for bigint::BigUint {
228    type Error = Error;
229
230    fn try_from(mpint: &Mpint) -> Result<bigint::BigUint> {
231        mpint
232            .as_positive_bytes()
233            .map(bigint::BigUint::from_bytes_be)
234            .ok_or(Error::Crypto)
235    }
236}
237
238#[cfg(test)]
239mod tests {
240    use super::Mpint;
241    use hex_literal::hex;
242
243    #[test]
244    fn decode_0() {
245        let n = Mpint::from_bytes(b"").unwrap();
246        assert_eq!(b"", n.as_bytes())
247    }
248
249    #[test]
250    fn reject_extra_leading_zeroes() {
251        assert!(Mpint::from_bytes(&hex!("00")).is_err());
252        assert!(Mpint::from_bytes(&hex!("00 00")).is_err());
253        assert!(Mpint::from_bytes(&hex!("00 01")).is_err());
254    }
255
256    #[test]
257    fn decode_9a378f9b2e332a7() {
258        assert!(Mpint::from_bytes(&hex!("09 a3 78 f9 b2 e3 32 a7")).is_ok());
259    }
260
261    #[test]
262    fn decode_80() {
263        let n = Mpint::from_bytes(&hex!("00 80")).unwrap();
264
265        // Leading zero stripped
266        assert_eq!(&hex!("80"), n.as_positive_bytes().unwrap())
267    }
268    #[test]
269    fn from_positive_bytes_strips_leading_zeroes() {
270        assert_eq!(
271            Mpint::from_positive_bytes(&hex!("00")).unwrap().as_ref(),
272            b""
273        );
274        assert_eq!(
275            Mpint::from_positive_bytes(&hex!("00 00")).unwrap().as_ref(),
276            b""
277        );
278        assert_eq!(
279            Mpint::from_positive_bytes(&hex!("00 01")).unwrap().as_ref(),
280            b"\x01"
281        );
282    }
283
284    // TODO(tarcieri): drop support for negative numbers?
285    #[test]
286    fn decode_neg_1234() {
287        let n = Mpint::from_bytes(&hex!("ed cc")).unwrap();
288        assert!(n.as_positive_bytes().is_none());
289    }
290
291    // TODO(tarcieri): drop support for negative numbers?
292    #[test]
293    fn decode_neg_deadbeef() {
294        let n = Mpint::from_bytes(&hex!("ff 21 52 41 11")).unwrap();
295        assert!(n.as_positive_bytes().is_none());
296    }
297}