ssh/model/
packet.rs

1use std::io::{Read, Write};
2use std::time::Duration;
3
4use crate::error::SshResult;
5use crate::{client::Client, model::Data};
6
7use super::timeout::Timeout;
8
9/// ## Binary Packet Protocol
10///
11/// <https://www.rfc-editor.org/rfc/rfc4253#section-6>
12///
13/// uint32 `packet_length`
14///
15/// byte `padding_length`
16///
17/// byte[[n1]] `payload`; n1 = packet_length - padding_length - 1
18///
19/// byte[[n2]] `random padding`; n2 = padding_length
20///
21/// byte[[m]] `mac` (Message Authentication Code - MAC); m = mac_length
22///
23/// ---
24///
25/// **packet_length**
26/// The length of the packet in bytes, not including 'mac' or the 'packet_length' field itself.
27///
28///
29/// **padding_length**
30/// Length of 'random padding' (bytes).
31///
32///
33/// **payload**
34///  The useful contents of the packet.  If compression has been negotiated, this field is compressed.
35/// Initially, compression MUST be "none".
36///
37///
38/// **random padding**
39/// Arbitrary-length padding, such that the total length of
40/// (packet_length || padding_length || payload || random padding)
41/// is a multiple of the cipher block size or 8, whichever is
42/// larger.  There MUST be at least four bytes of padding.  The
43/// padding SHOULD consist of random bytes.  The maximum amount of
44/// padding is 255 bytes.
45
46///
47/// **mac**
48/// Message Authentication Code.  If message authentication has
49/// been negotiated, this field contains the MAC bytes.  Initially,
50/// the MAC algorithm MUST be "none".。
51
52fn read_with_timeout<S>(stream: &mut S, tm: Option<Duration>, buf: &mut [u8]) -> SshResult<()>
53where
54    S: Read,
55{
56    let want_len = buf.len();
57    let mut offset = 0;
58    let mut timeout = Timeout::new(tm);
59
60    loop {
61        match stream.read(&mut buf[offset..]) {
62            Ok(i) => {
63                offset += i;
64                if offset == want_len {
65                    return Ok(());
66                } else {
67                    timeout.renew();
68                    continue;
69                }
70            }
71            Err(e) => {
72                if let std::io::ErrorKind::WouldBlock = e.kind() {
73                    timeout.till_next_tick()?;
74                    continue;
75                } else {
76                    return Err(e.into());
77                }
78            }
79        };
80    }
81}
82
83fn try_read<S>(stream: &mut S, _tm: Option<Duration>, buf: &mut [u8]) -> SshResult<usize>
84where
85    S: Read,
86{
87    match stream.read(buf) {
88        Ok(i) => Ok(i),
89        Err(e) => {
90            if let std::io::ErrorKind::WouldBlock = e.kind() {
91                Ok(0)
92            } else {
93                Err(e.into())
94            }
95        }
96    }
97}
98
99fn write_with_timeout<S>(stream: &mut S, tm: Option<Duration>, buf: &[u8]) -> SshResult<()>
100where
101    S: Write,
102{
103    let want_len = buf.len();
104    let mut offset = 0;
105    let mut timeout = Timeout::new(tm);
106
107    loop {
108        match stream.write(&buf[offset..]) {
109            Ok(i) => {
110                offset += i;
111                if offset == want_len {
112                    return Ok(());
113                } else {
114                    timeout.renew();
115                    continue;
116                }
117            }
118            Err(e) => {
119                if let std::io::ErrorKind::WouldBlock = e.kind() {
120                    timeout.till_next_tick()?;
121                    continue;
122                } else {
123                    return Err(e.into());
124                }
125            }
126        };
127    }
128}
129
130pub(crate) trait Packet<'a> {
131    fn pack(self, client: &'a mut Client) -> SecPacket<'a>;
132    fn unpack(pkt: SecPacket) -> SshResult<Self>
133    where
134        Self: Sized;
135}
136
137pub(crate) struct SecPacket<'a> {
138    payload: Data,
139    client: &'a mut Client,
140}
141
142impl<'a> SecPacket<'a> {
143    fn get_align(bsize: usize) -> i32 {
144        let bsize = bsize as i32;
145        if bsize > 8 {
146            bsize
147        } else {
148            8
149        }
150    }
151
152    pub fn write_stream<S>(self, stream: &mut S) -> SshResult<()>
153    where
154        S: Write,
155    {
156        let tm = self.client.get_timeout();
157        let payload = self.client.get_compressor().compress(&self.payload)?;
158        let payload_len = payload.len() as u32;
159        let pad_len = {
160            let mut pad = payload_len as i32 + 1;
161            let block_size = Self::get_align(self.client.get_encryptor().bsize());
162            if !self.client.get_encryptor().no_pad() {
163                pad += 4
164            }
165            (((-pad) & (block_size - 1)) + block_size) as u32
166        } as u8;
167        let packet_len = 1 + pad_len as u32 + payload_len;
168        let mut buf = vec![];
169        buf.extend(packet_len.to_be_bytes());
170        buf.extend([pad_len]);
171        buf.extend(payload);
172        buf.extend(vec![0; pad_len as usize]);
173        let seq = self.client.get_seq().get_client();
174        self.client.get_encryptor().encrypt(seq, &mut buf);
175        write_with_timeout(stream, tm, &buf)
176    }
177
178    pub fn from_stream<S>(stream: &mut S, client: &'a mut Client) -> SshResult<Self>
179    where
180        S: Read,
181    {
182        let tm = client.get_timeout();
183        let bsize = Self::get_align(client.get_encryptor().bsize()) as usize;
184
185        // read the first block
186        let mut first_block = vec![0; bsize];
187        read_with_timeout(stream, tm, &mut first_block)?;
188
189        // detect the total len
190        let seq = client.get_seq().get_server();
191        let data_len = client.get_encryptor().data_len(seq, &first_block);
192
193        // read remain
194        let mut data = Data::uninit_new(data_len);
195        data[0..bsize].clone_from_slice(&first_block);
196        read_with_timeout(stream, tm, &mut data[bsize..])?;
197
198        // decrypt all
199        let data = client.get_encryptor().decrypt(seq, &mut data)?;
200
201        // unpacking
202        let pkt_len = u32::from_be_bytes(data[0..4].try_into().unwrap());
203        let pad_len = data[4];
204        let payload_len = pkt_len - pad_len as u32 - 1;
205
206        let payload = data[5..payload_len as usize + 5].into();
207        let payload = client.get_compressor().decompress(payload)?.into();
208
209        Ok(Self { payload, client })
210    }
211
212    pub fn try_from_stream<S>(stream: &mut S, client: &'a mut Client) -> SshResult<Option<Self>>
213    where
214        S: Read,
215    {
216        let tm = client.get_timeout();
217        let bsize = Self::get_align(client.get_encryptor().bsize()) as usize;
218
219        // read the first block
220        let mut first_block = vec![0; bsize];
221        let read = try_read(stream, tm, &mut first_block)?;
222        if read == 0 {
223            return Ok(None);
224        }
225
226        // detect the total len
227        let seq = client.get_seq().get_server();
228        let data_len = client.get_encryptor().data_len(seq, &first_block);
229
230        // read remain
231        let mut data = Data::uninit_new(data_len);
232        data[0..bsize].clone_from_slice(&first_block);
233        read_with_timeout(stream, tm, &mut data[bsize..])?;
234
235        // decrypt all
236        let data = client.get_encryptor().decrypt(seq, &mut data)?;
237
238        // unpacking
239        let pkt_len = u32::from_be_bytes(data[0..4].try_into().unwrap());
240        let pad_len = data[4];
241        let payload_len = pkt_len - pad_len as u32 - 1;
242
243        let payload = data[5..payload_len as usize + 5].into();
244
245        Ok(Some(Self { payload, client }))
246    }
247
248    pub fn get_inner(&self) -> &[u8] {
249        &self.payload
250    }
251
252    pub fn into_inner(self) -> Data {
253        self.payload
254    }
255}
256
257impl<'a> From<(Data, &'a mut Client)> for SecPacket<'a> {
258    fn from((d, c): (Data, &'a mut Client)) -> Self {
259        Self {
260            payload: d,
261            client: c,
262        }
263    }
264}