ssh_test_server/
builder.rs

1use crate::session::SshConnection;
2use crate::user::User;
3use crate::{SshExecuteHandler, SshServer};
4use anyhow::Result;
5use rand::Rng;
6use random_port::{PortPicker, Protocol};
7use russh::{server, MethodSet};
8use russh_keys::key;
9use russh_keys::key::KeyPair;
10use std::borrow::Cow;
11use std::collections::HashMap;
12use std::sync::{Arc, Mutex};
13use std::time::Duration;
14use tokio::net::TcpListener;
15use tracing::debug;
16
17/// Builder for the ssh server.
18///
19/// # Example
20///
21/// ```
22/// use ssh_test_server::{SshServerBuilder, User};
23///
24/// # #[tokio::main(flavor = "current_thread")]
25/// # async fn main() {
26/// let _ssh = SshServerBuilder::default()
27///     .add_user(User::new_admin("root", "pass123"))
28///     .run()
29///     .await
30///     .unwrap();
31/// # }
32/// ```
33#[derive(Default)]
34pub struct SshServerBuilder {
35    port: Option<u16>,
36    bind_addr: Option<String>,
37    users: Vec<User>,
38    programs: HashMap<String, Box<SshExecuteHandler>>,
39}
40
41impl SshServerBuilder {
42    /// Add a user to ssh server.
43    ///
44    /// # Example
45    ///
46    /// ```
47    /// # use ssh_test_server::{SshServerBuilder, User};
48    /// # #[tokio::main(flavor = "current_thread")]
49    /// # async fn main() {
50    /// let _ssh = SshServerBuilder::default()
51    ///     .add_user(User::new_admin("root", "pass"))
52    ///     .add_user(User::new("luiza", "obrazy"))
53    ///     .run()
54    ///     .await
55    ///     .unwrap();
56    /// # }
57    /// ```
58    pub fn add_user(mut self, user: User) -> Self {
59        self.users.push(user);
60        self
61    }
62
63    /// Add list of users.
64    ///
65    ///
66    /// ```
67    /// # use ssh_test_server::{SshServerBuilder, User};
68    /// # #[tokio::main(flavor = "current_thread")]
69    /// # async fn main() {
70    /// let users = vec![User::new("a", "p"), User::new("b", "p")];
71    /// let _ssh = SshServerBuilder::default()
72    ///     .add_users(&users)
73    ///     .run()
74    ///     .await
75    ///     .unwrap();
76    /// # }
77    /// ```
78    pub fn add_users(mut self, users: &[User]) -> Self {
79        for u in users {
80            self.users.push(u.clone());
81        }
82        self
83    }
84
85    /// Add custom command/program.
86    ///
87    /// # Example
88    ///
89    /// ```
90    /// # use ssh_test_server::{SshExecuteContext, SshServerBuilder, SshExecuteResult, User};
91    /// fn cmd_print_message(
92    ///     context: &SshExecuteContext,
93    ///     program: &str,
94    ///     args: &[&str],
95    /// ) -> SshExecuteResult {
96    ///     let stdout = format!(
97    ///         "Program {program} run by {} has {} args.",
98    ///         context.current_user,
99    ///         args.len()
100    ///     );
101    ///     SshExecuteResult::stdout(0, stdout)
102    /// }
103    ///
104    /// # #[tokio::main(flavor = "current_thread")]
105    /// # async fn main() {
106    /// let _ssh = SshServerBuilder::default()
107    ///     .add_program("print_message", Box::new(cmd_print_message))
108    ///     .run()
109    ///     .await
110    ///     .unwrap();
111    /// # }
112    /// ```
113    pub fn add_program(mut self, program: &str, handler: Box<SshExecuteHandler>) -> Self {
114        self.programs.insert(program.to_string(), handler);
115        self
116    }
117
118    /// Listen on address.
119    ///
120    /// # Example
121    ///
122    /// ```
123    /// # use ssh_test_server::SshServerBuilder;
124    /// # #[tokio::main(flavor = "current_thread")]
125    /// # async fn main() {
126    /// let _ssh = SshServerBuilder::default()
127    ///     .bind_addr("127.0.0.1")
128    ///     .run()
129    ///     .await
130    ///     .unwrap();
131    /// # }
132    pub fn bind_addr(mut self, bind_addr: &str) -> Self {
133        self.bind_addr = Some(bind_addr.to_string());
134        self
135    }
136
137    /// Listen on port.
138    ///
139    /// # Example
140    ///
141    /// ```no_run
142    /// # use ssh_test_server::SshServerBuilder;
143    /// # #[tokio::main(flavor = "current_thread")]
144    /// # async fn main() {
145    /// let _ssh = SshServerBuilder::default()
146    ///     .port(9992)
147    ///     .run()
148    ///     .await
149    ///     .unwrap();
150    /// # }
151    pub fn port(mut self, port: u16) -> Self {
152        self.port = Some(port);
153        self
154    }
155
156    /// Build and run the ssh server.
157    ///
158    /// Server stops when [SshServer] is dropped.
159    pub async fn run(self) -> Result<SshServer> {
160        let host = self
161            .bind_addr
162            .clone()
163            .unwrap_or_else(|| "127.0.0.1".to_string());
164
165        let port_picker = PortPicker::new()
166            .host(host.clone())
167            .protocol(Protocol::Tcp)
168            .random(true);
169        let port = self.port.unwrap_or_else(|| {
170            port_picker.pick().unwrap_or_else(|_| {
171                let mut rng = rand::thread_rng();
172                rng.gen_range(15000..55000)
173            })
174        });
175        let addr = format!("{host}:{port}");
176
177        let server_keys = KeyPair::generate_ed25519();
178        let server_public_key = server_keys.clone_public_key()?;
179
180        let mut config = server::Config {
181            methods: MethodSet::PASSWORD,
182            auth_rejection_time: Duration::from_secs(0),
183            ..Default::default()
184        };
185        config.preferred.key = Cow::Borrowed(&[key::ED25519]);
186        config.keys.push(server_keys);
187        let config = Arc::new(config);
188        let users: Arc<Mutex<HashMap<String, User>>> = Arc::new(Mutex::new(
189            self.users
190                .into_iter()
191                .map(|u| (u.login().to_string(), u))
192                .collect(),
193        ));
194
195        let socket = TcpListener::bind(addr).await?;
196        let users2 = users.clone();
197
198        let listener = tokio::spawn(async move {
199            let programs = Arc::new(self.programs);
200            let mut id = 0;
201            while let Ok((socket, addr)) = socket.accept().await {
202                let config = config.clone();
203                debug!("New connection from {addr:?}");
204                let s = SshConnection::new(id, users2.clone(), programs.clone());
205                tokio::spawn(server::run_stream(config, socket, s));
206                id += 1;
207            }
208            debug!("ssh server stopped");
209        });
210
211        Ok(SshServer {
212            listener,
213            users,
214            port,
215            host,
216            server_public_key,
217        })
218    }
219}