1use std::{
2 collections::HashMap,
3 io::{Read, Write},
4 sync::{
5 mpsc::{self, Receiver, Sender, TryRecvError},
6 Arc, Mutex,
7 },
8 thread::spawn,
9};
10
11use tracing::*;
12
13use crate::{
14 algorithm::Digest,
15 channel::{BackendChannel, ExecBroker},
16 client::Client,
17 config::algorithm::AlgList,
18 constant::{size, ssh_channel_fail_code, ssh_connection_code, ssh_str, ssh_transport_code},
19 error::{SshError, SshResult},
20 model::{ArcMut, BackendResp, BackendRqst, Data, Packet, SecPacket, U32Iter},
21 ChannelBroker, ShellBrocker, TerminalSize,
22};
23
24#[cfg(feature = "scp")]
25use crate::ScpBroker;
26
27pub struct SessionBroker {
28 channel_num: ArcMut<U32Iter>,
29 snd: Sender<BackendRqst>,
30}
31
32impl SessionBroker {
33 pub(crate) fn new<S>(client: Client, stream: S) -> Self
34 where
35 S: Read + Write + Send + 'static,
36 {
37 let (rqst_snd, rqst_rcv) = mpsc::channel();
38 spawn(move || {
39 if let Err(e) = client_loop(client, stream, rqst_rcv) {
40 error!("Error {:?} occurred when running backend task", e)
41 }
42 });
43 Self {
44 channel_num: Arc::new(Mutex::new(U32Iter::default())),
45 snd: rqst_snd,
46 }
47 }
48
49 pub fn close(self) {
52 info!("Client close");
53 drop(self)
54 }
55
56 pub fn open_exec(&mut self) -> SshResult<ExecBroker> {
59 let channel = self.open_channel()?;
60 channel.exec()
61 }
62
63 #[cfg(feature = "scp")]
66 pub fn open_scp(&mut self) -> SshResult<ScpBroker> {
67 let channel = self.open_channel()?;
68 channel.scp()
69 }
70
71 pub fn open_shell(&mut self) -> SshResult<ShellBrocker> {
74 self.open_shell_terminal(TerminalSize::from(80, 24))
75 }
76
77 pub fn open_shell_terminal(&mut self, tv: TerminalSize) -> SshResult<ShellBrocker> {
82 let channel = self.open_channel()?;
83 channel.shell(tv)
84 }
85
86 pub fn open_channel(&mut self) -> SshResult<ChannelBroker> {
91 let (resp_send, resp_recv) = mpsc::channel();
92 let client_id = self.channel_num.lock().unwrap().next().unwrap();
93
94 let mut data = Data::new();
96 data.put_u8(ssh_connection_code::CHANNEL_OPEN)
97 .put_str(ssh_str::SESSION)
98 .put_u32(client_id)
99 .put_u32(size::LOCAL_WINDOW_SIZE)
100 .put_u32(size::BUF_SIZE as u32);
101
102 self.snd
103 .send(BackendRqst::OpenChannel(client_id, data, resp_send))?;
104
105 match resp_recv.recv() {
107 Ok(resp) => match resp {
108 BackendResp::Ok(server_id) => Ok(ChannelBroker::new(
109 client_id,
110 server_id,
111 resp_recv,
112 self.snd.clone(),
113 )),
114 BackendResp::Fail(msg) => Err(SshError::GeneralError(msg)),
115 _ => unreachable!(),
116 },
117 Err(e) => Err(e.into()),
118 }
119 }
120}
121
122fn client_loop<S>(mut client: Client, mut stream: S, rcv: Receiver<BackendRqst>) -> SshResult<()>
123where
124 S: Read + Write,
125{
126 let mut channels = HashMap::<u32, BackendChannel>::new();
127 let mut pendings = HashMap::<u32, Sender<BackendResp>>::new();
128 client.set_timeout(None);
129 loop {
130 let try_recv = rcv.try_recv();
131 if try_recv.is_err() {
132 if let Err(TryRecvError::Disconnected) = try_recv {
133 info!("Session backend Closed");
134 return Ok(());
135 }
136 } else if let Ok(rqst) = try_recv {
137 match rqst {
138 BackendRqst::OpenChannel(id, data, sender) => {
139 info!("try open channel {}.", id);
140
141 data.pack(&mut client).write_stream(&mut stream)?;
142
143 assert!(pendings.insert(id, sender).is_none());
145 }
146 BackendRqst::Data(id, data) => {
147 let channel = channels.get_mut(&id).unwrap();
148
149 trace!("Channel {} send {} data", id, data.len());
150 channel.send_data(data, &mut client, &mut stream)?;
151 }
152 BackendRqst::Command(id, data) => {
153 let channel = channels.get_mut(&id).unwrap();
154
155 trace!("Channel {} send control data", id);
156 channel.send(data, &mut client, &mut stream)?;
157 }
158 BackendRqst::CloseChannel(id, data) => {
159 info!("try close channel {}.", id);
160
161 let channel = channels.get_mut(&id).unwrap();
162 channel.send(data, &mut client, &mut stream)?;
163 channel.local_close()?;
164 if channel.closed() {
165 channels.remove(&id);
166 }
167 }
168 }
169 }
170
171 if let Some(pkt) = SecPacket::try_from_stream(&mut stream, &mut client)? {
172 let mut data = Data::unpack(pkt)?;
173 let message_code = data.get_u8();
174
175 match message_code {
176 ssh_connection_code::CHANNEL_OPEN_CONFIRMATION => {
178 let client_channel_no = data.get_u32();
179 let server_channel_no = data.get_u32();
180 let remote_window_size = data.get_u32();
181 data.get_u32();
183
184 let sender = pendings.remove(&client_channel_no);
186 assert!(sender.is_some());
187
188 assert!(channels
190 .insert(
191 client_channel_no,
192 BackendChannel::new(
193 server_channel_no,
194 client_channel_no,
195 remote_window_size,
196 sender.unwrap()
197 )?
198 )
199 .is_none())
200 }
201 ssh_connection_code::CHANNEL_OPEN_FAILURE => {
210 let id = data.get_u32();
212
213 let sender = pendings.remove(&id);
214 assert!(sender.is_some());
215 let code = data.get_u32();
217 let description =
219 String::from_utf8(data.get_u8s()).unwrap_or_else(|_| String::from("error"));
220 data.get_u8s();
222
223 let err_msg = match code {
224 ssh_channel_fail_code::ADMINISTRATIVELY_PROHIBITED => {
225 format!("ADMINISTRATIVELY_PROHIBITED: {description}")
226 }
227 ssh_channel_fail_code::CONNECT_FAILED => {
228 format!("CONNECT_FAILED: {description}")
229 }
230 ssh_channel_fail_code::UNKNOWN_CHANNEL_TYPE => {
231 format!("UNKNOWN_CHANNEL_TYPE: {description}")
232 }
233 ssh_channel_fail_code::RESOURCE_SHORTAGE => {
234 format!("RESOURCE_SHORTAGE: {description}")
235 }
236 _ => description,
237 };
238 sender.unwrap().send(BackendResp::Fail(err_msg))?;
239 }
240 ssh_transport_code::KEXINIT => {
241 data.insert(0, message_code);
242 let mut digest = Digest::new();
243 digest.hash_ctx.set_i_s(&data);
244 let server_algs = AlgList::unpack((data, &mut client).into())?;
245 client.key_agreement(&mut stream, server_algs, &mut digest)?;
246 }
247 ssh_connection_code::CHANNEL_DATA => {
248 let id = data.get_u32();
249 trace!("Channel {id} get {} data", data.len());
250 let channel = channels.get_mut(&id).unwrap();
251 channel.recv(data, &mut client, &mut stream)?;
252 }
253 ssh_connection_code::CHANNEL_EXTENDED_DATA => {
254 let id = data.get_u32();
255 let data_type = data.get_u32();
256 trace!(
257 "Channel {id} get {} extended data, type {data_type}",
258 data.len(),
259 );
260 let channel = channels.get_mut(&id).unwrap();
261 channel.recv(data, &mut client, &mut stream)?;
262 }
263 ssh_connection_code::CHANNEL_WINDOW_ADJUST => {
265 let id = data.get_u32();
267 let rws = data.get_u32();
269 let channel = channels.get_mut(&id).unwrap();
270 channel.recv_window_adjust(rws, &mut client, &mut stream)?;
271 }
272 ssh_connection_code::CHANNEL_CLOSE => {
273 let id = data.get_u32();
274 info!("Channel {} recv close", id);
275 let channel = channels.get_mut(&id).unwrap();
276 channel.remote_close()?;
277 if channel.closed() {
278 channels.remove(&id);
279 }
280 }
281 ssh_connection_code::GLOBAL_REQUEST => {
282 let mut data = Data::new();
283 data.put_u8(ssh_connection_code::REQUEST_FAILURE);
284 data.pack(&mut client).write_stream(&mut stream)?;
285 continue;
286 }
287
288 x @ ssh_connection_code::CHANNEL_EOF => {
289 debug!("Currently ignore message {}", x);
290 }
291 ssh_connection_code::CHANNEL_REQUEST => {
292 let id = data.get_u32();
293 let channel = channels.get_mut(&id).unwrap();
294 let _ = channel.recv_rqst(data);
295 }
296 _x @ ssh_connection_code::CHANNEL_SUCCESS => {
297 let id = data.get_u32();
298 trace!("Channel {} control success", id);
299 let channel = channels.get_mut(&id).unwrap();
300 channel.success()?
301 }
302 ssh_connection_code::CHANNEL_FAILURE => {
303 let id = data.get_u32();
304 trace!("Channel {} control failed", id);
305 let channel = channels.get_mut(&id).unwrap();
306 channel.failed()?
307 }
308
309 x => {
310 debug!("Currently ignore message {}", x);
311 }
312 }
313 }
314 }
315}