Skip to main content

russh_extra_core/
config.rs

1//! Client and server configuration.
2
3use std::time::Duration;
4
5use crate::{Credential, Endpoint, Error, HostKeyErrorKind, Identity, Result, Username};
6
7/// Host-key verification policy.
8#[non_exhaustive]
9#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
10#[derive(Clone, Debug, Default, Eq, PartialEq)]
11pub enum HostKeyPolicy {
12    /// Reject host keys unless a future persistent store or verifier accepts
13    /// them. This is the default.
14    #[default]
15    Strict,
16    /// Accept every host key.
17    ///
18    /// **Insecure**: this disables host-key verification entirely. Only use
19    /// this policy in tests or controlled environments.
20    InsecureAcceptAny,
21    /// Accept only pinned SHA256 host-key fingerprints.
22    PinnedSha256(Vec<HostKeyFingerprint>),
23}
24
25impl HostKeyPolicy {
26    /// Creates a pinned SHA256 host-key policy.
27    pub fn pinned_sha256(fingerprint: impl Into<String>) -> Result<Self> {
28        Ok(Self::PinnedSha256(vec![HostKeyFingerprint::sha256(
29            fingerprint,
30        )?]))
31    }
32
33    /// Returns whether this policy accepts any host key.
34    pub fn accepts_any(&self) -> bool {
35        matches!(self, Self::InsecureAcceptAny)
36    }
37}
38
39/// Pinned host-key fingerprint.
40#[non_exhaustive]
41#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
42#[derive(Clone, Debug, Eq, Hash, PartialEq)]
43pub struct HostKeyFingerprint {
44    algorithm: HostKeyFingerprintAlgorithm,
45    value: String,
46}
47
48impl HostKeyFingerprint {
49    /// Creates a SHA256 host-key fingerprint.
50    pub fn sha256(value: impl Into<String>) -> Result<Self> {
51        let value = value.into();
52        validate_sha256_fingerprint(&value)?;
53        Ok(Self {
54            algorithm: HostKeyFingerprintAlgorithm::Sha256,
55            value,
56        })
57    }
58
59    /// Returns the fingerprint algorithm.
60    pub fn algorithm(&self) -> HostKeyFingerprintAlgorithm {
61        self.algorithm
62    }
63
64    /// Returns the fingerprint value.
65    pub fn value(&self) -> &str {
66        &self.value
67    }
68}
69
70/// Host-key fingerprint algorithm.
71#[non_exhaustive]
72#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
73#[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)]
74pub enum HostKeyFingerprintAlgorithm {
75    /// OpenSSH-style SHA256 host-key fingerprint.
76    Sha256,
77}
78
79/// Client connection configuration.
80#[non_exhaustive]
81#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
82#[derive(Clone, Debug, Eq, PartialEq)]
83pub struct ClientConfig {
84    endpoint: Endpoint,
85    username: Option<Username>,
86    #[cfg_attr(feature = "serde", serde(skip))]
87    credentials: Vec<Credential>,
88    timeouts: Timeouts,
89    keepalive: Keepalive,
90    host_key_policy: HostKeyPolicy,
91}
92
93impl ClientConfig {
94    /// Creates a config for the given endpoint.
95    pub fn new(endpoint: impl Into<Endpoint>) -> Self {
96        Self {
97            endpoint: endpoint.into(),
98            username: None,
99            credentials: Vec::new(),
100            timeouts: Timeouts::default(),
101            keepalive: Keepalive::default(),
102            host_key_policy: HostKeyPolicy::default(),
103        }
104    }
105
106    /// Returns the configured endpoint.
107    pub fn endpoint(&self) -> &Endpoint {
108        &self.endpoint
109    }
110
111    /// Sets the endpoint.
112    pub fn set_endpoint(&mut self, endpoint: impl Into<Endpoint>) {
113        self.endpoint = endpoint.into();
114    }
115
116    /// Returns the optional username.
117    pub fn username(&self) -> Option<&Username> {
118        self.username.as_ref()
119    }
120
121    /// Sets the username.
122    pub fn set_username(&mut self, username: impl Into<Username>) {
123        self.username = Some(username.into());
124    }
125
126    /// Returns configured credentials in preference order.
127    pub fn credentials(&self) -> &[Credential] {
128        &self.credentials
129    }
130
131    /// Adds a credential.
132    pub fn add_credential(&mut self, credential: Credential) {
133        self.credentials.push(credential);
134    }
135
136    /// Adds an SSH agent credential.
137    pub fn use_agent(&mut self) {
138        self.add_credential(Credential::identity(Identity::agent()));
139    }
140
141    /// Returns timeout settings.
142    pub fn timeouts(&self) -> &Timeouts {
143        &self.timeouts
144    }
145
146    /// Sets timeout settings.
147    pub fn set_timeouts(&mut self, timeouts: Timeouts) {
148        self.timeouts = timeouts;
149    }
150
151    /// Returns keepalive settings.
152    pub fn keepalive(&self) -> &Keepalive {
153        &self.keepalive
154    }
155
156    /// Sets keepalive settings.
157    pub fn set_keepalive(&mut self, keepalive: Keepalive) {
158        self.keepalive = keepalive;
159    }
160
161    /// Returns whether strict host key checking is enabled.
162    pub fn strict_host_key_checking(&self) -> bool {
163        !self.host_key_policy.accepts_any()
164    }
165
166    /// Sets strict host key checking.
167    #[deprecated = "use set_host_key_policy instead"]
168    pub fn set_strict_host_key_checking(&mut self, enabled: bool) {
169        self.host_key_policy = if enabled {
170            HostKeyPolicy::Strict
171        } else {
172            HostKeyPolicy::InsecureAcceptAny
173        };
174    }
175
176    /// Returns the configured host-key policy.
177    pub fn host_key_policy(&self) -> &HostKeyPolicy {
178        &self.host_key_policy
179    }
180
181    /// Sets the host-key policy.
182    pub fn set_host_key_policy(&mut self, policy: HostKeyPolicy) {
183        self.host_key_policy = policy;
184    }
185}
186
187impl Default for ClientConfig {
188    fn default() -> Self {
189        Self::new(Endpoint::default())
190    }
191}
192
193/// Server configuration.
194#[non_exhaustive]
195#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
196#[derive(Clone, Debug, Eq, PartialEq)]
197pub struct ServerConfig {
198    listen: Endpoint,
199    server_id: String,
200    max_sessions: usize,
201}
202
203impl ServerConfig {
204    /// Creates server configuration for a listen endpoint.
205    pub fn new(listen: impl Into<Endpoint>) -> Self {
206        Self {
207            listen: listen.into(),
208            server_id: "SSH-2.0-russh-extra".to_owned(),
209            max_sessions: 1024,
210        }
211    }
212
213    /// Returns the listen endpoint.
214    pub fn listen(&self) -> &Endpoint {
215        &self.listen
216    }
217
218    /// Sets the listen endpoint.
219    pub fn set_listen(&mut self, listen: impl Into<Endpoint>) {
220        self.listen = listen.into();
221    }
222
223    /// Returns the SSH identification string.
224    pub fn server_id(&self) -> &str {
225        &self.server_id
226    }
227
228    /// Sets the SSH identification string.
229    pub fn set_server_id(&mut self, server_id: impl Into<String>) {
230        self.server_id = server_id.into();
231    }
232
233    /// Returns the configured maximum session count.
234    pub fn max_sessions(&self) -> usize {
235        self.max_sessions
236    }
237
238    /// Sets the maximum session count.
239    pub fn set_max_sessions(&mut self, max_sessions: usize) {
240        self.max_sessions = max_sessions;
241    }
242}
243
244impl Default for ServerConfig {
245    fn default() -> Self {
246        Self::new(("127.0.0.1", 0))
247    }
248}
249
250/// Timeout configuration.
251#[non_exhaustive]
252#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
253#[derive(Clone, Debug, Eq, PartialEq)]
254pub struct Timeouts {
255    connect: Duration,
256    auth: Duration,
257    channel_open: Duration,
258}
259
260impl Default for Timeouts {
261    fn default() -> Self {
262        Self {
263            connect: Duration::from_secs(30),
264            auth: Duration::from_secs(30),
265            channel_open: Duration::from_secs(10),
266        }
267    }
268}
269
270impl Timeouts {
271    /// Creates new timeout configuration.
272    pub fn new(connect: Duration, auth: Duration, channel_open: Duration) -> Self {
273        Self {
274            connect,
275            auth,
276            channel_open,
277        }
278    }
279
280    /// Returns the TCP connection timeout.
281    pub fn connect(&self) -> Duration {
282        self.connect
283    }
284
285    /// Returns the authentication timeout.
286    pub fn auth(&self) -> Duration {
287        self.auth
288    }
289
290    /// Returns the channel-open timeout.
291    pub fn channel_open(&self) -> Duration {
292        self.channel_open
293    }
294}
295
296/// Keepalive configuration.
297#[non_exhaustive]
298#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
299#[derive(Clone, Debug, Eq, PartialEq)]
300pub struct Keepalive {
301    enabled: bool,
302    interval: Duration,
303    max_missed: u32,
304}
305
306impl Keepalive {
307    /// Creates a `Keepalive` with the given configuration.
308    pub fn new(enabled: bool, interval: Duration, max_missed: u32) -> Self {
309        Self {
310            enabled,
311            interval,
312            max_missed,
313        }
314    }
315
316    /// Whether keepalives are enabled.
317    pub fn enabled(&self) -> bool {
318        self.enabled
319    }
320
321    /// Interval between keepalive messages.
322    pub fn interval(&self) -> Duration {
323        self.interval
324    }
325
326    /// Number of unanswered keepalives before disconnecting.
327    pub fn max_missed(&self) -> u32 {
328        self.max_missed
329    }
330}
331
332impl Default for Keepalive {
333    fn default() -> Self {
334        Self {
335            enabled: true,
336            interval: Duration::from_secs(30),
337            max_missed: 3,
338        }
339    }
340}
341
342fn validate_sha256_fingerprint(value: &str) -> Result<()> {
343    let Some(rest) = value.strip_prefix("SHA256:") else {
344        return Err(Error::host_key(
345            HostKeyErrorKind::Unsupported,
346            "host-key fingerprint must start with SHA256:",
347        ));
348    };
349
350    if rest.is_empty() {
351        return Err(Error::host_key(
352            HostKeyErrorKind::Unavailable,
353            "host-key fingerprint must not be empty",
354        ));
355    }
356
357    if rest.bytes().any(|byte| byte.is_ascii_whitespace()) {
358        return Err(Error::host_key(
359            HostKeyErrorKind::Rejected,
360            "host-key fingerprint must not contain whitespace",
361        ));
362    }
363
364    Ok(())
365}
366
367#[cfg(test)]
368mod tests {
369    use crate::{
370        ClientConfig, Endpoint, Error, HostKeyFingerprint, HostKeyFingerprintAlgorithm,
371        HostKeyPolicy,
372    };
373
374    #[test]
375    fn server_config_defaults_to_loopback_ephemeral_port() {
376        let config = crate::ServerConfig::default();
377
378        assert_eq!(config.listen(), &Endpoint::new("127.0.0.1", 0));
379    }
380
381    #[test]
382    fn client_config_defaults_to_strict_host_key_policy() {
383        let config = ClientConfig::default();
384
385        assert_eq!(config.host_key_policy(), &HostKeyPolicy::Strict);
386        assert!(config.strict_host_key_checking());
387    }
388
389    #[test]
390    #[allow(deprecated)]
391    fn disabling_strict_host_key_checking_sets_accept_any_policy() {
392        let mut config = ClientConfig::default();
393
394        config.set_strict_host_key_checking(false);
395
396        assert_eq!(config.host_key_policy(), &HostKeyPolicy::InsecureAcceptAny);
397        assert!(!config.strict_host_key_checking());
398    }
399
400    #[test]
401    fn validates_sha256_host_key_fingerprints() {
402        let fingerprint = HostKeyFingerprint::sha256("SHA256:abc123+/=").unwrap();
403
404        assert_eq!(fingerprint.algorithm(), HostKeyFingerprintAlgorithm::Sha256);
405        assert_eq!(fingerprint.value(), "SHA256:abc123+/=");
406    }
407
408    #[test]
409    fn rejects_invalid_sha256_host_key_fingerprints() {
410        let error = HostKeyFingerprint::sha256("MD5:abc").unwrap_err();
411        assert!(matches!(error, Error::HostKey(_)));
412
413        let error = HostKeyFingerprint::sha256("SHA256:").unwrap_err();
414        assert!(matches!(error, Error::HostKey(_)));
415    }
416
417    #[test]
418    #[cfg(feature = "serde")]
419    fn client_config_serialization_skips_credentials() {
420        let mut config = ClientConfig::new(Endpoint::new("example.com", 2222));
421        config.add_credential(crate::Credential::password("secret"));
422
423        let serialized = serde_json::to_string(&config).unwrap();
424        let deserialized: ClientConfig = serde_json::from_str(&serialized).unwrap();
425
426        assert!(!serialized.contains("secret"));
427        assert!(!serialized.contains("credentials"));
428        assert!(deserialized.credentials().is_empty());
429    }
430
431    #[test]
432    fn client_config_debug_does_not_expose_credential_content() {
433        let mut config = ClientConfig::new(Endpoint::new("example.com", 2222));
434        config.add_credential(crate::Credential::password("my-secret-password"));
435        let debug = format!("{:?}", config);
436        assert!(!debug.contains("my-secret-password"));
437        assert!(debug.contains("Password(***)"));
438    }
439
440    #[test]
441    fn keepalive_defaults_enabled_with_30s_interval() {
442        let k = crate::Keepalive::default();
443        assert!(k.enabled());
444        assert_eq!(k.interval(), std::time::Duration::from_secs(30));
445        assert_eq!(k.max_missed(), 3);
446    }
447
448    #[test]
449    fn keepalive_new_stores_fields() {
450        let k = crate::Keepalive::new(true, std::time::Duration::from_secs(15), 5);
451        assert!(k.enabled());
452        assert_eq!(k.interval(), std::time::Duration::from_secs(15));
453        assert_eq!(k.max_missed(), 5);
454    }
455
456    #[test]
457    fn keepalive_disabled_still_stores_interval() {
458        let k = crate::Keepalive::new(false, std::time::Duration::from_secs(5), 1);
459        assert!(!k.enabled());
460        assert_eq!(k.interval(), std::time::Duration::from_secs(5));
461    }
462
463    #[test]
464    fn timeouts_new_stores_fields() {
465        use std::time::Duration;
466        let t = crate::Timeouts::new(
467            Duration::from_secs(5),
468            Duration::from_secs(10),
469            Duration::from_secs(2),
470        );
471        assert_eq!(t.connect(), Duration::from_secs(5));
472        assert_eq!(t.auth(), Duration::from_secs(10));
473        assert_eq!(t.channel_open(), Duration::from_secs(2));
474    }
475
476    #[test]
477    fn timeouts_defaults_are_reasonable() {
478        let t = crate::Timeouts::default();
479        assert!(t.connect() > std::time::Duration::ZERO);
480        assert!(t.auth() > std::time::Duration::ZERO);
481        assert!(t.channel_open() > std::time::Duration::ZERO);
482    }
483
484    #[test]
485    fn timeouts_with_zero_durations_stores_them() {
486        use std::time::Duration;
487        let t = crate::Timeouts::new(Duration::ZERO, Duration::ZERO, Duration::ZERO);
488        assert_eq!(t.connect(), Duration::ZERO);
489        assert_eq!(t.auth(), Duration::ZERO);
490        assert_eq!(t.channel_open(), Duration::ZERO);
491    }
492}