ssh_test_server/builder.rs
1use crate::session::SshConnection;
2use crate::user::User;
3use crate::{SshExecuteHandler, SshServer};
4use anyhow::Result;
5use rand::Rng;
6use random_port::{PortPicker, Protocol};
7use russh::{server, MethodSet};
8use russh_keys::key;
9use russh_keys::key::KeyPair;
10use std::borrow::Cow;
11use std::collections::HashMap;
12use std::sync::{Arc, Mutex};
13use std::time::Duration;
14use tokio::net::TcpListener;
15use tracing::debug;
16
17/// Builder for the ssh server.
18///
19/// # Example
20///
21/// ```
22/// use ssh_test_server::{SshServerBuilder, User};
23///
24/// # #[tokio::main(flavor = "current_thread")]
25/// # async fn main() {
26/// let _ssh = SshServerBuilder::default()
27/// .add_user(User::new_admin("root", "pass123"))
28/// .run()
29/// .await
30/// .unwrap();
31/// # }
32/// ```
33#[derive(Default)]
34pub struct SshServerBuilder {
35 port: Option<u16>,
36 bind_addr: Option<String>,
37 users: Vec<User>,
38 programs: HashMap<String, Box<SshExecuteHandler>>,
39}
40
41impl SshServerBuilder {
42 /// Add a user to ssh server.
43 ///
44 /// # Example
45 ///
46 /// ```
47 /// # use ssh_test_server::{SshServerBuilder, User};
48 /// # #[tokio::main(flavor = "current_thread")]
49 /// # async fn main() {
50 /// let _ssh = SshServerBuilder::default()
51 /// .add_user(User::new_admin("root", "pass"))
52 /// .add_user(User::new("luiza", "obrazy"))
53 /// .run()
54 /// .await
55 /// .unwrap();
56 /// # }
57 /// ```
58 pub fn add_user(mut self, user: User) -> Self {
59 self.users.push(user);
60 self
61 }
62
63 /// Add list of users.
64 ///
65 ///
66 /// ```
67 /// # use ssh_test_server::{SshServerBuilder, User};
68 /// # #[tokio::main(flavor = "current_thread")]
69 /// # async fn main() {
70 /// let users = vec![User::new("a", "p"), User::new("b", "p")];
71 /// let _ssh = SshServerBuilder::default()
72 /// .add_users(&users)
73 /// .run()
74 /// .await
75 /// .unwrap();
76 /// # }
77 /// ```
78 pub fn add_users(mut self, users: &[User]) -> Self {
79 for u in users {
80 self.users.push(u.clone());
81 }
82 self
83 }
84
85 /// Add custom command/program.
86 ///
87 /// # Example
88 ///
89 /// ```
90 /// # use ssh_test_server::{SshExecuteContext, SshServerBuilder, SshExecuteResult, User};
91 /// fn cmd_print_message(
92 /// context: &SshExecuteContext,
93 /// program: &str,
94 /// args: &[&str],
95 /// ) -> SshExecuteResult {
96 /// let stdout = format!(
97 /// "Program {program} run by {} has {} args.",
98 /// context.current_user,
99 /// args.len()
100 /// );
101 /// SshExecuteResult::stdout(0, stdout)
102 /// }
103 ///
104 /// # #[tokio::main(flavor = "current_thread")]
105 /// # async fn main() {
106 /// let _ssh = SshServerBuilder::default()
107 /// .add_program("print_message", Box::new(cmd_print_message))
108 /// .run()
109 /// .await
110 /// .unwrap();
111 /// # }
112 /// ```
113 pub fn add_program(mut self, program: &str, handler: Box<SshExecuteHandler>) -> Self {
114 self.programs.insert(program.to_string(), handler);
115 self
116 }
117
118 /// Listen on address.
119 ///
120 /// # Example
121 ///
122 /// ```
123 /// # use ssh_test_server::SshServerBuilder;
124 /// # #[tokio::main(flavor = "current_thread")]
125 /// # async fn main() {
126 /// let _ssh = SshServerBuilder::default()
127 /// .bind_addr("127.0.0.1")
128 /// .run()
129 /// .await
130 /// .unwrap();
131 /// # }
132 pub fn bind_addr(mut self, bind_addr: &str) -> Self {
133 self.bind_addr = Some(bind_addr.to_string());
134 self
135 }
136
137 /// Listen on port.
138 ///
139 /// # Example
140 ///
141 /// ```no_run
142 /// # use ssh_test_server::SshServerBuilder;
143 /// # #[tokio::main(flavor = "current_thread")]
144 /// # async fn main() {
145 /// let _ssh = SshServerBuilder::default()
146 /// .port(9992)
147 /// .run()
148 /// .await
149 /// .unwrap();
150 /// # }
151 pub fn port(mut self, port: u16) -> Self {
152 self.port = Some(port);
153 self
154 }
155
156 /// Build and run the ssh server.
157 ///
158 /// Server stops when [SshServer] is dropped.
159 pub async fn run(self) -> Result<SshServer> {
160 let host = self
161 .bind_addr
162 .clone()
163 .unwrap_or_else(|| "127.0.0.1".to_string());
164
165 let port_picker = PortPicker::new()
166 .host(host.clone())
167 .protocol(Protocol::Tcp)
168 .random(true);
169 let port = self.port.unwrap_or_else(|| {
170 port_picker.pick().unwrap_or_else(|_| {
171 let mut rng = rand::thread_rng();
172 rng.gen_range(15000..55000)
173 })
174 });
175 let addr = format!("{host}:{port}");
176
177 let server_keys = KeyPair::generate_ed25519();
178 let server_public_key = server_keys.clone_public_key()?;
179
180 let mut config = server::Config {
181 methods: MethodSet::PASSWORD,
182 auth_rejection_time: Duration::from_secs(0),
183 ..Default::default()
184 };
185 config.preferred.key = Cow::Borrowed(&[key::ED25519]);
186 config.keys.push(server_keys);
187 let config = Arc::new(config);
188 let users: Arc<Mutex<HashMap<String, User>>> = Arc::new(Mutex::new(
189 self.users
190 .into_iter()
191 .map(|u| (u.login().to_string(), u))
192 .collect(),
193 ));
194
195 let socket = TcpListener::bind(addr).await?;
196 let users2 = users.clone();
197
198 let listener = tokio::spawn(async move {
199 let programs = Arc::new(self.programs);
200 let mut id = 0;
201 while let Ok((socket, addr)) = socket.accept().await {
202 let config = config.clone();
203 debug!("New connection from {addr:?}");
204 let s = SshConnection::new(id, users2.clone(), programs.clone());
205 tokio::spawn(server::run_stream(config, socket, s));
206 id += 1;
207 }
208 debug!("ssh server stopped");
209 });
210
211 Ok(SshServer {
212 listener,
213 users,
214 port,
215 host,
216 server_public_key,
217 })
218 }
219}