1use std::collections::HashMap;
6
7use super::{Duration, PathBuf};
8
9#[derive(Debug, Default, Clone, PartialEq, Eq)]
13pub struct HostParams {
14 pub bind_address: Option<String>,
16 pub bind_interface: Option<String>,
18 pub ca_signature_algorithms: Option<Vec<String>>,
20 pub certificate_file: Option<PathBuf>,
22 pub ciphers: Option<Vec<String>>,
24 pub compression: Option<bool>,
26 pub connection_attempts: Option<usize>,
28 pub connect_timeout: Option<Duration>,
30 pub host_key_algorithms: Option<Vec<String>>,
32 pub host_name: Option<String>,
34 pub identity_file: Option<Vec<PathBuf>>,
38 pub ignore_unknown: Option<Vec<String>>,
40 pub kex_algorithms: Option<Vec<String>>,
42 pub mac: Option<Vec<String>>,
44 pub port: Option<u16>,
46 pub pubkey_accepted_algorithms: Option<Vec<String>>,
48 pub pubkey_authentication: Option<bool>,
50 pub remote_forward: Option<u16>,
52 pub server_alive_interval: Option<Duration>,
54 pub tcp_keep_alive: Option<bool>,
56 #[cfg(target_os = "macos")]
57 pub use_keychain: Option<bool>,
59 pub user: Option<String>,
61 pub ignored_fields: HashMap<String, Vec<String>>,
63 pub unsupported_fields: HashMap<String, Vec<String>>,
65}
66
67impl HostParams {
68 pub(crate) fn ignored(&self, param: &str) -> bool {
70 self.ignore_unknown
71 .as_ref()
72 .map(|x| x.iter().any(|x| x.as_str() == param))
73 .unwrap_or(false)
74 }
75
76 pub fn merge(&mut self, b: &Self) {
78 if let Some(bind_address) = b.bind_address.as_deref() {
79 self.bind_address = Some(bind_address.to_owned());
80 }
81 if let Some(bind_interface) = b.bind_interface.as_deref() {
82 self.bind_interface = Some(bind_interface.to_owned());
83 }
84 if let Some(ca_signature_algorithms) = b.ca_signature_algorithms.as_deref() {
85 if self.ca_signature_algorithms.is_none() {
86 self.ca_signature_algorithms = Some(Vec::new());
87 }
88 Self::resolve_algorithms(
89 self.ca_signature_algorithms.as_mut().unwrap(),
90 ca_signature_algorithms,
91 );
92 }
93 if let Some(certificate_file) = b.certificate_file.as_deref() {
94 self.certificate_file = Some(certificate_file.to_owned());
95 }
96 if let Some(ciphers) = b.ciphers.as_deref() {
97 if self.ciphers.is_none() {
98 self.ciphers = Some(Vec::new());
99 }
100 Self::resolve_algorithms(self.ciphers.as_mut().unwrap(), ciphers);
101 }
102 if let Some(compression) = b.compression {
103 self.compression = Some(compression);
104 }
105 if let Some(connection_attempts) = b.connection_attempts {
106 self.connection_attempts = Some(connection_attempts);
107 }
108 if let Some(connect_timeout) = b.connect_timeout {
109 self.connect_timeout = Some(connect_timeout);
110 }
111 if let Some(host_key_algorithms) = b.host_key_algorithms.as_deref() {
112 if self.host_key_algorithms.is_none() {
113 self.host_key_algorithms = Some(Vec::new());
114 }
115 Self::resolve_algorithms(
116 self.host_key_algorithms.as_mut().unwrap(),
117 host_key_algorithms,
118 );
119 }
120 if let Some(host_name) = b.host_name.as_deref() {
121 self.host_name = Some(host_name.to_owned());
122 }
123 if let Some(identity_file) = b.identity_file.as_deref() {
124 self.identity_file = Some(identity_file.to_owned());
125 }
126 if let Some(ignore_unknown) = b.ignore_unknown.as_deref() {
127 self.ignore_unknown = Some(ignore_unknown.to_owned());
128 }
129 if let Some(kex_algorithms) = b.kex_algorithms.as_deref() {
130 if self.kex_algorithms.is_none() {
131 self.kex_algorithms = Some(Vec::new());
132 }
133 Self::resolve_algorithms(self.kex_algorithms.as_mut().unwrap(), kex_algorithms);
134 }
135 if let Some(mac) = b.mac.as_deref() {
136 if self.mac.is_none() {
137 self.mac = Some(Vec::new());
138 }
139 Self::resolve_algorithms(self.mac.as_mut().unwrap(), mac);
140 }
141 if let Some(port) = b.port {
142 self.port = Some(port);
143 }
144 if let Some(pubkey_accepted_algorithms) = b.pubkey_accepted_algorithms.as_deref() {
145 if self.pubkey_accepted_algorithms.is_none() {
146 self.pubkey_accepted_algorithms = Some(Vec::new());
147 }
148 Self::resolve_algorithms(
149 self.pubkey_accepted_algorithms.as_mut().unwrap(),
150 pubkey_accepted_algorithms,
151 );
152 }
153 if let Some(pubkey_authentication) = b.pubkey_authentication {
154 self.pubkey_authentication = Some(pubkey_authentication);
155 }
156 if let Some(remote_forward) = b.remote_forward {
157 self.remote_forward = Some(remote_forward);
158 }
159 if let Some(server_alive_interval) = b.server_alive_interval {
160 self.server_alive_interval = Some(server_alive_interval);
161 }
162 if let Some(tcp_keep_alive) = b.tcp_keep_alive {
163 self.tcp_keep_alive = Some(tcp_keep_alive);
164 }
165 #[cfg(target_os = "macos")]
166 if let Some(use_keychain) = b.use_keychain {
167 self.use_keychain = Some(use_keychain);
168 }
169 if let Some(user) = b.user.as_deref() {
170 self.user = Some(user.to_owned());
171 }
172 for (ignored_field, args) in &b.ignored_fields {
173 if !self.ignored_fields.contains_key(ignored_field) {
174 self.ignored_fields
175 .insert(ignored_field.to_owned(), args.to_owned());
176 }
177 }
178
179 for (unsupported_field, args) in &b.unsupported_fields {
180 if !self.unsupported_fields.contains_key(unsupported_field) {
181 self.unsupported_fields
182 .insert(unsupported_field.to_owned(), args.to_owned());
183 }
184 }
185 }
186
187 fn resolve_algorithms(current_list: &mut Vec<String>, algos: &[String]) {
192 if algos.is_empty() {
193 return;
194 }
195 let first = algos.first().unwrap();
196 if first.starts_with('+') {
197 for algo in [first.replacen('+', "", 1)].iter().chain(algos[1..].iter()) {
199 if !current_list.contains(algo) {
200 current_list.push(algo.to_owned());
201 }
202 }
203 } else if first.starts_with('-') {
204 let new_first = [first.replacen('-', "", 1)];
206 current_list.retain(|algo| {
208 !new_first
209 .iter()
210 .chain(algos[1..].iter())
211 .any(|remove| remove == algo)
212 });
213 } else {
214 *current_list = algos.to_vec();
215 }
216 }
217}
218
219#[cfg(test)]
220mod test {
221
222 use pretty_assertions::assert_eq;
223
224 use super::*;
225
226 #[test]
227 fn should_initialize_params() {
228 let params = HostParams::default();
229 assert!(params.bind_address.is_none());
230 assert!(params.bind_interface.is_none());
231 assert!(params.ca_signature_algorithms.is_none());
232 assert!(params.certificate_file.is_none());
233 assert!(params.ciphers.is_none());
234 assert!(params.compression.is_none());
235 assert!(params.connection_attempts.is_none());
236 assert!(params.connect_timeout.is_none());
237 assert!(params.host_key_algorithms.is_none());
238 assert!(params.host_name.is_none());
239 assert!(params.identity_file.is_none());
240 assert!(params.ignore_unknown.is_none());
241 assert!(params.kex_algorithms.is_none());
242 assert!(params.mac.is_none());
243 assert!(params.port.is_none());
244 assert!(params.pubkey_accepted_algorithms.is_none());
245 assert!(params.pubkey_authentication.is_none());
246 assert!(params.remote_forward.is_none());
247 assert!(params.server_alive_interval.is_none());
248 #[cfg(target_os = "macos")]
249 assert!(params.use_keychain.is_none());
250 assert!(params.tcp_keep_alive.is_none());
251 }
252
253 #[test]
254 fn should_merge_params() {
255 let mut params = HostParams::default();
256 let mut b = HostParams {
257 bind_address: Some(String::from("pippo")),
258 bind_interface: Some(String::from("tun0")),
259 ca_signature_algorithms: Some(vec![]),
260 certificate_file: Some(PathBuf::default()),
261 ciphers: Some(vec![]),
262 compression: Some(true),
263 connect_timeout: Some(Duration::from_secs(1)),
264 connection_attempts: Some(3),
265 host_key_algorithms: Some(vec![]),
266 host_name: Some(String::from("192.168.1.2")),
267 identity_file: Some(vec![PathBuf::default()]),
268 ignore_unknown: Some(vec![]),
269 kex_algorithms: Some(vec![]),
270 mac: Some(vec![]),
271 port: Some(22),
272 pubkey_accepted_algorithms: Some(vec![]),
273 pubkey_authentication: Some(true),
274 remote_forward: Some(32),
275 server_alive_interval: Some(Duration::from_secs(10)),
276 #[cfg(target_os = "macos")]
277 use_keychain: Some(true),
278 tcp_keep_alive: Some(true),
279 ..Default::default()
280 };
281 params.merge(&b);
282 assert!(params.bind_address.is_some());
283 assert!(params.bind_interface.is_some());
284 assert!(params.ca_signature_algorithms.is_some());
285 assert!(params.certificate_file.is_some());
286 assert!(params.ciphers.is_some());
287 assert!(params.compression.is_some());
288 assert!(params.connection_attempts.is_some());
289 assert!(params.connect_timeout.is_some());
290 assert!(params.host_key_algorithms.is_some());
291 assert!(params.host_name.is_some());
292 assert!(params.identity_file.is_some());
293 assert!(params.ignore_unknown.is_some());
294 assert!(params.kex_algorithms.is_some());
295 assert!(params.mac.is_some());
296 assert!(params.port.is_some());
297 assert!(params.pubkey_accepted_algorithms.is_some());
298 assert!(params.pubkey_authentication.is_some());
299 assert!(params.remote_forward.is_some());
300 assert!(params.server_alive_interval.is_some());
301 #[cfg(target_os = "macos")]
302 assert!(params.use_keychain.is_some());
303 assert!(params.tcp_keep_alive.is_some());
304 b.tcp_keep_alive = None;
306 params.merge(&b);
307 assert_eq!(params.tcp_keep_alive.unwrap(), true);
308 }
309
310 #[test]
311 fn should_resolve_algorithms_list_when_preceeded_by_plus() {
312 let mut list = vec![
313 "a".to_string(),
314 "b".to_string(),
315 "c".to_string(),
316 "d".to_string(),
317 "e".to_string(),
318 ];
319 let algos = [
320 "+1".to_string(),
321 "a".to_string(),
322 "b".to_string(),
323 "3".to_string(),
324 "d".to_string(),
325 ];
326 HostParams::resolve_algorithms(&mut list, &algos);
327 assert_eq!(
328 list,
329 vec![
330 "a".to_string(),
331 "b".to_string(),
332 "c".to_string(),
333 "d".to_string(),
334 "e".to_string(),
335 "1".to_string(),
336 "3".to_string(),
337 ]
338 );
339 }
340
341 #[test]
342 fn should_resolve_algorithms_list_when_preceeded_by_minus() {
343 let mut list = vec![
344 "a".to_string(),
345 "b".to_string(),
346 "c".to_string(),
347 "d".to_string(),
348 "e".to_string(),
349 ];
350 let algos = ["-a".to_string(), "b".to_string(), "3".to_string()];
351 HostParams::resolve_algorithms(&mut list, &algos);
352 assert_eq!(
353 list,
354 vec!["c".to_string(), "d".to_string(), "e".to_string(),]
355 );
356 }
357
358 #[test]
359 fn should_resolve_algorithm_list_when_replacing() {
360 let mut list = vec![
361 "a".to_string(),
362 "b".to_string(),
363 "c".to_string(),
364 "d".to_string(),
365 "e".to_string(),
366 ];
367 let algos = [
368 "1".to_string(),
369 "a".to_string(),
370 "b".to_string(),
371 "3".to_string(),
372 "d".to_string(),
373 ];
374 HostParams::resolve_algorithms(&mut list, &algos);
375 assert_eq!(
376 list,
377 vec![
378 "1".to_string(),
379 "a".to_string(),
380 "b".to_string(),
381 "3".to_string(),
382 "d".to_string(),
383 ]
384 );
385 }
386}