sunset/
sshwire.rs

1//! SSH wire format reading/writing.
2//!
3//! Used in conjunction with [`sunset_sshwire_derive`] and the [`packet`](crate::packets) format
4//! definitions.
5//!
6//! SSH wire format is described in [RFC4251](https://tools.ietf.org/html/rfc4251) SSH Architecture
7
8#[allow(unused_imports)]
9use {
10    crate::error::{Error, Result, TrapBug},
11    log::{debug, error, info, log, trace, warn},
12};
13
14use core::convert::AsRef;
15use core::fmt::{self, Debug, Display};
16use core::str::FromStr;
17use digest::Output;
18use pretty_hex::PrettyHex;
19use snafu::{prelude::*, Location};
20
21use ascii::{AsAsciiStr, AsciiChar, AsciiStr};
22
23#[cfg(feature = "arbitrary")]
24use arbitrary::{Arbitrary, Unstructured};
25
26use digest::Digest;
27
28use crate::*;
29use packets::{Packet, ParseContext};
30
31/// A generic destination for serializing, used similarly to `serde::Serializer`
32pub trait SSHSink {
33    fn push(&mut self, v: &[u8]) -> WireResult<()>;
34}
35
36/// A generic source for a packet, used similarly to `serde::Deserializer`
37pub trait SSHSource<'de> {
38    fn take(&mut self, len: usize) -> WireResult<&'de [u8]>;
39    fn remaining(&self) -> usize;
40    fn ctx(&mut self) -> &mut ParseContext;
41}
42
43/// Encodes the type in SSH wire format
44pub trait SSHEncode {
45    /// Encode data
46    ///
47    /// The state of the `SSHSink` is undefined after an error is returned, data may
48    /// have been partially encoded.
49    fn enc(&self, s: &mut dyn SSHSink) -> WireResult<()>;
50}
51
52/// For enums with an externally provided name
53pub trait SSHEncodeEnum {
54    /// Returns the current variant, used for encoding parent structs.
55    /// Fails if it is Unknown
56    fn variant_name(&self) -> WireResult<&'static str>;
57}
58
59/// Decodes `struct` and `enum`s without an externally provided enum name
60pub trait SSHDecode<'de>: Sized {
61    /// Decode data
62    ///
63    /// The state of the `SSHSource` is undefined after an error is returned, data may
64    /// have been partially consumed.
65    fn dec<S>(s: &mut S) -> WireResult<Self>
66    where
67        S: SSHSource<'de>;
68}
69
70/// Decodes enums with an externally provided name
71pub trait SSHDecodeEnum<'de>: Sized {
72    /// `var` is the variant name to decode, as raw bytes off the wire.
73    fn dec_enum<S>(s: &mut S, var: &'de [u8]) -> WireResult<Self>
74    where
75        S: SSHSource<'de>;
76}
77
78/// A subset of [`Error`] for `SSHEncode` and `SSHDecode`.
79///
80/// Compiled code size is very sensitive to the size of this
81/// enum so we avoid unused elements.
82#[derive(Debug)]
83pub enum WireError {
84    NoRoom,
85
86    RanOut,
87
88    BadString,
89
90    BadName,
91
92    UnknownVariant,
93
94    PacketWrong,
95
96    SSHProto,
97
98    BadKeyFormat,
99
100    UnknownPacket { number: u8 },
101}
102
103impl From<WireError> for Error {
104    fn from(w: WireError) -> Self {
105        match w {
106            WireError::NoRoom => error::NoRoom.build(),
107            WireError::RanOut => error::RanOut.build(),
108            WireError::BadString => Error::BadString,
109            WireError::BadName => Error::BadName,
110            WireError::SSHProto => error::SSHProto.build(),
111            WireError::PacketWrong => error::PacketWrong.build(),
112            WireError::BadKeyFormat => Error::BadKeyFormat,
113            WireError::UnknownVariant => Error::bug_err_msg("Can't encode Unknown"),
114            WireError::UnknownPacket { number } => Error::UnknownPacket { number },
115        }
116    }
117}
118
119pub type WireResult<T> = core::result::Result<T, WireError>;
120
121///////////////////////////////////////////////
122
123/// Parses a [`Packet`] from a borrowed `&[u8]` byte buffer.
124pub fn packet_from_bytes<'a>(b: &'a [u8], ctx: &ParseContext) -> Result<Packet<'a>> {
125    let ctx = ParseContext { seen_unknown: false, ..ctx.clone() };
126    let mut s = DecodeBytes { input: b, parse_ctx: ctx };
127    let p = Packet::dec(&mut s)?;
128
129    if s.input.len() != 0 && !s.ctx().seen_unknown {
130        // No length check if the packet had an unknown variant
131        // - it skipped parsing the remainder of the packet.
132        Err(Error::WrongPacketLength)
133    } else {
134        Ok(p)
135    }
136}
137
138pub fn read_ssh<'a, T: SSHDecode<'a>>(
139    b: &'a [u8],
140    ctx: Option<ParseContext>,
141) -> Result<T> {
142    let mut s = DecodeBytes { input: b, parse_ctx: ctx.unwrap_or_default() };
143    Ok(T::dec(&mut s)?)
144}
145
146pub fn write_ssh(target: &mut [u8], value: &dyn SSHEncode) -> Result<usize> {
147    let mut s = EncodeBytes { target };
148    value.enc(&mut s)?;
149    let end_len = s.target.len();
150    debug_assert!(target.len() >= end_len);
151    Ok(target.len() - end_len)
152}
153
154#[cfg(feature = "std")]
155pub fn ssh_push_vec(target: &mut Vec<u8>, value: &dyn SSHEncode) -> Result<()> {
156    let orig = target.len();
157    let l = length_enc(value)? as usize;
158    target.resize(orig + l, 0);
159    write_ssh(&mut target[orig..], value)?;
160    Ok(())
161}
162
163/// Hashes the SSH wire format representation of `value`, with a `u32` length prefix.
164pub fn hash_ser_length(
165    hash_ctx: &mut impl SSHWireDigestUpdate,
166    value: &dyn SSHEncode,
167) -> Result<()> {
168    let len: u32 = length_enc(value)?;
169    hash_ctx.digest_update(&len.to_be_bytes());
170    hash_ser(hash_ctx, value)
171}
172
173/// Hashes the SSH wire format representation of `value`
174///
175/// Will only fail if `value.enc()` can return an error.
176pub fn hash_ser(
177    hash_ctx: &mut impl SSHWireDigestUpdate,
178    value: &dyn SSHEncode,
179) -> Result<()> {
180    let mut s = EncodeHash { hash_ctx };
181    value.enc(&mut s)?;
182    Ok(())
183}
184
185/// Returns `WireError::NoRoom` if larger than `u32`
186pub fn length_enc(value: &dyn SSHEncode) -> WireResult<u32> {
187    let mut s = EncodeLen { pos: 0 };
188    value.enc(&mut s)?;
189    s.pos.try_into().map_err(|_| WireError::NoRoom)
190}
191
192struct EncodeBytes<'a> {
193    target: &'a mut [u8],
194}
195
196impl<'a> SSHSink for EncodeBytes<'a> {
197    fn push(&mut self, v: &[u8]) -> WireResult<()> {
198        if v.len() > self.target.len() {
199            return Err(WireError::NoRoom);
200        }
201        // keep the borrow checker happy
202        let tmp = core::mem::replace(&mut self.target, &mut []);
203        let t;
204        (t, self.target) = tmp.split_at_mut(v.len());
205        t.copy_from_slice(v);
206        Ok(())
207    }
208}
209
210struct EncodeLen {
211    pos: usize,
212}
213
214impl SSHSink for EncodeLen {
215    fn push(&mut self, v: &[u8]) -> WireResult<()> {
216        self.pos = self.pos.checked_add(v.len()).ok_or(WireError::NoRoom)?;
217        Ok(())
218    }
219}
220
221struct EncodeHash<'a> {
222    hash_ctx: &'a mut dyn SSHWireDigestUpdate,
223}
224
225impl SSHSink for EncodeHash<'_> {
226    fn push(&mut self, v: &[u8]) -> WireResult<()> {
227        self.hash_ctx.digest_update(v);
228        Ok(())
229    }
230}
231
232struct DecodeBytes<'a> {
233    input: &'a [u8],
234    parse_ctx: ParseContext,
235}
236
237impl<'de> SSHSource<'de> for DecodeBytes<'de> {
238    fn take(&mut self, len: usize) -> WireResult<&'de [u8]> {
239        if len > self.input.len() {
240            return Err(WireError::RanOut);
241        }
242        let t;
243        (t, self.input) = self.input.split_at(len);
244        Ok(t)
245    }
246
247    fn remaining(&self) -> usize {
248        self.input.len()
249    }
250
251    fn ctx(&mut self) -> &mut ParseContext {
252        &mut self.parse_ctx
253    }
254}
255
256// Hashes a slice to be treated as a mpint. Has u32 length prefix
257// and an extra 0x00 byte if the MSB is set.
258pub fn hash_mpint(hash_ctx: &mut dyn SSHWireDigestUpdate, m: &[u8]) {
259    let pad = !m.is_empty() && (m[0] & 0x80) != 0;
260    let l = m.len() as u32 + pad as u32;
261    hash_ctx.digest_update(&l.to_be_bytes());
262    if pad {
263        hash_ctx.digest_update(&[0x00]);
264    }
265    hash_ctx.digest_update(m);
266}
267
268///////////////////////////////////////////////
269
270/// A SSH style binary string. Serialized as `u32` length followed by the bytes
271/// of the slice.
272/// Application API
273#[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
274#[cfg_attr(feature = "arbitrary", derive(Arbitrary))]
275pub struct BinString<'a>(pub &'a [u8]);
276
277impl AsRef<[u8]> for BinString<'_> {
278    fn as_ref(&self) -> &[u8] {
279        self.0
280    }
281}
282
283impl Debug for BinString<'_> {
284    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
285        write!(f, "BinString(len={})", self.0.len())
286    }
287}
288
289impl SSHEncode for BinString<'_> {
290    fn enc(&self, s: &mut dyn SSHSink) -> WireResult<()> {
291        (self.0.len() as u32).enc(s)?;
292        self.0.enc(s)
293    }
294}
295
296impl<'de> SSHDecode<'de> for BinString<'de> {
297    fn dec<S>(s: &mut S) -> WireResult<Self>
298    where
299        S: sshwire::SSHSource<'de>,
300    {
301        let len = u32::dec(s)? as usize;
302        Ok(BinString(s.take(len)?))
303    }
304}
305
306impl<const N: usize> SSHEncode for heapless::String<N> {
307    fn enc(&self, s: &mut dyn SSHSink) -> WireResult<()> {
308        self.as_str().enc(s)
309    }
310}
311
312/// A text string that may be presented to a user or used
313/// for things such as a password, username, exec command, TCP hostname, etc.
314///
315/// The SSH protocol defines it to be UTF-8, though
316/// in some applications it could be treated as ASCII-only.
317/// Sunset treats it as an opaque `&[u8]`, leaving
318/// interpretation to the application.
319///
320/// Note that SSH protocol identifiers in [`Packet`]
321/// are `&str` rather than `TextString`, and always defined as ASCII. For
322/// example `"publickey"`, `"ssh-ed25519"`.
323///
324/// Application API
325#[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
326#[cfg_attr(feature = "arbitrary", derive(Arbitrary))]
327pub struct TextString<'a>(pub &'a [u8]);
328
329impl<'a> TextString<'a> {
330    /// Returns the UTF-8 decoded string, using [`core::str::from_utf8`]
331    ///
332    /// Don't call this if you are avoiding including UTF-8 routines in
333    /// the binary.
334    pub fn as_str(&self) -> Result<&'a str> {
335        core::str::from_utf8(self.0).map_err(|_| Error::BadString)
336    }
337
338    pub fn as_ascii(&self) -> Result<&'a str> {
339        self.0.as_ascii_str().map_err(|_| Error::BadString).map(|s| s.as_str())
340    }
341}
342
343impl<'a> AsRef<[u8]> for TextString<'a> {
344    fn as_ref(&self) -> &'a [u8] {
345        self.0
346    }
347}
348
349impl<'a> From<&'a str> for TextString<'a> {
350    fn from(s: &'a str) -> Self {
351        TextString(s.as_bytes())
352    }
353}
354
355impl Debug for TextString<'_> {
356    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
357        let s = core::str::from_utf8(self.0);
358        if let Ok(s) = s {
359            write!(f, "TextString(\"{}\")", s.escape_default())
360        } else {
361            write!(f, "TextString(not utf8!, {:#?})", self.0.hex_dump())
362        }
363    }
364}
365
366impl Display for TextString<'_> {
367    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
368        let s = core::str::from_utf8(self.0);
369        if let Ok(s) = s {
370            write!(f, "\"{}\"", s.escape_default())
371        } else {
372            write!(f, "{:?}", self)
373        }
374    }
375}
376
377impl SSHEncode for TextString<'_> {
378    fn enc(&self, s: &mut dyn SSHSink) -> WireResult<()> {
379        (self.0.len() as u32).enc(s)?;
380        self.0.enc(s)
381    }
382}
383
384impl<'de> SSHDecode<'de> for TextString<'de> {
385    fn dec<S>(s: &mut S) -> WireResult<Self>
386    where
387        S: sshwire::SSHSource<'de>,
388    {
389        let len = u32::dec(s)? as usize;
390        Ok(TextString(s.take(len)?))
391    }
392}
393
394/// A wrapper for a `u32` length prefixed data structure `B`, such as a public key blob
395#[derive(PartialEq, Clone)]
396pub struct Blob<B>(pub B);
397
398#[cfg(feature = "arbitrary")]
399impl<'arb: 'a, 'a, B: Arbitrary<'arb>> Arbitrary<'arb> for Blob<B> {
400    fn arbitrary(u: &mut Unstructured<'arb>) -> arbitrary::Result<Self> {
401        Ok(Blob(Arbitrary::arbitrary(u)?))
402    }
403}
404
405impl<B> AsRef<B> for Blob<B> {
406    fn as_ref(&self) -> &B {
407        &self.0
408    }
409}
410
411impl<T: SSHEncode> SSHEncode for &T {
412    fn enc(&self, s: &mut dyn SSHSink) -> WireResult<()> {
413        (*self).enc(s)
414    }
415}
416
417impl<B: SSHEncode + Debug> Debug for Blob<B> {
418    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
419        if let Ok(len) = sshwire::length_enc(&self.0) {
420            write!(f, "Blob(len={len}, {:?})", self.0)
421        } else {
422            write!(f, "Blob(len>u32, {:?})", self.0)
423        }
424    }
425}
426
427impl<B: SSHEncode> SSHEncode for Blob<B> {
428    fn enc(&self, s: &mut dyn SSHSink) -> WireResult<()> {
429        let len: u32 = sshwire::length_enc(&self.0)?;
430        len.enc(s)?;
431        self.0.enc(s)
432    }
433}
434
435impl<'de, B: SSHDecode<'de>> SSHDecode<'de> for Blob<B> {
436    fn dec<S>(s: &mut S) -> WireResult<Self>
437    where
438        S: sshwire::SSHSource<'de>,
439    {
440        let len = u32::dec(s)? as usize;
441        let rem1 = s.remaining();
442        let inner = SSHDecode::dec(s)?;
443        let rem2 = s.remaining();
444
445        // Sanity check the length matched
446        let used_len = rem1 - rem2;
447        if used_len != len {
448            if s.ctx().seen_unknown {
449                // Skip over unconsumed bytes in the blob.
450                // This can occur with Unknown variants
451                let extra = len.checked_sub(used_len).ok_or(WireError::SSHProto)?;
452                s.take(extra)?;
453            } else {
454                trace!(
455                    "SSH blob length differs. \
456                    Expected {} bytes, got {} remaining {}, {}",
457                    len,
458                    used_len,
459                    rem1,
460                    rem2
461                );
462                return Err(WireError::SSHProto);
463            }
464        }
465        Ok(Blob(inner))
466    }
467}
468
469///////////////////////////////////////////////
470
471impl SSHEncode for u8 {
472    fn enc(&self, s: &mut dyn SSHSink) -> WireResult<()> {
473        s.push(&[*self])
474    }
475}
476
477impl SSHEncode for bool {
478    fn enc(&self, s: &mut dyn SSHSink) -> WireResult<()> {
479        (*self as u8).enc(s)
480    }
481}
482
483impl SSHEncode for u32 {
484    fn enc(&self, s: &mut dyn SSHSink) -> WireResult<()> {
485        s.push(&self.to_be_bytes())
486    }
487}
488
489// no length prefix
490impl SSHEncode for &[u8] {
491    fn enc(&self, s: &mut dyn SSHSink) -> WireResult<()> {
492        // data
493        s.push(self)
494    }
495}
496
497// no length prefix
498impl<const N: usize> SSHEncode for [u8; N] {
499    fn enc(&self, s: &mut dyn SSHSink) -> WireResult<()> {
500        s.push(self.as_slice())
501    }
502}
503
504impl SSHEncode for &str {
505    fn enc(&self, s: &mut dyn SSHSink) -> WireResult<()> {
506        let v = self.as_bytes();
507        // length prefix
508        (v.len() as u32).enc(s)?;
509        s.push(v)
510    }
511}
512
513impl<T: SSHEncode> SSHEncode for Option<T> {
514    fn enc(&self, s: &mut dyn SSHSink) -> WireResult<()> {
515        if let Some(t) = self.as_ref() {
516            t.enc(s)?;
517        }
518        Ok(())
519    }
520}
521
522impl SSHEncode for &AsciiStr {
523    fn enc(&self, s: &mut dyn SSHSink) -> WireResult<()> {
524        let v = self.as_bytes();
525        BinString(v).enc(s)
526    }
527}
528
529impl<'de> SSHDecode<'de> for bool {
530    fn dec<S>(s: &mut S) -> WireResult<Self>
531    where
532        S: SSHSource<'de>,
533    {
534        Ok(u8::dec(s)? != 0)
535    }
536}
537
538impl<'de> SSHDecode<'de> for u8 {
539    fn dec<S>(s: &mut S) -> WireResult<Self>
540    where
541        S: SSHSource<'de>,
542    {
543        let t = s.take(core::mem::size_of::<u8>())?;
544        Ok(u8::from_be_bytes(t.try_into().unwrap()))
545    }
546}
547
548impl<'de> SSHDecode<'de> for u32 {
549    fn dec<S>(s: &mut S) -> WireResult<Self>
550    where
551        S: SSHSource<'de>,
552    {
553        let t = s.take(core::mem::size_of::<u32>())?;
554        Ok(u32::from_be_bytes(t.try_into().unwrap()))
555    }
556}
557
558/// Decodes a SSH name string. Must be ASCII
559/// without control characters. RFC4251 section 6.
560pub fn try_as_ascii(t: &[u8]) -> WireResult<&AsciiStr> {
561    let n = t.as_ascii_str().map_err(|_| WireError::BadName)?;
562    if n.chars().any(|ch| ch.is_ascii_control() || ch == AsciiChar::DEL) {
563        return Err(WireError::BadName);
564    }
565    Ok(n)
566}
567
568pub fn try_as_ascii_str(t: &[u8]) -> WireResult<&str> {
569    try_as_ascii(t).map(AsciiStr::as_str)
570}
571
572impl<'de: 'a, 'a> SSHDecode<'de> for &'a str {
573    fn dec<S>(s: &mut S) -> WireResult<Self>
574    where
575        S: SSHSource<'de>,
576    {
577        let len = u32::dec(s)?;
578        let t = s.take(len as usize)?;
579        try_as_ascii_str(t)
580    }
581}
582
583impl<'de: 'a, 'a> SSHDecode<'de> for &'de AsciiStr {
584    fn dec<S>(s: &mut S) -> WireResult<&'de AsciiStr>
585    where
586        S: SSHSource<'de>,
587    {
588        let b: BinString = SSHDecode::dec(s)?;
589        try_as_ascii(b.0)
590    }
591}
592
593impl<'de, const N: usize> SSHDecode<'de> for &'de [u8; N] {
594    fn dec<S>(s: &mut S) -> WireResult<Self>
595    where
596        S: SSHSource<'de>,
597    {
598        // OK unwrap: take() fails if the length is short
599        Ok(s.take(N)?.try_into().unwrap())
600    }
601}
602
603impl<'de, const N: usize> SSHDecode<'de> for [u8; N] {
604    fn dec<S>(s: &mut S) -> WireResult<Self>
605    where
606        S: SSHSource<'de>,
607    {
608        // OK unwrap: take() fails if the length is short
609        Ok(s.take(N)?.try_into().unwrap())
610    }
611}
612
613impl<'de, const N: usize> SSHDecode<'de> for heapless::String<N> {
614    fn dec<S>(s: &mut S) -> WireResult<Self>
615    where
616        S: SSHSource<'de>,
617    {
618        heapless::String::from_str(SSHDecode::dec(s)?).map_err(|_| WireError::NoRoom)
619    }
620}
621
622/// Like `digest::DynDigest` but simpler.
623///
624/// Doesn't have any optional methods that depend on `alloc`.
625pub trait SSHWireDigestUpdate {
626    fn digest_update(&mut self, data: &[u8]);
627}
628
629impl SSHWireDigestUpdate for sha2::Sha256 {
630    fn digest_update(&mut self, data: &[u8]) {
631        self.update(data)
632    }
633}
634
635impl SSHWireDigestUpdate for sha2::Sha512 {
636    fn digest_update(&mut self, data: &[u8]) {
637        self.update(data)
638    }
639}
640
641#[cfg(feature = "rsa")]
642fn top_bit_set(b: &[u8]) -> bool {
643    b.first().unwrap_or(&0) & 0x80 != 0
644}
645
646#[cfg(feature = "rsa")]
647impl SSHEncode for rsa::BigUint {
648    fn enc(&self, s: &mut dyn SSHSink) -> WireResult<()> {
649        let b = self.to_bytes_be();
650        let b = b.as_slice();
651
652        // rfc4251 mpint, need a leading zero byte if top bit is set
653        let pad = top_bit_set(b);
654        let len = b.len() as u32 + pad as u32;
655        len.enc(s)?;
656
657        if pad {
658            0u8.enc(s)?;
659        }
660
661        b.enc(s)
662    }
663}
664
665#[cfg(feature = "rsa")]
666impl<'de> SSHDecode<'de> for rsa::BigUint {
667    fn dec<S>(s: &mut S) -> WireResult<Self>
668    where
669        S: SSHSource<'de>,
670    {
671        let b = BinString::dec(s)?;
672        if top_bit_set(b.0) {
673            trace!("received negative mpint");
674            return Err(WireError::BadKeyFormat);
675        }
676        Ok(rsa::BigUint::from_bytes_be(b.0))
677    }
678}
679
680// TODO: is there already something like this?
681pub enum OwnOrBorrow<'a, T> {
682    Own(T),
683    Borrow(&'a T),
684}
685
686impl<T: SSHEncode> SSHEncode for OwnOrBorrow<'_, T> {
687    fn enc(&self, s: &mut dyn SSHSink) -> WireResult<()> {
688        match self {
689            Self::Own(t) => t.enc(s),
690            Self::Borrow(t) => t.enc(s),
691        }
692    }
693}
694
695impl<'de, T: SSHDecode<'de>> SSHDecode<'de> for OwnOrBorrow<'_, T> {
696    fn dec<S>(s: &mut S) -> WireResult<Self>
697    where
698        S: SSHSource<'de>,
699    {
700        Ok(Self::Own(T::dec(s)?))
701    }
702}
703
704impl<'a, T> core::borrow::Borrow<T> for OwnOrBorrow<'a, T> {
705    fn borrow(&self) -> &T {
706        match self {
707            Self::Own(t) => t,
708            Self::Borrow(t) => t,
709        }
710    }
711}
712
713#[cfg(test)]
714pub(crate) mod tests {
715    use crate::*;
716    use error::Error;
717    use packets::*;
718    use pretty_hex::PrettyHex;
719    use sshwire::*;
720    use sunsetlog::init_test_log;
721
722    /// Checks that two items serialize the same
723    pub fn assert_serialize_equal<'de, T: SSHEncode>(p1: &T, p2: &T) {
724        let mut buf1 = vec![99; 2000];
725        let mut buf2 = vec![88; 1000];
726        let l1 = write_ssh(&mut buf1, p1).unwrap();
727        let l2 = write_ssh(&mut buf2, p2).unwrap();
728        buf1.truncate(l1);
729        buf2.truncate(l2);
730        assert_eq!(buf1, buf2);
731    }
732
733    #[test]
734    /// check that hash_ser_length() matches hashing a serialized message
735    fn test_hash_packet() {
736        use digest::Digest;
737        use sha2::Sha256;
738        let input = "hello";
739        let mut buf = vec![99; 20];
740        let w1 = write_ssh(&mut buf, &input).unwrap();
741        buf.truncate(w1);
742
743        // hash_ser_length
744        let mut hash_ctx = Sha256::new();
745        hash_ser_length(&mut hash_ctx, &input).unwrap();
746        let digest1 = hash_ctx.finalize();
747
748        let mut hash_ctx = Sha256::new();
749        hash_ctx.update(&(w1 as u32).to_be_bytes());
750        hash_ctx.update(&buf);
751        let digest2 = hash_ctx.finalize();
752
753        assert_eq!(digest1, digest2);
754
755        // hash_ser
756        let mut hash_ctx = Sha256::new();
757        hash_ctx.update(&(w1 as u32).to_be_bytes());
758        hash_ser(&mut hash_ctx, &input).unwrap();
759        let digest3 = hash_ctx.finalize();
760        assert_eq!(digest3, digest2);
761    }
762
763    pub fn test_roundtrip_context(p: &Packet, ctx: &ParseContext) {
764        let mut buf = vec![99; 200];
765        let l = write_ssh(&mut buf, p).unwrap();
766        buf.truncate(l);
767        trace!("wrote packet {:?}", buf.hex_dump());
768
769        let p2 = packet_from_bytes(&buf, &ctx).unwrap();
770        trace!("returned packet {:#?}", p2);
771        assert_serialize_equal(p, &p2);
772    }
773
774    /// With default context
775    pub fn test_roundtrip(p: &Packet) {
776        test_roundtrip_context(&p, &ParseContext::default());
777    }
778
779    /// Tests parsing a packet with a ParseContext.
780    #[test]
781    fn test_parse_context() {
782        init_test_log();
783        let mut ctx = ParseContext::new();
784
785        let p = Userauth60::PwChangeReq(UserauthPwChangeReq {
786            prompt: "change the password".into(),
787            lang: "".into(),
788        })
789        .into();
790        ctx.cli_auth_type = Some(auth::AuthType::Password);
791        test_roundtrip_context(&p, &ctx);
792
793        // PkOk is a more interesting case because the PubKey inside it is also
794        // an enum but that can identify its own enum variant.
795        let p = Userauth60::PkOk(UserauthPkOk {
796            algo: "ed25519",
797            key: Blob(PubKey::Ed25519(Ed25519PubKey { key: Blob([0x11; 32]) })),
798        })
799        .into();
800        ctx.cli_auth_type = Some(auth::AuthType::PubKey);
801        test_roundtrip_context(&p, &ctx);
802    }
803
804    // Some other blob decoding tests are in packets module
805
806    #[test]
807    fn wrong_blob_size() {
808        let p1 = Blob(BinString(b"hello"));
809
810        let mut buf1 = vec![88; 1000];
811        let l = write_ssh(&mut buf1, &p1).unwrap();
812        // some leeway
813        buf1.truncate(l + 5);
814        // make the length one extra
815        buf1[3] += 1;
816        let r: Result<Blob<BinString>, _> = read_ssh(&buf1, None);
817        assert!(matches!(r.unwrap_err(), Error::SSHProto { .. }));
818
819        let mut buf1 = vec![88; 1000];
820        let l = write_ssh(&mut buf1, &p1).unwrap();
821        // some leeway
822        buf1.truncate(l + 5);
823        // make the length one short
824        buf1[3] -= 1;
825        let r: Result<Blob<BinString>, _> = read_ssh(&buf1, None);
826        assert!(matches!(r.unwrap_err(), Error::SSHProto { .. }));
827    }
828
829    #[test]
830    fn wrong_packet_size() {
831        let p1 = packets::NewKeys {};
832        let p1: Packet = p1.into();
833        let ctx = ParseContext::new();
834
835        let mut buf1 = vec![88; 1000];
836        let l = write_ssh(&mut buf1, &p1).unwrap();
837
838        // too long
839        buf1.truncate(l + 1);
840        let r = packet_from_bytes(&buf1, &ctx);
841        assert!(matches!(r.unwrap_err(), Error::WrongPacketLength));
842
843        // success
844        buf1.truncate(l);
845        packet_from_bytes(&buf1, &ctx).unwrap();
846
847        // short
848        buf1.truncate(l - 1);
849        let r = packet_from_bytes(&buf1, &ctx);
850        assert!(matches!(r.unwrap_err(), Error::RanOut { .. }));
851    }
852
853    #[test]
854    fn overflow_encode() {
855        let mut buf1 = vec![22; 7];
856
857        assert_eq!(write_ssh(&mut buf1, &"").unwrap(), 4);
858        assert_eq!(write_ssh(&mut buf1, &"a").unwrap(), 5);
859        assert_eq!(write_ssh(&mut buf1, &"aa").unwrap(), 6);
860        assert_eq!(write_ssh(&mut buf1, &"aaa").unwrap(), 7);
861        assert!(matches!(
862            write_ssh(&mut buf1, &"aaaa").unwrap_err(),
863            Error::NoRoom { .. }
864        ));
865    }
866}