1use std::io::{Read, Write};
2use std::time::Duration;
3
4use crate::error::SshResult;
5use crate::{client::Client, model::Data};
6
7use super::timeout::Timeout;
8
9fn read_with_timeout<S>(stream: &mut S, tm: Option<Duration>, buf: &mut [u8]) -> SshResult<()>
53where
54 S: Read,
55{
56 let want_len = buf.len();
57 let mut offset = 0;
58 let mut timeout = Timeout::new(tm);
59
60 loop {
61 match stream.read(&mut buf[offset..]) {
62 Ok(i) => {
63 offset += i;
64 if offset == want_len {
65 return Ok(());
66 } else {
67 timeout.renew();
68 continue;
69 }
70 }
71 Err(e) => {
72 if let std::io::ErrorKind::WouldBlock = e.kind() {
73 timeout.till_next_tick()?;
74 continue;
75 } else {
76 return Err(e.into());
77 }
78 }
79 };
80 }
81}
82
83fn try_read<S>(stream: &mut S, _tm: Option<Duration>, buf: &mut [u8]) -> SshResult<usize>
84where
85 S: Read,
86{
87 match stream.read(buf) {
88 Ok(i) => Ok(i),
89 Err(e) => {
90 if let std::io::ErrorKind::WouldBlock = e.kind() {
91 Ok(0)
92 } else {
93 Err(e.into())
94 }
95 }
96 }
97}
98
99fn write_with_timeout<S>(stream: &mut S, tm: Option<Duration>, buf: &[u8]) -> SshResult<()>
100where
101 S: Write,
102{
103 let want_len = buf.len();
104 let mut offset = 0;
105 let mut timeout = Timeout::new(tm);
106
107 loop {
108 match stream.write(&buf[offset..]) {
109 Ok(i) => {
110 offset += i;
111 if offset == want_len {
112 return Ok(());
113 } else {
114 timeout.renew();
115 continue;
116 }
117 }
118 Err(e) => {
119 if let std::io::ErrorKind::WouldBlock = e.kind() {
120 timeout.till_next_tick()?;
121 continue;
122 } else {
123 return Err(e.into());
124 }
125 }
126 };
127 }
128}
129
130pub(crate) trait Packet<'a> {
131 fn pack(self, client: &'a mut Client) -> SecPacket<'a>;
132 fn unpack(pkt: SecPacket) -> SshResult<Self>
133 where
134 Self: Sized;
135}
136
137pub(crate) struct SecPacket<'a> {
138 payload: Data,
139 client: &'a mut Client,
140}
141
142impl<'a> SecPacket<'a> {
143 fn get_align(bsize: usize) -> i32 {
144 let bsize = bsize as i32;
145 if bsize > 8 {
146 bsize
147 } else {
148 8
149 }
150 }
151
152 pub fn write_stream<S>(self, stream: &mut S) -> SshResult<()>
153 where
154 S: Write,
155 {
156 let tm = self.client.get_timeout();
157 let payload = self.client.get_compressor().compress(&self.payload)?;
158 let payload_len = payload.len() as u32;
159 let pad_len = {
160 let mut pad = payload_len as i32 + 1;
161 let block_size = Self::get_align(self.client.get_encryptor().bsize());
162 if !self.client.get_encryptor().no_pad() {
163 pad += 4
164 }
165 (((-pad) & (block_size - 1)) + block_size) as u32
166 } as u8;
167 let packet_len = 1 + pad_len as u32 + payload_len;
168 let mut buf = vec![];
169 buf.extend(packet_len.to_be_bytes());
170 buf.extend([pad_len]);
171 buf.extend(payload);
172 buf.extend(vec![0; pad_len as usize]);
173 let seq = self.client.get_seq().get_client();
174 self.client.get_encryptor().encrypt(seq, &mut buf);
175 write_with_timeout(stream, tm, &buf)
176 }
177
178 pub fn from_stream<S>(stream: &mut S, client: &'a mut Client) -> SshResult<Self>
179 where
180 S: Read,
181 {
182 let tm = client.get_timeout();
183 let bsize = Self::get_align(client.get_encryptor().bsize()) as usize;
184
185 let mut first_block = vec![0; bsize];
187 read_with_timeout(stream, tm, &mut first_block)?;
188
189 let seq = client.get_seq().get_server();
191 let data_len = client.get_encryptor().data_len(seq, &first_block);
192
193 let mut data = Data::uninit_new(data_len);
195 data[0..bsize].clone_from_slice(&first_block);
196 read_with_timeout(stream, tm, &mut data[bsize..])?;
197
198 let data = client.get_encryptor().decrypt(seq, &mut data)?;
200
201 let pkt_len = u32::from_be_bytes(data[0..4].try_into().unwrap());
203 let pad_len = data[4];
204 let payload_len = pkt_len - pad_len as u32 - 1;
205
206 let payload = data[5..payload_len as usize + 5].into();
207 let payload = client.get_compressor().decompress(payload)?.into();
208
209 Ok(Self { payload, client })
210 }
211
212 pub fn try_from_stream<S>(stream: &mut S, client: &'a mut Client) -> SshResult<Option<Self>>
213 where
214 S: Read,
215 {
216 let tm = client.get_timeout();
217 let bsize = Self::get_align(client.get_encryptor().bsize()) as usize;
218
219 let mut first_block = vec![0; bsize];
221 let read = try_read(stream, tm, &mut first_block)?;
222 if read == 0 {
223 return Ok(None);
224 }
225
226 let seq = client.get_seq().get_server();
228 let data_len = client.get_encryptor().data_len(seq, &first_block);
229
230 let mut data = Data::uninit_new(data_len);
232 data[0..bsize].clone_from_slice(&first_block);
233 read_with_timeout(stream, tm, &mut data[bsize..])?;
234
235 let data = client.get_encryptor().decrypt(seq, &mut data)?;
237
238 let pkt_len = u32::from_be_bytes(data[0..4].try_into().unwrap());
240 let pad_len = data[4];
241 let payload_len = pkt_len - pad_len as u32 - 1;
242
243 let payload = data[5..payload_len as usize + 5].into();
244
245 Ok(Some(Self { payload, client }))
246 }
247
248 pub fn get_inner(&self) -> &[u8] {
249 &self.payload
250 }
251
252 pub fn into_inner(self) -> Data {
253 self.payload
254 }
255}
256
257impl<'a> From<(Data, &'a mut Client)> for SecPacket<'a> {
258 fn from((d, c): (Data, &'a mut Client)) -> Self {
259 Self {
260 payload: d,
261 client: c,
262 }
263 }
264}