ssh_packet/
id.rs

1use crate::Error;
2
3const VERSION: &str = "2.0";
4
5/// The SSH identification string as defined in the SSH protocol.
6///
7/// The format must match the following pattern:
8/// `SSH-<protoversion>-<softwareversion>[ <comments>]`.
9///
10/// see <https://datatracker.ietf.org/doc/html/rfc4253#section-4.2>.
11#[derive(Debug, Clone, PartialEq, Eq, Hash)]
12pub struct Id {
13    /// The SSH's protocol version, should be `2.0` in our case.
14    pub protoversion: String,
15
16    /// A string identifying the software curently used, in example `billsSSH_3.6.3q3`.
17    pub softwareversion: String,
18
19    /// Optional comments with additionnal informations about the software.
20    pub comments: Option<String>,
21}
22
23impl Id {
24    /// Convenience method to create an `SSH-2.0` identifier string.
25    pub fn v2(softwareversion: impl Into<String>, comments: Option<impl Into<String>>) -> Self {
26        Self {
27            protoversion: VERSION.into(),
28            softwareversion: softwareversion.into(),
29            comments: comments.map(Into::into),
30        }
31    }
32
33    #[cfg(feature = "futures")]
34    #[cfg_attr(docsrs, doc(cfg(feature = "futures")))]
35    /// Read an [`Id`], discarding any _extra lines_ sent by the server
36    /// from the provided asynchronous `reader`.
37    pub async fn from_reader<R>(reader: &mut R) -> Result<Self, Error>
38    where
39        R: futures::io::AsyncBufRead + Unpin,
40    {
41        use futures::TryStreamExt;
42
43        let text = futures::io::AsyncBufReadExt::lines(reader)
44            // Skip extra lines the server can send before identifying
45            .try_skip_while(|line| futures::future::ok(!line.starts_with("SSH")))
46            .try_next()
47            .await?
48            .ok_or(Error::UnexpectedEof)?;
49
50        text.parse()
51    }
52
53    #[cfg(feature = "futures")]
54    #[cfg_attr(docsrs, doc(cfg(feature = "futures")))]
55    /// Write the [`Id`] to the provided asynchronous `writer`.
56    pub async fn to_writer<W>(&self, writer: &mut W) -> Result<(), Error>
57    where
58        W: futures::io::AsyncWrite + Unpin,
59    {
60        use futures::io::AsyncWriteExt;
61
62        writer.write_all(self.to_string().as_bytes()).await?;
63        writer.write_all(b"\r\n").await?;
64
65        Ok(())
66    }
67}
68
69impl std::fmt::Display for Id {
70    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
71        write!(f, "SSH-{}-{}", self.protoversion, self.softwareversion)?;
72
73        if let Some(comments) = &self.comments {
74            write!(f, " {comments}")?;
75        }
76
77        Ok(())
78    }
79}
80
81impl std::str::FromStr for Id {
82    type Err = Error;
83
84    fn from_str(s: &str) -> Result<Self, Self::Err> {
85        let (id, comments) = s
86            .split_once(' ')
87            .map_or_else(|| (s, None), |(id, comments)| (id, Some(comments)));
88
89        match id.splitn(3, '-').collect::<Vec<_>>()[..] {
90            ["SSH", protoversion, softwareversion]
91                if !protoversion.is_empty() && !softwareversion.is_empty() =>
92            {
93                Ok(Self {
94                    protoversion: protoversion.to_string(),
95                    softwareversion: softwareversion.to_string(),
96                    comments: comments.map(str::to_string),
97                })
98            }
99            _ => Err(Error::BadIdentifer(s.into())),
100        }
101    }
102}
103
104#[cfg(test)]
105mod tests {
106    #![allow(clippy::unwrap_used, clippy::unimplemented)]
107    use rstest::rstest;
108    use std::str::FromStr;
109
110    use super::*;
111
112    impl PartialEq for Error {
113        fn eq(&self, other: &Self) -> bool {
114            match (self, other) {
115                (Self::Io(l0), Self::Io(r0)) => l0.kind() == r0.kind(),
116                _ => core::mem::discriminant(self) == core::mem::discriminant(other),
117            }
118        }
119    }
120
121    #[rstest]
122    #[case("SSH-2.0-billsSSH_3.6.3q3")]
123    #[case("SSH-1.99-billsSSH_3.6.3q3")]
124    #[case("SSH-2.0-billsSSH_3.6.3q3 with-comment")]
125    #[case("SSH-2.0-billsSSH_3.6.3q3 utf∞-comment")]
126    #[case("SSH-2.0-billsSSH_3.6.3q3 ")] // empty comment
127    fn it_parses_valid(#[case] text: &str) {
128        Id::from_str(text).expect(text);
129    }
130
131    #[rstest]
132    #[case("")]
133    #[case("FOO-2.0-billsSSH_3.6.3q3")]
134    #[case("-2.0-billsSSH_3.6.3q3")]
135    #[case("SSH--billsSSH_3.6.3q3")]
136    #[case("SSH-2.0-")]
137    fn it_rejects_invalid(#[case] text: &str) {
138        Id::from_str(text).expect_err(text);
139    }
140
141    #[rstest]
142    #[case(Id::v2("billsSSH_3.6.3q3", None::<String>))]
143    #[case(Id::v2("billsSSH_utf∞", None::<String>))]
144    #[case(Id::v2("billsSSH_3.6.3q3", Some("with-comment")))]
145    #[case(Id::v2("billsSSH_3.6.3q3", Some("utf∞-comment")))]
146    #[case(Id::v2("billsSSH_3.6.3q3", Some("")))] // empty comment
147    fn it_reparses_consistently(#[case] id: Id) {
148        assert_eq!(id, id.to_string().parse().unwrap());
149    }
150}