ssh/session/
session_broker.rs

1use std::{
2    collections::HashMap,
3    io::{Read, Write},
4    sync::{
5        mpsc::{self, Receiver, Sender, TryRecvError},
6        Arc, Mutex,
7    },
8    thread::spawn,
9};
10
11use tracing::*;
12
13use crate::{
14    algorithm::Digest,
15    channel::{BackendChannel, ExecBroker},
16    client::Client,
17    config::algorithm::AlgList,
18    constant::{size, ssh_channel_fail_code, ssh_connection_code, ssh_str, ssh_transport_code},
19    error::{SshError, SshResult},
20    model::{ArcMut, BackendResp, BackendRqst, Data, Packet, SecPacket, U32Iter},
21    ChannelBroker, ShellBrocker, TerminalSize,
22};
23
24#[cfg(feature = "scp")]
25use crate::ScpBroker;
26
27pub struct SessionBroker {
28    channel_num: ArcMut<U32Iter>,
29    snd: Sender<BackendRqst>,
30}
31
32impl SessionBroker {
33    pub(crate) fn new<S>(client: Client, stream: S) -> Self
34    where
35        S: Read + Write + Send + 'static,
36    {
37        let (rqst_snd, rqst_rcv) = mpsc::channel();
38        spawn(move || {
39            if let Err(e) = client_loop(client, stream, rqst_rcv) {
40                error!("Error {:?} occurred when running backend task", e)
41            }
42        });
43        Self {
44            channel_num: Arc::new(Mutex::new(U32Iter::default())),
45            snd: rqst_snd,
46        }
47    }
48
49    /// close the backend session and consume the session broker itself
50    ///
51    pub fn close(self) {
52        info!("Client close");
53        drop(self)
54    }
55
56    /// open a [ExecBroker] channel which can excute commands
57    ///
58    pub fn open_exec(&mut self) -> SshResult<ExecBroker> {
59        let channel = self.open_channel()?;
60        channel.exec()
61    }
62
63    /// open a [ScpBroker] channel which can download/upload files/directories
64    ///
65    #[cfg(feature = "scp")]
66    pub fn open_scp(&mut self) -> SshResult<ScpBroker> {
67        let channel = self.open_channel()?;
68        channel.scp()
69    }
70
71    /// open a [ShellBrocker] channel which  can be used as a pseudo terminal (AKA PTY)
72    ///
73    pub fn open_shell(&mut self) -> SshResult<ShellBrocker> {
74        self.open_shell_terminal(TerminalSize::from(80, 24))
75    }
76
77    /// open a [ShellBrocker] channel
78    ///
79    /// custom terminal dimensions
80    ///
81    pub fn open_shell_terminal(&mut self, tv: TerminalSize) -> SshResult<ShellBrocker> {
82        let channel = self.open_channel()?;
83        channel.shell(tv)
84    }
85
86    /// open a raw channel
87    ///
88    /// need call `.exec()`, `.shell()`, `.scp()` and so on to convert it to a specific channel
89    ///
90    pub fn open_channel(&mut self) -> SshResult<ChannelBroker> {
91        let (resp_send, resp_recv) = mpsc::channel();
92        let client_id = self.channel_num.lock().unwrap().next().unwrap();
93
94        // open channel request
95        let mut data = Data::new();
96        data.put_u8(ssh_connection_code::CHANNEL_OPEN)
97            .put_str(ssh_str::SESSION)
98            .put_u32(client_id)
99            .put_u32(size::LOCAL_WINDOW_SIZE)
100            .put_u32(size::BUF_SIZE as u32);
101
102        self.snd
103            .send(BackendRqst::OpenChannel(client_id, data, resp_send))?;
104
105        // get the response
106        match resp_recv.recv() {
107            Ok(resp) => match resp {
108                BackendResp::Ok(server_id) => Ok(ChannelBroker::new(
109                    client_id,
110                    server_id,
111                    resp_recv,
112                    self.snd.clone(),
113                )),
114                BackendResp::Fail(msg) => Err(SshError::GeneralError(msg)),
115                _ => unreachable!(),
116            },
117            Err(e) => Err(e.into()),
118        }
119    }
120}
121
122fn client_loop<S>(mut client: Client, mut stream: S, rcv: Receiver<BackendRqst>) -> SshResult<()>
123where
124    S: Read + Write,
125{
126    let mut channels = HashMap::<u32, BackendChannel>::new();
127    let mut pendings = HashMap::<u32, Sender<BackendResp>>::new();
128    client.set_timeout(None);
129    loop {
130        let try_recv = rcv.try_recv();
131        if try_recv.is_err() {
132            if let Err(TryRecvError::Disconnected) = try_recv {
133                info!("Session backend Closed");
134                return Ok(());
135            }
136        } else if let Ok(rqst) = try_recv {
137            match rqst {
138                BackendRqst::OpenChannel(id, data, sender) => {
139                    info!("try open channel {}.", id);
140
141                    data.pack(&mut client).write_stream(&mut stream)?;
142
143                    // add to pending open list
144                    assert!(pendings.insert(id, sender).is_none());
145                }
146                BackendRqst::Data(id, data) => {
147                    let channel = channels.get_mut(&id).unwrap();
148
149                    trace!("Channel {} send {} data", id, data.len());
150                    channel.send_data(data, &mut client, &mut stream)?;
151                }
152                BackendRqst::Command(id, data) => {
153                    let channel = channels.get_mut(&id).unwrap();
154
155                    trace!("Channel {} send control data", id);
156                    channel.send(data, &mut client, &mut stream)?;
157                }
158                BackendRqst::CloseChannel(id, data) => {
159                    info!("try close channel {}.", id);
160
161                    let channel = channels.get_mut(&id).unwrap();
162                    channel.send(data, &mut client, &mut stream)?;
163                    channel.local_close()?;
164                    if channel.closed() {
165                        channels.remove(&id);
166                    }
167                }
168            }
169        }
170
171        if let Some(pkt) = SecPacket::try_from_stream(&mut stream, &mut client)? {
172            let mut data = Data::unpack(pkt)?;
173            let message_code = data.get_u8();
174
175            match message_code {
176                // Successfully open a channel
177                ssh_connection_code::CHANNEL_OPEN_CONFIRMATION => {
178                    let client_channel_no = data.get_u32();
179                    let server_channel_no = data.get_u32();
180                    let remote_window_size = data.get_u32();
181                    // remote packet size, currently don't need it
182                    data.get_u32();
183
184                    // remove from pending open list
185                    let sender = pendings.remove(&client_channel_no);
186                    assert!(sender.is_some());
187
188                    // add to opened list
189                    assert!(channels
190                        .insert(
191                            client_channel_no,
192                            BackendChannel::new(
193                                server_channel_no,
194                                client_channel_no,
195                                remote_window_size,
196                                sender.unwrap()
197                            )?
198                        )
199                        .is_none())
200                }
201                /*
202                    byte CHANNEL_OPEN_FAILURE
203                    uint32 recipient channel
204                    uint32 reason code
205                    string description,ISO-10646 UTF-8 [RFC3629]
206                    string language tag,[RFC3066]
207                */
208                // Fail to open a channel
209                ssh_connection_code::CHANNEL_OPEN_FAILURE => {
210                    //  client channel number
211                    let id = data.get_u32();
212
213                    let sender = pendings.remove(&id);
214                    assert!(sender.is_some());
215                    // error code
216                    let code = data.get_u32();
217                    // error detail: By default is utf-8
218                    let description =
219                        String::from_utf8(data.get_u8s()).unwrap_or_else(|_| String::from("error"));
220                    // language tag, assume to be en-US
221                    data.get_u8s();
222
223                    let err_msg = match code {
224                        ssh_channel_fail_code::ADMINISTRATIVELY_PROHIBITED => {
225                            format!("ADMINISTRATIVELY_PROHIBITED: {description}")
226                        }
227                        ssh_channel_fail_code::CONNECT_FAILED => {
228                            format!("CONNECT_FAILED: {description}")
229                        }
230                        ssh_channel_fail_code::UNKNOWN_CHANNEL_TYPE => {
231                            format!("UNKNOWN_CHANNEL_TYPE: {description}")
232                        }
233                        ssh_channel_fail_code::RESOURCE_SHORTAGE => {
234                            format!("RESOURCE_SHORTAGE: {description}")
235                        }
236                        _ => description,
237                    };
238                    sender.unwrap().send(BackendResp::Fail(err_msg))?;
239                }
240                ssh_transport_code::KEXINIT => {
241                    data.insert(0, message_code);
242                    let mut digest = Digest::new();
243                    digest.hash_ctx.set_i_s(&data);
244                    let server_algs = AlgList::unpack((data, &mut client).into())?;
245                    client.key_agreement(&mut stream, server_algs, &mut digest)?;
246                }
247                ssh_connection_code::CHANNEL_DATA => {
248                    let id = data.get_u32();
249                    trace!("Channel {id} get {} data", data.len());
250                    let channel = channels.get_mut(&id).unwrap();
251                    channel.recv(data, &mut client, &mut stream)?;
252                }
253                ssh_connection_code::CHANNEL_EXTENDED_DATA => {
254                    let id = data.get_u32();
255                    let data_type = data.get_u32();
256                    trace!(
257                        "Channel {id} get {} extended data, type {data_type}",
258                        data.len(),
259                    );
260                    let channel = channels.get_mut(&id).unwrap();
261                    channel.recv(data, &mut client, &mut stream)?;
262                }
263                // flow_control msg
264                ssh_connection_code::CHANNEL_WINDOW_ADJUST => {
265                    // client channel number
266                    let id = data.get_u32();
267                    // to_add
268                    let rws = data.get_u32();
269                    let channel = channels.get_mut(&id).unwrap();
270                    channel.recv_window_adjust(rws, &mut client, &mut stream)?;
271                }
272                ssh_connection_code::CHANNEL_CLOSE => {
273                    let id = data.get_u32();
274                    info!("Channel {} recv close", id);
275                    let channel = channels.get_mut(&id).unwrap();
276                    channel.remote_close()?;
277                    if channel.closed() {
278                        channels.remove(&id);
279                    }
280                }
281                ssh_connection_code::GLOBAL_REQUEST => {
282                    let mut data = Data::new();
283                    data.put_u8(ssh_connection_code::REQUEST_FAILURE);
284                    data.pack(&mut client).write_stream(&mut stream)?;
285                    continue;
286                }
287
288                x @ ssh_connection_code::CHANNEL_EOF => {
289                    debug!("Currently ignore message {}", x);
290                }
291                ssh_connection_code::CHANNEL_REQUEST => {
292                    let id = data.get_u32();
293                    let channel = channels.get_mut(&id).unwrap();
294                    let _ = channel.recv_rqst(data);
295                }
296                _x @ ssh_connection_code::CHANNEL_SUCCESS => {
297                    let id = data.get_u32();
298                    trace!("Channel {} control success", id);
299                    let channel = channels.get_mut(&id).unwrap();
300                    channel.success()?
301                }
302                ssh_connection_code::CHANNEL_FAILURE => {
303                    let id = data.get_u32();
304                    trace!("Channel {} control failed", id);
305                    let channel = channels.get_mut(&id).unwrap();
306                    channel.failed()?
307                }
308
309                x => {
310                    debug!("Currently ignore message {}", x);
311                }
312            }
313        }
314    }
315}