1use std::time::Duration;
4
5use crate::{Credential, Endpoint, Error, HostKeyErrorKind, Identity, Result, Username};
6
7#[non_exhaustive]
9#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
10#[derive(Clone, Debug, Default, Eq, PartialEq)]
11pub enum HostKeyPolicy {
12 #[default]
15 Strict,
16 InsecureAcceptAny,
21 PinnedSha256(Vec<HostKeyFingerprint>),
23}
24
25impl HostKeyPolicy {
26 pub fn pinned_sha256(fingerprint: impl Into<String>) -> Result<Self> {
28 Ok(Self::PinnedSha256(vec![HostKeyFingerprint::sha256(
29 fingerprint,
30 )?]))
31 }
32
33 pub fn accepts_any(&self) -> bool {
35 matches!(self, Self::InsecureAcceptAny)
36 }
37}
38
39#[non_exhaustive]
41#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
42#[derive(Clone, Debug, Eq, Hash, PartialEq)]
43pub struct HostKeyFingerprint {
44 algorithm: HostKeyFingerprintAlgorithm,
45 value: String,
46}
47
48impl HostKeyFingerprint {
49 pub fn sha256(value: impl Into<String>) -> Result<Self> {
51 let value = value.into();
52 validate_sha256_fingerprint(&value)?;
53 Ok(Self {
54 algorithm: HostKeyFingerprintAlgorithm::Sha256,
55 value,
56 })
57 }
58
59 pub fn algorithm(&self) -> HostKeyFingerprintAlgorithm {
61 self.algorithm
62 }
63
64 pub fn value(&self) -> &str {
66 &self.value
67 }
68}
69
70#[non_exhaustive]
72#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
73#[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)]
74pub enum HostKeyFingerprintAlgorithm {
75 Sha256,
77}
78
79#[non_exhaustive]
81#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
82#[derive(Clone, Debug, Eq, PartialEq)]
83pub struct ClientConfig {
84 endpoint: Endpoint,
85 username: Option<Username>,
86 #[cfg_attr(feature = "serde", serde(skip))]
87 credentials: Vec<Credential>,
88 timeouts: Timeouts,
89 keepalive: Keepalive,
90 host_key_policy: HostKeyPolicy,
91}
92
93impl ClientConfig {
94 pub fn new(endpoint: impl Into<Endpoint>) -> Self {
96 Self {
97 endpoint: endpoint.into(),
98 username: None,
99 credentials: Vec::new(),
100 timeouts: Timeouts::default(),
101 keepalive: Keepalive::default(),
102 host_key_policy: HostKeyPolicy::default(),
103 }
104 }
105
106 pub fn endpoint(&self) -> &Endpoint {
108 &self.endpoint
109 }
110
111 pub fn set_endpoint(&mut self, endpoint: impl Into<Endpoint>) {
113 self.endpoint = endpoint.into();
114 }
115
116 pub fn username(&self) -> Option<&Username> {
118 self.username.as_ref()
119 }
120
121 pub fn set_username(&mut self, username: impl Into<Username>) {
123 self.username = Some(username.into());
124 }
125
126 pub fn credentials(&self) -> &[Credential] {
128 &self.credentials
129 }
130
131 pub fn add_credential(&mut self, credential: Credential) {
133 self.credentials.push(credential);
134 }
135
136 pub fn use_agent(&mut self) {
138 self.add_credential(Credential::identity(Identity::agent()));
139 }
140
141 pub fn timeouts(&self) -> &Timeouts {
143 &self.timeouts
144 }
145
146 pub fn set_timeouts(&mut self, timeouts: Timeouts) {
148 self.timeouts = timeouts;
149 }
150
151 pub fn keepalive(&self) -> &Keepalive {
153 &self.keepalive
154 }
155
156 pub fn set_keepalive(&mut self, keepalive: Keepalive) {
158 self.keepalive = keepalive;
159 }
160
161 pub fn strict_host_key_checking(&self) -> bool {
163 !self.host_key_policy.accepts_any()
164 }
165
166 #[deprecated = "use set_host_key_policy instead"]
168 pub fn set_strict_host_key_checking(&mut self, enabled: bool) {
169 self.host_key_policy = if enabled {
170 HostKeyPolicy::Strict
171 } else {
172 HostKeyPolicy::InsecureAcceptAny
173 };
174 }
175
176 pub fn host_key_policy(&self) -> &HostKeyPolicy {
178 &self.host_key_policy
179 }
180
181 pub fn set_host_key_policy(&mut self, policy: HostKeyPolicy) {
183 self.host_key_policy = policy;
184 }
185}
186
187impl Default for ClientConfig {
188 fn default() -> Self {
189 Self::new(Endpoint::default())
190 }
191}
192
193#[non_exhaustive]
195#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
196#[derive(Clone, Debug, Eq, PartialEq)]
197pub struct ServerConfig {
198 listen: Endpoint,
199 server_id: String,
200 max_sessions: usize,
201}
202
203impl ServerConfig {
204 pub fn new(listen: impl Into<Endpoint>) -> Self {
206 Self {
207 listen: listen.into(),
208 server_id: "SSH-2.0-russh-extra".to_owned(),
209 max_sessions: 1024,
210 }
211 }
212
213 pub fn listen(&self) -> &Endpoint {
215 &self.listen
216 }
217
218 pub fn set_listen(&mut self, listen: impl Into<Endpoint>) {
220 self.listen = listen.into();
221 }
222
223 pub fn server_id(&self) -> &str {
225 &self.server_id
226 }
227
228 pub fn set_server_id(&mut self, server_id: impl Into<String>) {
230 self.server_id = server_id.into();
231 }
232
233 pub fn max_sessions(&self) -> usize {
235 self.max_sessions
236 }
237
238 pub fn set_max_sessions(&mut self, max_sessions: usize) {
240 self.max_sessions = max_sessions;
241 }
242}
243
244impl Default for ServerConfig {
245 fn default() -> Self {
246 Self::new(("127.0.0.1", 0))
247 }
248}
249
250#[non_exhaustive]
252#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
253#[derive(Clone, Debug, Eq, PartialEq)]
254pub struct Timeouts {
255 connect: Duration,
256 auth: Duration,
257 channel_open: Duration,
258}
259
260impl Default for Timeouts {
261 fn default() -> Self {
262 Self {
263 connect: Duration::from_secs(30),
264 auth: Duration::from_secs(30),
265 channel_open: Duration::from_secs(10),
266 }
267 }
268}
269
270impl Timeouts {
271 pub fn new(connect: Duration, auth: Duration, channel_open: Duration) -> Self {
273 Self {
274 connect,
275 auth,
276 channel_open,
277 }
278 }
279
280 pub fn connect(&self) -> Duration {
282 self.connect
283 }
284
285 pub fn auth(&self) -> Duration {
287 self.auth
288 }
289
290 pub fn channel_open(&self) -> Duration {
292 self.channel_open
293 }
294}
295
296#[non_exhaustive]
298#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
299#[derive(Clone, Debug, Eq, PartialEq)]
300pub struct Keepalive {
301 enabled: bool,
302 interval: Duration,
303 max_missed: u32,
304}
305
306impl Keepalive {
307 pub fn new(enabled: bool, interval: Duration, max_missed: u32) -> Self {
309 Self {
310 enabled,
311 interval,
312 max_missed,
313 }
314 }
315
316 pub fn enabled(&self) -> bool {
318 self.enabled
319 }
320
321 pub fn interval(&self) -> Duration {
323 self.interval
324 }
325
326 pub fn max_missed(&self) -> u32 {
328 self.max_missed
329 }
330}
331
332impl Default for Keepalive {
333 fn default() -> Self {
334 Self {
335 enabled: true,
336 interval: Duration::from_secs(30),
337 max_missed: 3,
338 }
339 }
340}
341
342fn validate_sha256_fingerprint(value: &str) -> Result<()> {
343 let Some(rest) = value.strip_prefix("SHA256:") else {
344 return Err(Error::host_key(
345 HostKeyErrorKind::Unsupported,
346 "host-key fingerprint must start with SHA256:",
347 ));
348 };
349
350 if rest.is_empty() {
351 return Err(Error::host_key(
352 HostKeyErrorKind::Unavailable,
353 "host-key fingerprint must not be empty",
354 ));
355 }
356
357 if rest.bytes().any(|byte| byte.is_ascii_whitespace()) {
358 return Err(Error::host_key(
359 HostKeyErrorKind::Rejected,
360 "host-key fingerprint must not contain whitespace",
361 ));
362 }
363
364 Ok(())
365}
366
367#[cfg(test)]
368mod tests {
369 use crate::{
370 ClientConfig, Endpoint, Error, HostKeyFingerprint, HostKeyFingerprintAlgorithm,
371 HostKeyPolicy,
372 };
373
374 #[test]
375 fn server_config_defaults_to_loopback_ephemeral_port() {
376 let config = crate::ServerConfig::default();
377
378 assert_eq!(config.listen(), &Endpoint::new("127.0.0.1", 0));
379 }
380
381 #[test]
382 fn client_config_defaults_to_strict_host_key_policy() {
383 let config = ClientConfig::default();
384
385 assert_eq!(config.host_key_policy(), &HostKeyPolicy::Strict);
386 assert!(config.strict_host_key_checking());
387 }
388
389 #[test]
390 #[allow(deprecated)]
391 fn disabling_strict_host_key_checking_sets_accept_any_policy() {
392 let mut config = ClientConfig::default();
393
394 config.set_strict_host_key_checking(false);
395
396 assert_eq!(config.host_key_policy(), &HostKeyPolicy::InsecureAcceptAny);
397 assert!(!config.strict_host_key_checking());
398 }
399
400 #[test]
401 fn validates_sha256_host_key_fingerprints() {
402 let fingerprint = HostKeyFingerprint::sha256("SHA256:abc123+/=").unwrap();
403
404 assert_eq!(fingerprint.algorithm(), HostKeyFingerprintAlgorithm::Sha256);
405 assert_eq!(fingerprint.value(), "SHA256:abc123+/=");
406 }
407
408 #[test]
409 fn rejects_invalid_sha256_host_key_fingerprints() {
410 let error = HostKeyFingerprint::sha256("MD5:abc").unwrap_err();
411 assert!(matches!(error, Error::HostKey(_)));
412
413 let error = HostKeyFingerprint::sha256("SHA256:").unwrap_err();
414 assert!(matches!(error, Error::HostKey(_)));
415 }
416
417 #[test]
418 #[cfg(feature = "serde")]
419 fn client_config_serialization_skips_credentials() {
420 let mut config = ClientConfig::new(Endpoint::new("example.com", 2222));
421 config.add_credential(crate::Credential::password("secret"));
422
423 let serialized = serde_json::to_string(&config).unwrap();
424 let deserialized: ClientConfig = serde_json::from_str(&serialized).unwrap();
425
426 assert!(!serialized.contains("secret"));
427 assert!(!serialized.contains("credentials"));
428 assert!(deserialized.credentials().is_empty());
429 }
430
431 #[test]
432 fn client_config_debug_does_not_expose_credential_content() {
433 let mut config = ClientConfig::new(Endpoint::new("example.com", 2222));
434 config.add_credential(crate::Credential::password("my-secret-password"));
435 let debug = format!("{:?}", config);
436 assert!(!debug.contains("my-secret-password"));
437 assert!(debug.contains("Password(***)"));
438 }
439
440 #[test]
441 fn keepalive_defaults_enabled_with_30s_interval() {
442 let k = crate::Keepalive::default();
443 assert!(k.enabled());
444 assert_eq!(k.interval(), std::time::Duration::from_secs(30));
445 assert_eq!(k.max_missed(), 3);
446 }
447
448 #[test]
449 fn keepalive_new_stores_fields() {
450 let k = crate::Keepalive::new(true, std::time::Duration::from_secs(15), 5);
451 assert!(k.enabled());
452 assert_eq!(k.interval(), std::time::Duration::from_secs(15));
453 assert_eq!(k.max_missed(), 5);
454 }
455
456 #[test]
457 fn keepalive_disabled_still_stores_interval() {
458 let k = crate::Keepalive::new(false, std::time::Duration::from_secs(5), 1);
459 assert!(!k.enabled());
460 assert_eq!(k.interval(), std::time::Duration::from_secs(5));
461 }
462
463 #[test]
464 fn timeouts_new_stores_fields() {
465 use std::time::Duration;
466 let t = crate::Timeouts::new(
467 Duration::from_secs(5),
468 Duration::from_secs(10),
469 Duration::from_secs(2),
470 );
471 assert_eq!(t.connect(), Duration::from_secs(5));
472 assert_eq!(t.auth(), Duration::from_secs(10));
473 assert_eq!(t.channel_open(), Duration::from_secs(2));
474 }
475
476 #[test]
477 fn timeouts_defaults_are_reasonable() {
478 let t = crate::Timeouts::default();
479 assert!(t.connect() > std::time::Duration::ZERO);
480 assert!(t.auth() > std::time::Duration::ZERO);
481 assert!(t.channel_open() > std::time::Duration::ZERO);
482 }
483
484 #[test]
485 fn timeouts_with_zero_durations_stores_them() {
486 use std::time::Duration;
487 let t = crate::Timeouts::new(Duration::ZERO, Duration::ZERO, Duration::ZERO);
488 assert_eq!(t.connect(), Duration::ZERO);
489 assert_eq!(t.auth(), Duration::ZERO);
490 assert_eq!(t.channel_open(), Duration::ZERO);
491 }
492}