1mod session_broker;
3mod session_local;
4
5pub use session_broker::SessionBroker;
6pub use session_local::LocalSession;
7use tracing::*;
8
9use std::{
10 io::{Read, Write},
11 net::{TcpStream, ToSocketAddrs},
12 path::Path,
13 time::Duration,
14};
15
16use crate::{
17 algorithm::{Compress, Digest, Enc, Kex, Mac, PubKey},
18 client::Client,
19 config::{algorithm::AlgList, Config},
20 error::SshResult,
21 model::{Packet, SecPacket},
22};
23
24enum SessionState<S>
25where
26 S: Read + Write,
27{
28 Init(Config, S),
29 Version(Config, S),
30 Auth(Client, S),
31 Connected(Client, S),
32}
33
34pub struct SessionConnector<S>
35where
36 S: Read + Write,
37{
38 inner: SessionState<S>,
39}
40
41impl<S> SessionConnector<S>
42where
43 S: Read + Write,
44{
45 fn connect(self) -> SshResult<Self> {
46 match self.inner {
47 SessionState::Init(config, stream) => Self {
48 inner: SessionState::Version(config, stream),
49 }
50 .connect(),
51 SessionState::Version(mut config, mut stream) => {
52 info!("start for version negotiation.");
53 config.ver.send_our_version(&mut stream)?;
55
56 config
58 .ver
59 .read_server_version(&mut stream, config.timeout)?;
60 config.ver.validate()?;
62
63 let client = Client::new(config);
67
68 Self {
69 inner: SessionState::Auth(client, stream),
70 }
71 .connect()
72 }
73 SessionState::Auth(mut client, mut stream) => {
74 let mut digest = Digest::new();
77 let server_algs = SecPacket::from_stream(&mut stream, &mut client)?;
78 digest.hash_ctx.set_i_s(server_algs.get_inner());
79 let server_algs = AlgList::unpack(server_algs)?;
80 client.key_agreement(&mut stream, server_algs, &mut digest)?;
81 client.do_auth(&mut stream, &digest)?;
82 Ok(Self {
83 inner: SessionState::Connected(client, stream),
84 })
85 }
86 _ => unreachable!(),
87 }
88 }
89
90 pub fn run_local(self) -> LocalSession<S> {
95 if let SessionState::Connected(client, stream) = self.inner {
96 LocalSession::new(client, stream)
97 } else {
98 unreachable!("Why you here?")
99 }
100 }
101
102 pub fn close(self) {
105 drop(self)
106 }
107}
108
109impl<S> SessionConnector<S>
110where
111 S: Read + Write + Send + 'static,
112{
113 pub fn run_backend(self) -> SessionBroker {
118 if let SessionState::Connected(client, stream) = self.inner {
119 SessionBroker::new(client, stream)
120 } else {
121 unreachable!("Why you here?")
122 }
123 }
124}
125
126#[derive(Default)]
127pub struct SessionBuilder {
128 config: Config,
129}
130
131impl SessionBuilder {
132 pub fn new() -> Self {
133 Self {
134 ..Default::default()
135 }
136 }
137
138 pub fn disable_default() -> Self {
139 Self {
140 config: Config::disable_default(),
141 }
142 }
143
144 pub fn timeout(mut self, timeout: Option<Duration>) -> Self {
150 self.config.timeout = timeout;
151 self
152 }
153
154 pub fn username(mut self, username: &str) -> Self {
155 self.config.auth.username(username).unwrap();
156 self
157 }
158
159 pub fn password(mut self, password: &str) -> Self {
160 self.config.auth.password(password).unwrap();
161 self
162 }
163
164 pub fn private_key<K>(mut self, private_key: K) -> Self
165 where
166 K: ToString,
167 {
168 match self.config.auth.private_key(private_key) {
169 Ok(_) => (),
170 Err(e) => error!(
171 "Parse private key from string: {}, will fallback to password authentication",
172 e
173 ),
174 }
175 self
176 }
177
178 pub fn private_key_path<P>(mut self, key_path: P) -> Self
179 where
180 P: AsRef<Path>,
181 {
182 match self.config.auth.private_key_path(key_path) {
183 Ok(_) => (),
184 Err(e) => error!(
185 "Parse private key from file: {}, will fallback to password authentication",
186 e
187 ),
188 }
189 self
190 }
191
192 pub fn add_kex_algorithms(mut self, alg: Kex) -> Self {
193 self.config.algs.key_exchange.push(alg);
194 self
195 }
196
197 pub fn del_kex_algorithms(mut self, alg: Kex) -> Self {
198 self.config.algs.key_exchange.retain(|x| *x != alg);
199 self
200 }
201
202 pub fn add_pubkey_algorithms(mut self, alg: PubKey) -> Self {
203 self.config.algs.public_key.push(alg);
204 self
205 }
206
207 pub fn del_pubkey_algorithms(mut self, alg: PubKey) -> Self {
208 self.config.algs.public_key.retain(|x| *x != alg);
209 self
210 }
211
212 pub fn add_enc_algorithms(mut self, alg: Enc) -> Self {
213 self.config.algs.c_encryption.push(alg);
214 self.config.algs.s_encryption.push(alg);
215 self
216 }
217
218 pub fn del_enc_algorithms(mut self, alg: Enc) -> Self {
219 self.config.algs.c_encryption.retain(|x| *x != alg);
220 self.config.algs.s_encryption.retain(|x| *x != alg);
221 self
222 }
223
224 pub fn add_mac_algortihms(mut self, alg: Mac) -> Self {
225 self.config.algs.c_mac.push(alg);
226 self.config.algs.s_mac.push(alg);
227 self
228 }
229
230 pub fn del_mac_algortihms(mut self, alg: Mac) -> Self {
231 self.config.algs.c_mac.retain(|x| *x != alg);
232 self.config.algs.s_mac.retain(|x| *x != alg);
233 self
234 }
235
236 pub fn add_compress_algorithms(mut self, alg: Compress) -> Self {
237 self.config.algs.c_compress.push(alg);
238 self.config.algs.s_compress.push(alg);
239 self
240 }
241
242 pub fn del_compress_algorithms(mut self, alg: Compress) -> Self {
243 self.config.algs.c_compress.retain(|x| *x != alg);
244 self.config.algs.s_compress.retain(|x| *x != alg);
245 self
246 }
247
248 pub fn connect<A>(self, addr: A) -> SshResult<SessionConnector<TcpStream>>
251 where
252 A: ToSocketAddrs,
253 {
254 let tcp = if let Some(ref to) = self.config.timeout {
256 TcpStream::connect_timeout(&addr.to_socket_addrs()?.next().unwrap(), *to)?
257 } else {
258 TcpStream::connect(addr)?
259 };
260
261 tcp.set_nonblocking(true).unwrap();
263 self.connect_bio(tcp)
264 }
265
266 pub fn connect_with_timeout<A>(
269 self,
270 addr: A,
271 timeout: Option<Duration>,
272 ) -> SshResult<SessionConnector<TcpStream>>
273 where
274 A: ToSocketAddrs,
275 {
276 let tcp = if let Some(ref to) = timeout {
278 TcpStream::connect_timeout(&addr.to_socket_addrs()?.next().unwrap(), *to)?
279 } else {
280 TcpStream::connect(addr)?
281 };
282
283 tcp.set_nonblocking(true).unwrap();
285 self.connect_bio(tcp)
286 }
287
288 pub fn connect_bio<S>(mut self, stream: S) -> SshResult<SessionConnector<S>>
293 where
294 S: Read + Write,
295 {
296 self.config.tune_alglist_on_private_key();
297 SessionConnector {
298 inner: SessionState::Init(self.config, stream),
299 }
300 .connect()
301 }
302}