Skip to main content

ssh_encoding/
mpint.rs

1//! Multiple precision integer.
2
3use crate::{CheckedSum, Decode, Encode, Error, Reader, Result, Writer};
4use alloc::{boxed::Box, vec::Vec};
5use core::fmt;
6
7#[cfg(feature = "bigint")]
8use crate::Uint;
9
10#[cfg(feature = "ctutils")]
11use ctutils::{Choice, CtEq};
12
13#[cfg(any(feature = "bigint", feature = "zeroize"))]
14use zeroize::Zeroize;
15#[cfg(feature = "bigint")]
16use zeroize::Zeroizing;
17
18/// Multiple precision integer, a.k.a. `mpint`.
19///
20/// Described in [RFC4251 ยง 5](https://datatracker.ietf.org/doc/html/rfc4251#section-5):
21///
22/// > Represents multiple precision integers in two's complement format,
23/// > stored as a string, 8 bits per byte, MSB first.  Negative numbers
24/// > have the value 1 as the most significant bit of the first byte of
25/// > the data partition.  If the most significant bit would be set for
26/// > a positive number, the number MUST be preceded by a zero byte.
27/// > Unnecessary leading bytes with the value 0 or 255 MUST NOT be
28/// > included.  The value zero MUST be stored as a string with zero
29/// > bytes of data.
30/// >
31/// > By convention, a number that is used in modular computations in
32/// > Z_n SHOULD be represented in the range 0 <= x < n.
33///
34/// ## Examples
35///
36/// | value (hex)     | representation (hex) |
37/// |-----------------|----------------------|
38/// | 0               | `00 00 00 00`
39/// | 9a378f9b2e332a7 | `00 00 00 08 09 a3 78 f9 b2 e3 32 a7`
40/// | 80              | `00 00 00 02 00 80`
41/// |-1234            | `00 00 00 02 ed cc`
42/// | -deadbeef       | `00 00 00 05 ff 21 52 41 11`
43#[cfg_attr(not(feature = "ctutils"), derive(Clone))]
44#[cfg_attr(feature = "ctutils", derive(Clone, Ord, PartialOrd))] // TODO: constant time (Partial)`Ord`?
45pub struct Mpint {
46    /// Inner big endian-serialized integer value
47    inner: Box<[u8]>,
48}
49
50impl Mpint {
51    /// Create a new multiple precision integer from the given big endian-encoded byte slice.
52    ///
53    /// Note that this method expects a leading zero on positive integers whose MSB is set, but does
54    /// *NOT* expect a 4-byte length prefix.
55    ///
56    /// # Errors
57    /// Returns [`Error::MpintEncoding`] in the event of an unnecessary leading `0`.
58    pub fn from_bytes(bytes: &[u8]) -> Result<Self> {
59        bytes.try_into()
60    }
61
62    /// Create a new multiple precision integer from the given big endian encoded byte slice
63    /// representing a positive integer.
64    ///
65    /// The input may begin with leading zeros, which will be stripped when converted to [`Mpint`]
66    /// encoding.
67    #[must_use]
68    pub fn from_positive_bytes(mut bytes: &[u8]) -> Self {
69        // Strip leading zeros
70        while bytes.first().copied() == Some(0) {
71            bytes = &bytes[1..];
72        }
73
74        // Add a leading zero to the output if necessary
75        let inner = match bytes.first().copied() {
76            Some(n) if n >= 0x80 => {
77                let mut inner = Vec::with_capacity(bytes.len().saturating_add(1));
78                inner.push(0);
79                inner.extend_from_slice(bytes);
80                inner
81            }
82            _ => Vec::from(bytes),
83        };
84
85        Self {
86            inner: inner.into_boxed_slice(),
87        }
88    }
89
90    /// Get the big integer data encoded as big endian bytes.
91    ///
92    /// This slice will contain a leading zero if the value is positive but the
93    /// MSB is also set. Use [`Mpint::as_positive_bytes`] to ensure the number
94    /// is positive and strip the leading zero byte if it exists.
95    #[must_use]
96    pub fn as_bytes(&self) -> &[u8] {
97        &self.inner
98    }
99
100    /// Get the bytes of a positive integer.
101    ///
102    /// # Returns
103    /// - `Some(bytes)` if the number is positive. The leading zero byte will be stripped.
104    /// - `None` if the value is negative
105    #[must_use]
106    pub fn as_positive_bytes(&self) -> Option<&[u8]> {
107        match self.as_bytes() {
108            [0x00, rest @ ..] => Some(rest),
109            [byte, ..] if *byte < 0x80 => Some(self.as_bytes()),
110            _ => None,
111        }
112    }
113
114    /// Is this [`Mpint`] positive?
115    #[must_use]
116    pub fn is_positive(&self) -> bool {
117        self.as_positive_bytes().is_some()
118    }
119}
120
121impl AsRef<[u8]> for Mpint {
122    fn as_ref(&self) -> &[u8] {
123        self.as_bytes()
124    }
125}
126
127#[cfg(feature = "ctutils")]
128impl CtEq for Mpint {
129    fn ct_eq(&self, other: &Self) -> Choice {
130        self.as_ref().ct_eq(other.as_ref())
131    }
132}
133
134#[cfg(feature = "ctutils")]
135impl Eq for Mpint {}
136
137#[cfg(feature = "ctutils")]
138impl PartialEq for Mpint {
139    fn eq(&self, other: &Self) -> bool {
140        self.ct_eq(other).into()
141    }
142}
143
144impl Decode for Mpint {
145    type Error = Error;
146
147    fn decode(reader: &mut impl Reader) -> Result<Self> {
148        Vec::decode(reader)?.into_boxed_slice().try_into()
149    }
150}
151
152impl Encode for Mpint {
153    fn encoded_len(&self) -> Result<usize> {
154        [4, self.as_bytes().len()].checked_sum()
155    }
156
157    fn encode(&self, writer: &mut impl Writer) -> Result<()> {
158        self.as_bytes().encode(writer)?;
159        Ok(())
160    }
161}
162
163impl TryFrom<&[u8]> for Mpint {
164    type Error = Error;
165
166    fn try_from(bytes: &[u8]) -> Result<Self> {
167        Vec::from(bytes).into_boxed_slice().try_into()
168    }
169}
170
171impl TryFrom<Box<[u8]>> for Mpint {
172    type Error = Error;
173
174    fn try_from(bytes: Box<[u8]>) -> Result<Self> {
175        match &*bytes {
176            // Unnecessary leading 0
177            [0x00] => Err(Error::MpintEncoding),
178            // Unnecessary leading 0
179            [0x00, n, ..] if *n < 0x80 => Err(Error::MpintEncoding),
180            _ => Ok(Self { inner: bytes }),
181        }
182    }
183}
184
185#[cfg(feature = "zeroize")]
186impl Zeroize for Mpint {
187    fn zeroize(&mut self) {
188        self.inner.zeroize();
189    }
190}
191
192impl fmt::Debug for Mpint {
193    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
194        write!(f, "Mpint({self:X})")
195    }
196}
197
198impl fmt::Display for Mpint {
199    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
200        write!(f, "{self:X}")
201    }
202}
203
204impl fmt::LowerHex for Mpint {
205    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
206        for byte in self.as_bytes() {
207            write!(f, "{byte:02x}")?;
208        }
209        Ok(())
210    }
211}
212
213impl fmt::UpperHex for Mpint {
214    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
215        for byte in self.as_bytes() {
216            write!(f, "{byte:02X}")?;
217        }
218        Ok(())
219    }
220}
221
222#[cfg(feature = "bigint")]
223impl From<&Uint> for Mpint {
224    fn from(uint: &Uint) -> Mpint {
225        let bytes = Zeroizing::new(uint.to_be_bytes());
226        Mpint::from_positive_bytes(&bytes)
227    }
228}
229
230#[cfg(feature = "bigint")]
231impl From<Uint> for Mpint {
232    fn from(uint: Uint) -> Mpint {
233        Mpint::from(&uint)
234    }
235}
236
237#[cfg(feature = "bigint")]
238impl TryFrom<Mpint> for Uint {
239    type Error = Error;
240
241    fn try_from(mpint: Mpint) -> Result<Uint> {
242        Uint::try_from(&mpint)
243    }
244}
245
246#[cfg(feature = "bigint")]
247impl TryFrom<&Mpint> for Uint {
248    type Error = Error;
249
250    fn try_from(mpint: &Mpint) -> Result<Uint> {
251        // TODO(tarcieri): enforce a maximum size?
252        let bytes = mpint.as_positive_bytes().ok_or(Error::MpintEncoding)?;
253        Ok(Uint::from_be_slice_vartime(bytes))
254    }
255}
256
257#[cfg(test)]
258mod tests {
259    use super::Mpint;
260    use hex_literal::hex;
261
262    #[test]
263    fn decode_0() {
264        let n = Mpint::from_bytes(b"").unwrap();
265        assert_eq!(b"", n.as_bytes());
266    }
267
268    #[test]
269    fn reject_extra_leading_zeroes() {
270        assert!(Mpint::from_bytes(&hex!("00")).is_err());
271        assert!(Mpint::from_bytes(&hex!("00 00")).is_err());
272        assert!(Mpint::from_bytes(&hex!("00 01")).is_err());
273    }
274
275    #[test]
276    fn decode_9a378f9b2e332a7() {
277        assert!(Mpint::from_bytes(&hex!("09 a3 78 f9 b2 e3 32 a7")).is_ok());
278    }
279
280    #[test]
281    fn decode_80() {
282        let n = Mpint::from_bytes(&hex!("00 80")).unwrap();
283
284        // Leading zero stripped
285        assert_eq!(&hex!("80"), n.as_positive_bytes().unwrap());
286    }
287    #[test]
288    fn from_positive_bytes_strips_leading_zeroes() {
289        assert_eq!(Mpint::from_positive_bytes(&hex!("00")).as_ref(), b"");
290        assert_eq!(Mpint::from_positive_bytes(&hex!("00 00")).as_ref(), b"");
291        assert_eq!(Mpint::from_positive_bytes(&hex!("00 01")).as_ref(), b"\x01");
292    }
293
294    // TODO(tarcieri): drop support for negative numbers?
295    #[test]
296    fn decode_neg_1234() {
297        let n = Mpint::from_bytes(&hex!("ed cc")).unwrap();
298        assert!(n.as_positive_bytes().is_none());
299    }
300
301    // TODO(tarcieri): drop support for negative numbers?
302    #[test]
303    fn decode_neg_deadbeef() {
304        let n = Mpint::from_bytes(&hex!("ff 21 52 41 11")).unwrap();
305        assert!(n.as_positive_bytes().is_none());
306    }
307}