ssh/session/
mod.rs

1// pub(crate) use session_inner::SessionInner;
2mod session_broker;
3mod session_local;
4
5pub use session_broker::SessionBroker;
6pub use session_local::LocalSession;
7use tracing::*;
8
9use std::{
10    io::{Read, Write},
11    net::{TcpStream, ToSocketAddrs},
12    path::Path,
13    time::Duration,
14};
15
16use crate::{
17    algorithm::{Compress, Digest, Enc, Kex, Mac, PubKey},
18    client::Client,
19    config::{algorithm::AlgList, Config},
20    error::SshResult,
21    model::{Packet, SecPacket},
22};
23
24enum SessionState<S>
25where
26    S: Read + Write,
27{
28    Init(Config, S),
29    Version(Config, S),
30    Auth(Client, S),
31    Connected(Client, S),
32}
33
34pub struct SessionConnector<S>
35where
36    S: Read + Write,
37{
38    inner: SessionState<S>,
39}
40
41impl<S> SessionConnector<S>
42where
43    S: Read + Write,
44{
45    fn connect(self) -> SshResult<Self> {
46        match self.inner {
47            SessionState::Init(config, stream) => Self {
48                inner: SessionState::Version(config, stream),
49            }
50            .connect(),
51            SessionState::Version(mut config, mut stream) => {
52                info!("start for version negotiation.");
53                // Send Client version
54                config.ver.send_our_version(&mut stream)?;
55
56                // Receive the server version
57                config
58                    .ver
59                    .read_server_version(&mut stream, config.timeout)?;
60                // Version validate
61                config.ver.validate()?;
62
63                // from now on
64                // each step of the interaction is subject to the ssh constraints on the packet
65                // so we create a client to hide the underlay details
66                let client = Client::new(config);
67
68                Self {
69                    inner: SessionState::Auth(client, stream),
70                }
71                .connect()
72            }
73            SessionState::Auth(mut client, mut stream) => {
74                // before auth,
75                // we should have a key exchange at first
76                let mut digest = Digest::new();
77                let server_algs = SecPacket::from_stream(&mut stream, &mut client)?;
78                digest.hash_ctx.set_i_s(server_algs.get_inner());
79                let server_algs = AlgList::unpack(server_algs)?;
80                client.key_agreement(&mut stream, server_algs, &mut digest)?;
81                client.do_auth(&mut stream, &digest)?;
82                Ok(Self {
83                    inner: SessionState::Connected(client, stream),
84                })
85            }
86            _ => unreachable!(),
87        }
88    }
89
90    /// To run this ssh session on the local thread
91    ///
92    /// It will return a [LocalSession] which doesn't support multithread concurrency
93    ///
94    pub fn run_local(self) -> LocalSession<S> {
95        if let SessionState::Connected(client, stream) = self.inner {
96            LocalSession::new(client, stream)
97        } else {
98            unreachable!("Why you here?")
99        }
100    }
101
102    /// close the session and consume it
103    ///
104    pub fn close(self) {
105        drop(self)
106    }
107}
108
109impl<S> SessionConnector<S>
110where
111    S: Read + Write + Send + 'static,
112{
113    /// To spwan a new thread to run this ssh session
114    ///
115    /// It will return a [SessionBroker] which supports multithread concurrency
116    ///
117    pub fn run_backend(self) -> SessionBroker {
118        if let SessionState::Connected(client, stream) = self.inner {
119            SessionBroker::new(client, stream)
120        } else {
121            unreachable!("Why you here?")
122        }
123    }
124}
125
126#[derive(Default)]
127pub struct SessionBuilder {
128    config: Config,
129}
130
131impl SessionBuilder {
132    pub fn new() -> Self {
133        Self {
134            ..Default::default()
135        }
136    }
137
138    pub fn disable_default() -> Self {
139        Self {
140            config: Config::disable_default(),
141        }
142    }
143
144    /// Read/Write timeout for local SSH mode. Use None to disable timeout.
145    /// This is a global timeout only take effect after the session is established
146    ///
147    /// Use `connect_with_timeout` instead if you want to add timeout
148    /// when connect to the target SSH server
149    pub fn timeout(mut self, timeout: Option<Duration>) -> Self {
150        self.config.timeout = timeout;
151        self
152    }
153
154    pub fn username(mut self, username: &str) -> Self {
155        self.config.auth.username(username).unwrap();
156        self
157    }
158
159    pub fn password(mut self, password: &str) -> Self {
160        self.config.auth.password(password).unwrap();
161        self
162    }
163
164    pub fn private_key<K>(mut self, private_key: K) -> Self
165    where
166        K: ToString,
167    {
168        match self.config.auth.private_key(private_key) {
169            Ok(_) => (),
170            Err(e) => error!(
171                "Parse private key from string: {}, will fallback to password authentication",
172                e
173            ),
174        }
175        self
176    }
177
178    pub fn private_key_path<P>(mut self, key_path: P) -> Self
179    where
180        P: AsRef<Path>,
181    {
182        match self.config.auth.private_key_path(key_path) {
183            Ok(_) => (),
184            Err(e) => error!(
185                "Parse private key from file: {}, will fallback to password authentication",
186                e
187            ),
188        }
189        self
190    }
191
192    pub fn add_kex_algorithms(mut self, alg: Kex) -> Self {
193        self.config.algs.key_exchange.push(alg);
194        self
195    }
196
197    pub fn del_kex_algorithms(mut self, alg: Kex) -> Self {
198        self.config.algs.key_exchange.retain(|x| *x != alg);
199        self
200    }
201
202    pub fn add_pubkey_algorithms(mut self, alg: PubKey) -> Self {
203        self.config.algs.public_key.push(alg);
204        self
205    }
206
207    pub fn del_pubkey_algorithms(mut self, alg: PubKey) -> Self {
208        self.config.algs.public_key.retain(|x| *x != alg);
209        self
210    }
211
212    pub fn add_enc_algorithms(mut self, alg: Enc) -> Self {
213        self.config.algs.c_encryption.push(alg);
214        self.config.algs.s_encryption.push(alg);
215        self
216    }
217
218    pub fn del_enc_algorithms(mut self, alg: Enc) -> Self {
219        self.config.algs.c_encryption.retain(|x| *x != alg);
220        self.config.algs.s_encryption.retain(|x| *x != alg);
221        self
222    }
223
224    pub fn add_mac_algortihms(mut self, alg: Mac) -> Self {
225        self.config.algs.c_mac.push(alg);
226        self.config.algs.s_mac.push(alg);
227        self
228    }
229
230    pub fn del_mac_algortihms(mut self, alg: Mac) -> Self {
231        self.config.algs.c_mac.retain(|x| *x != alg);
232        self.config.algs.s_mac.retain(|x| *x != alg);
233        self
234    }
235
236    pub fn add_compress_algorithms(mut self, alg: Compress) -> Self {
237        self.config.algs.c_compress.push(alg);
238        self.config.algs.s_compress.push(alg);
239        self
240    }
241
242    pub fn del_compress_algorithms(mut self, alg: Compress) -> Self {
243        self.config.algs.c_compress.retain(|x| *x != alg);
244        self.config.algs.s_compress.retain(|x| *x != alg);
245        self
246    }
247
248    /// Create a TCP connection to the target server
249    ///
250    pub fn connect<A>(self, addr: A) -> SshResult<SessionConnector<TcpStream>>
251    where
252        A: ToSocketAddrs,
253    {
254        // connect tcp by default
255        let tcp = if let Some(ref to) = self.config.timeout {
256            TcpStream::connect_timeout(&addr.to_socket_addrs()?.next().unwrap(), *to)?
257        } else {
258            TcpStream::connect(addr)?
259        };
260
261        // default nonblocking
262        tcp.set_nonblocking(true).unwrap();
263        self.connect_bio(tcp)
264    }
265
266    /// Create a TCP connection to the target server, with timeout provided
267    ///
268    pub fn connect_with_timeout<A>(
269        self,
270        addr: A,
271        timeout: Option<Duration>,
272    ) -> SshResult<SessionConnector<TcpStream>>
273    where
274        A: ToSocketAddrs,
275    {
276        // connect tcp with custom connection timeout
277        let tcp = if let Some(ref to) = timeout {
278            TcpStream::connect_timeout(&addr.to_socket_addrs()?.next().unwrap(), *to)?
279        } else {
280            TcpStream::connect(addr)?
281        };
282
283        // default nonblocking
284        tcp.set_nonblocking(true).unwrap();
285        self.connect_bio(tcp)
286    }
287
288    /// connect to target server w/ a bio object
289    ///
290    /// which requires to implement `std::io::{Read, Write}`
291    ///
292    pub fn connect_bio<S>(mut self, stream: S) -> SshResult<SessionConnector<S>>
293    where
294        S: Read + Write,
295    {
296        self.config.tune_alglist_on_private_key();
297        SessionConnector {
298            inner: SessionState::Init(self.config, stream),
299        }
300        .connect()
301    }
302}