ssh_encoding/
mpint.rs

1//! Multiple precision integer
2
3use crate::{CheckedSum, Decode, Encode, Error, Reader, Result, Writer};
4use alloc::{boxed::Box, vec::Vec};
5use core::fmt;
6
7#[cfg(feature = "bigint")]
8use crate::Uint;
9
10#[cfg(feature = "subtle")]
11use subtle::{Choice, ConstantTimeEq};
12
13#[cfg(any(feature = "bigint", feature = "zeroize"))]
14use zeroize::Zeroize;
15#[cfg(feature = "bigint")]
16use zeroize::Zeroizing;
17
18/// Multiple precision integer, a.k.a. `mpint`.
19///
20/// Described in [RFC4251 ยง 5](https://datatracker.ietf.org/doc/html/rfc4251#section-5):
21///
22/// > Represents multiple precision integers in two's complement format,
23/// > stored as a string, 8 bits per byte, MSB first.  Negative numbers
24/// > have the value 1 as the most significant bit of the first byte of
25/// > the data partition.  If the most significant bit would be set for
26/// > a positive number, the number MUST be preceded by a zero byte.
27/// > Unnecessary leading bytes with the value 0 or 255 MUST NOT be
28/// > included.  The value zero MUST be stored as a string with zero
29/// > bytes of data.
30/// >
31/// > By convention, a number that is used in modular computations in
32/// > Z_n SHOULD be represented in the range 0 <= x < n.
33///
34/// ## Examples
35///
36/// | value (hex)     | representation (hex) |
37/// |-----------------|----------------------|
38/// | 0               | `00 00 00 00`
39/// | 9a378f9b2e332a7 | `00 00 00 08 09 a3 78 f9 b2 e3 32 a7`
40/// | 80              | `00 00 00 02 00 80`
41/// |-1234            | `00 00 00 02 ed cc`
42/// | -deadbeef       | `00 00 00 05 ff 21 52 41 11`
43#[cfg_attr(not(feature = "subtle"), derive(Clone))]
44#[cfg_attr(feature = "subtle", derive(Clone, Ord, PartialOrd))] // TODO: constant time (Partial)`Ord`?
45pub struct Mpint {
46    /// Inner big endian-serialized integer value
47    inner: Box<[u8]>,
48}
49
50impl Mpint {
51    /// Create a new multiple precision integer from the given
52    /// big endian-encoded byte slice.
53    ///
54    /// Note that this method expects a leading zero on positive integers whose
55    /// MSB is set, but does *NOT* expect a 4-byte length prefix.
56    pub fn from_bytes(bytes: &[u8]) -> Result<Self> {
57        bytes.try_into()
58    }
59
60    /// Create a new multiple precision integer from the given big endian
61    /// encoded byte slice representing a positive integer.
62    ///
63    /// The input may begin with leading zeros, which will be stripped when
64    /// converted to [`Mpint`] encoding.
65    pub fn from_positive_bytes(mut bytes: &[u8]) -> Result<Self> {
66        let mut inner = Vec::with_capacity(bytes.len());
67
68        while bytes.first().copied() == Some(0) {
69            bytes = &bytes[1..];
70        }
71
72        match bytes.first().copied() {
73            Some(n) if n >= 0x80 => inner.push(0),
74            _ => (),
75        }
76
77        inner.extend_from_slice(bytes);
78        inner.into_boxed_slice().try_into()
79    }
80
81    /// Get the big integer data encoded as big endian bytes.
82    ///
83    /// This slice will contain a leading zero if the value is positive but the
84    /// MSB is also set. Use [`Mpint::as_positive_bytes`] to ensure the number
85    /// is positive and strip the leading zero byte if it exists.
86    pub fn as_bytes(&self) -> &[u8] {
87        &self.inner
88    }
89
90    /// Get the bytes of a positive integer.
91    ///
92    /// # Returns
93    /// - `Some(bytes)` if the number is positive. The leading zero byte will be stripped.
94    /// - `None` if the value is negative
95    pub fn as_positive_bytes(&self) -> Option<&[u8]> {
96        match self.as_bytes() {
97            [0x00, rest @ ..] => Some(rest),
98            [byte, ..] if *byte < 0x80 => Some(self.as_bytes()),
99            _ => None,
100        }
101    }
102
103    /// Is this [`Mpint`] positive?
104    pub fn is_positive(&self) -> bool {
105        self.as_positive_bytes().is_some()
106    }
107}
108
109impl AsRef<[u8]> for Mpint {
110    fn as_ref(&self) -> &[u8] {
111        self.as_bytes()
112    }
113}
114
115#[cfg(feature = "subtle")]
116impl ConstantTimeEq for Mpint {
117    fn ct_eq(&self, other: &Self) -> Choice {
118        self.as_ref().ct_eq(other.as_ref())
119    }
120}
121
122#[cfg(feature = "subtle")]
123impl Eq for Mpint {}
124
125#[cfg(feature = "subtle")]
126impl PartialEq for Mpint {
127    fn eq(&self, other: &Self) -> bool {
128        self.ct_eq(other).into()
129    }
130}
131
132impl Decode for Mpint {
133    type Error = Error;
134
135    fn decode(reader: &mut impl Reader) -> Result<Self> {
136        Vec::decode(reader)?.into_boxed_slice().try_into()
137    }
138}
139
140impl Encode for Mpint {
141    fn encoded_len(&self) -> Result<usize> {
142        [4, self.as_bytes().len()].checked_sum()
143    }
144
145    fn encode(&self, writer: &mut impl Writer) -> Result<()> {
146        self.as_bytes().encode(writer)?;
147        Ok(())
148    }
149}
150
151impl TryFrom<&[u8]> for Mpint {
152    type Error = Error;
153
154    fn try_from(bytes: &[u8]) -> Result<Self> {
155        Vec::from(bytes).into_boxed_slice().try_into()
156    }
157}
158
159impl TryFrom<Box<[u8]>> for Mpint {
160    type Error = Error;
161
162    fn try_from(bytes: Box<[u8]>) -> Result<Self> {
163        match &*bytes {
164            // Unnecessary leading 0
165            [0x00] => Err(Error::MpintEncoding),
166            // Unnecessary leading 0
167            [0x00, n, ..] if *n < 0x80 => Err(Error::MpintEncoding),
168            _ => Ok(Self { inner: bytes }),
169        }
170    }
171}
172
173#[cfg(feature = "zeroize")]
174impl Zeroize for Mpint {
175    fn zeroize(&mut self) {
176        self.inner.zeroize();
177    }
178}
179
180impl fmt::Debug for Mpint {
181    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
182        write!(f, "Mpint({self:X})")
183    }
184}
185
186impl fmt::Display for Mpint {
187    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
188        write!(f, "{self:X}")
189    }
190}
191
192impl fmt::LowerHex for Mpint {
193    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
194        for byte in self.as_bytes() {
195            write!(f, "{byte:02x}")?;
196        }
197        Ok(())
198    }
199}
200
201impl fmt::UpperHex for Mpint {
202    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
203        for byte in self.as_bytes() {
204            write!(f, "{byte:02X}")?;
205        }
206        Ok(())
207    }
208}
209
210#[cfg(feature = "bigint")]
211impl TryFrom<Uint> for Mpint {
212    type Error = Error;
213
214    fn try_from(uint: Uint) -> Result<Mpint> {
215        Mpint::try_from(&uint)
216    }
217}
218
219#[cfg(feature = "bigint")]
220impl TryFrom<&Uint> for Mpint {
221    type Error = Error;
222
223    fn try_from(uint: &Uint) -> Result<Mpint> {
224        let bytes = Zeroizing::new(uint.to_be_bytes());
225        Mpint::from_positive_bytes(&bytes)
226    }
227}
228
229#[cfg(feature = "bigint")]
230impl TryFrom<Mpint> for Uint {
231    type Error = Error;
232
233    fn try_from(mpint: Mpint) -> Result<Uint> {
234        Uint::try_from(&mpint)
235    }
236}
237
238#[cfg(feature = "bigint")]
239impl TryFrom<&Mpint> for Uint {
240    type Error = Error;
241
242    fn try_from(mpint: &Mpint) -> Result<Uint> {
243        // TODO(tarcieri): enforce a maximum size?
244        let bytes = mpint.as_positive_bytes().ok_or(Error::MpintEncoding)?;
245        Ok(Uint::from_be_slice_vartime(bytes))
246    }
247}
248
249#[cfg(test)]
250mod tests {
251    use super::Mpint;
252    use hex_literal::hex;
253
254    #[test]
255    fn decode_0() {
256        let n = Mpint::from_bytes(b"").unwrap();
257        assert_eq!(b"", n.as_bytes())
258    }
259
260    #[test]
261    fn reject_extra_leading_zeroes() {
262        assert!(Mpint::from_bytes(&hex!("00")).is_err());
263        assert!(Mpint::from_bytes(&hex!("00 00")).is_err());
264        assert!(Mpint::from_bytes(&hex!("00 01")).is_err());
265    }
266
267    #[test]
268    fn decode_9a378f9b2e332a7() {
269        assert!(Mpint::from_bytes(&hex!("09 a3 78 f9 b2 e3 32 a7")).is_ok());
270    }
271
272    #[test]
273    fn decode_80() {
274        let n = Mpint::from_bytes(&hex!("00 80")).unwrap();
275
276        // Leading zero stripped
277        assert_eq!(&hex!("80"), n.as_positive_bytes().unwrap())
278    }
279    #[test]
280    fn from_positive_bytes_strips_leading_zeroes() {
281        assert_eq!(
282            Mpint::from_positive_bytes(&hex!("00")).unwrap().as_ref(),
283            b""
284        );
285        assert_eq!(
286            Mpint::from_positive_bytes(&hex!("00 00")).unwrap().as_ref(),
287            b""
288        );
289        assert_eq!(
290            Mpint::from_positive_bytes(&hex!("00 01")).unwrap().as_ref(),
291            b"\x01"
292        );
293    }
294
295    // TODO(tarcieri): drop support for negative numbers?
296    #[test]
297    fn decode_neg_1234() {
298        let n = Mpint::from_bytes(&hex!("ed cc")).unwrap();
299        assert!(n.as_positive_bytes().is_none());
300    }
301
302    // TODO(tarcieri): drop support for negative numbers?
303    #[test]
304    fn decode_neg_deadbeef() {
305        let n = Mpint::from_bytes(&hex!("ff 21 52 41 11")).unwrap();
306        assert!(n.as_positive_bytes().is_none());
307    }
308}