Skip to main content

ssh_commander_core/
connection_manager.rs

1use crate::desktop_protocol::{DesktopConnectRequest, DesktopProtocol, FrameUpdate};
2use crate::ftp_client::FtpClient;
3use crate::postgres::{PgConfig, PgPool};
4use crate::rdp_client::RdpClient;
5use crate::sftp_client::StandaloneSftpClient;
6use crate::ssh::{HostKeyStore, PtySession, SshClient, SshConfig};
7use crate::vnc_client::VncClient;
8use anyhow::Result;
9use std::collections::HashMap;
10use std::sync::Arc;
11use tokio::sync::RwLock;
12use tokio::sync::mpsc;
13use tokio_util::sync::CancellationToken;
14
15/// Canonical protocol tag for a managed connection.
16///
17/// Using an enum instead of a free-form string means every branch that inspects
18/// a connection is exhaustiveness-checked and callers can't typo a tag.
19#[derive(Debug, Clone, Copy, PartialEq, Eq)]
20pub enum ProtocolKind {
21    Ssh,
22    Sftp,
23    Ftp,
24    Rdp,
25    Vnc,
26    Postgres,
27}
28
29impl ProtocolKind {
30    pub fn as_str(self) -> &'static str {
31        match self {
32            ProtocolKind::Ssh => "SSH",
33            ProtocolKind::Sftp => "SFTP",
34            ProtocolKind::Ftp => "FTP",
35            ProtocolKind::Rdp => "RDP",
36            ProtocolKind::Vnc => "VNC",
37            ProtocolKind::Postgres => "POSTGRES",
38        }
39    }
40}
41
42/// A single managed connection, tagged by protocol.
43///
44/// Each variant owns its own `Arc<RwLock<_>>` — giving per-connection locking
45/// granularity, instead of a global map-level RwLock that would serialise
46/// every operation across unrelated connections.
47pub enum ManagedConnection {
48    Ssh(Arc<RwLock<SshClient>>),
49    Sftp(Arc<RwLock<StandaloneSftpClient>>),
50    Ftp(Arc<RwLock<FtpClient>>),
51    Desktop {
52        kind: ProtocolKind, // Rdp or Vnc
53        client: Arc<RwLock<Box<dyn DesktopProtocol>>>,
54    },
55    /// `PgPool` is internally `Sync` (manages its own locks), so no
56    /// outer `RwLock` is needed here — multiple sessions / tabs can
57    /// hit the pool concurrently from independent tasks.
58    Postgres(Arc<PgPool>),
59}
60
61impl ManagedConnection {
62    pub fn kind(&self) -> ProtocolKind {
63        match self {
64            ManagedConnection::Ssh(_) => ProtocolKind::Ssh,
65            ManagedConnection::Sftp(_) => ProtocolKind::Sftp,
66            ManagedConnection::Ftp(_) => ProtocolKind::Ftp,
67            ManagedConnection::Desktop { kind, .. } => *kind,
68            ManagedConnection::Postgres(_) => ProtocolKind::Postgres,
69        }
70    }
71}
72
73#[derive(Debug, Clone, Copy)]
74enum ConnectionSlotKind {
75    Ssh,
76    Sftp,
77    Ftp,
78    Desktop,
79    Postgres,
80}
81
82impl ConnectionSlotKind {
83    fn label(self) -> &'static str {
84        match self {
85            ConnectionSlotKind::Ssh => "SSH",
86            ConnectionSlotKind::Sftp => "SFTP",
87            ConnectionSlotKind::Ftp => "FTP",
88            ConnectionSlotKind::Desktop => "desktop",
89            ConnectionSlotKind::Postgres => "postgres",
90        }
91    }
92
93    fn matches(self, connection: &ManagedConnection) -> bool {
94        match self {
95            ConnectionSlotKind::Ssh => matches!(connection, ManagedConnection::Ssh(_)),
96            ConnectionSlotKind::Sftp => matches!(connection, ManagedConnection::Sftp(_)),
97            ConnectionSlotKind::Ftp => matches!(connection, ManagedConnection::Ftp(_)),
98            ConnectionSlotKind::Desktop => {
99                matches!(connection, ManagedConnection::Desktop { .. })
100            }
101            ConnectionSlotKind::Postgres => matches!(connection, ManagedConnection::Postgres(_)),
102        }
103    }
104}
105
106/// The connection manager owns the mapping from connection_id → its backing
107/// protocol state. Previously this was eight parallel hashmaps held together
108/// by convention; invariants (e.g. "if connection_types says SFTP, the sftp
109/// hashmap contains the id") are now enforced by the variant tag itself.
110pub struct ConnectionManager {
111    connections: Arc<RwLock<HashMap<String, ManagedConnection>>>,
112    pty_sessions: Arc<RwLock<HashMap<String, Arc<PtySession>>>>,
113    /// Generation counter per connection_id — incremented on each StartPty.
114    /// Used to prevent a stale Close from killing a newly created session.
115    pty_generations: Arc<RwLock<HashMap<String, u64>>>,
116    pending_connections: Arc<RwLock<HashMap<String, CancellationToken>>>,
117    /// Shared TOFU host-key store used by every SSH/SFTP connection.
118    host_keys: Arc<HostKeyStore>,
119}
120
121impl Default for ConnectionManager {
122    fn default() -> Self {
123        Self::new()
124    }
125}
126
127impl ConnectionManager {
128    pub fn new() -> Self {
129        Self::with_host_keys(Arc::new(HostKeyStore::new(HostKeyStore::default_path())))
130    }
131
132    pub fn with_host_keys(host_keys: Arc<HostKeyStore>) -> Self {
133        Self {
134            connections: Arc::new(RwLock::new(HashMap::new())),
135            pty_sessions: Arc::new(RwLock::new(HashMap::new())),
136            pty_generations: Arc::new(RwLock::new(HashMap::new())),
137            pending_connections: Arc::new(RwLock::new(HashMap::new())),
138            host_keys,
139        }
140    }
141
142    /// Access the shared host-key store. Used by the macOS bridge to
143    /// expose `forget` over FFI for the "Trust new key" flow on a
144    /// `HostKeyMismatch`.
145    pub fn host_keys(&self) -> Arc<HostKeyStore> {
146        self.host_keys.clone()
147    }
148
149    // =========================================================================
150    // Inspection
151    // =========================================================================
152
153    /// Protocol of an existing connection, or None if not registered.
154    pub async fn connection_kind(&self, id: &str) -> Option<ProtocolKind> {
155        let connections = self.connections.read().await;
156        connections.get(id).map(|c| c.kind())
157    }
158
159    /// Backward-compatible string form of `connection_kind`. Returns "SSH",
160    /// "SFTP", "FTP", "RDP", or "VNC". Prefer `connection_kind` in new code.
161    pub async fn get_connection_type(&self, id: &str) -> Option<String> {
162        self.connection_kind(id)
163            .await
164            .map(|k| k.as_str().to_string())
165    }
166
167    pub async fn list_connections(&self) -> Vec<String> {
168        let connections = self.connections.read().await;
169        connections.keys().cloned().collect()
170    }
171
172    /// Return the SSH client for a connection if it is an SSH connection.
173    pub async fn get_connection(&self, id: &str) -> Option<Arc<RwLock<SshClient>>> {
174        let connections = self.connections.read().await;
175        match connections.get(id) {
176            Some(ManagedConnection::Ssh(c)) => Some(c.clone()),
177            _ => None,
178        }
179    }
180
181    pub async fn get_sftp_client(&self, id: &str) -> Option<Arc<RwLock<StandaloneSftpClient>>> {
182        let connections = self.connections.read().await;
183        match connections.get(id) {
184            Some(ManagedConnection::Sftp(c)) => Some(c.clone()),
185            _ => None,
186        }
187    }
188
189    pub async fn get_ftp_client(&self, id: &str) -> Option<Arc<RwLock<FtpClient>>> {
190        let connections = self.connections.read().await;
191        match connections.get(id) {
192            Some(ManagedConnection::Ftp(c)) => Some(c.clone()),
193            _ => None,
194        }
195    }
196
197    pub async fn get_desktop_connection(
198        &self,
199        id: &str,
200    ) -> Option<Arc<RwLock<Box<dyn DesktopProtocol>>>> {
201        let connections = self.connections.read().await;
202        match connections.get(id) {
203            Some(ManagedConnection::Desktop { client, .. }) => Some(client.clone()),
204            _ => None,
205        }
206    }
207
208    pub async fn get_postgres_pool(&self, id: &str) -> Option<Arc<PgPool>> {
209        let connections = self.connections.read().await;
210        match connections.get(id) {
211            Some(ManagedConnection::Postgres(c)) => Some(c.clone()),
212            _ => None,
213        }
214    }
215
216    // =========================================================================
217    // SSH connection lifecycle (supports cancellation of a pending connect)
218    // =========================================================================
219
220    pub async fn create_connection(&self, connection_id: String, config: SshConfig) -> Result<()> {
221        let mut client = SshClient::new(self.host_keys.clone());
222        let cancel_token = self.register_pending_connection(&connection_id).await;
223
224        let connect_result = tokio::select! {
225            res = client.connect(&config) => res,
226            _ = cancel_token.cancelled() => Err(anyhow::anyhow!("Connection cancelled by user")),
227        };
228
229        self.clear_pending_connection(&connection_id).await;
230
231        connect_result?;
232        self.replace_managed_connection(
233            connection_id,
234            ManagedConnection::Ssh(Arc::new(RwLock::new(client))),
235        )
236        .await
237    }
238
239    async fn register_pending_connection(&self, connection_id: &str) -> CancellationToken {
240        let token = CancellationToken::new();
241        let mut pending = self.pending_connections.write().await;
242        pending.insert(connection_id.to_string(), token.clone());
243        token
244    }
245
246    async fn clear_pending_connection(&self, connection_id: &str) {
247        let mut pending = self.pending_connections.write().await;
248        pending.remove(connection_id);
249    }
250
251    async fn disconnect_managed_connection(
252        &self,
253        connection_id: &str,
254        connection: ManagedConnection,
255    ) -> Result<()> {
256        match connection {
257            ManagedConnection::Ssh(client) => {
258                {
259                    let mut pty_sessions = self.pty_sessions.write().await;
260                    if let Some(session) = pty_sessions.remove(connection_id) {
261                        session.cancel.cancel();
262                    }
263                }
264                {
265                    let mut generations = self.pty_generations.write().await;
266                    generations.remove(connection_id);
267                }
268                let mut client = client.write().await;
269                client.disconnect().await?;
270            }
271            ManagedConnection::Sftp(client) => {
272                let mut client = client.write().await;
273                client.disconnect().await?;
274            }
275            ManagedConnection::Ftp(client) => {
276                let mut client = client.write().await;
277                client.disconnect().await?;
278            }
279            ManagedConnection::Desktop { client, .. } => {
280                let mut client = client.write().await;
281                client.disconnect().await?;
282            }
283            ManagedConnection::Postgres(pool) => {
284                pool.shutdown().await;
285            }
286        }
287        Ok(())
288    }
289
290    async fn replace_managed_connection(
291        &self,
292        connection_id: String,
293        replacement: ManagedConnection,
294    ) -> Result<()> {
295        let previous = {
296            let mut connections = self.connections.write().await;
297            connections.remove(&connection_id)
298        };
299
300        if let Some(previous) = previous {
301            self.disconnect_managed_connection(&connection_id, previous)
302                .await?;
303        }
304
305        let mut connections = self.connections.write().await;
306        connections.insert(connection_id, replacement);
307        Ok(())
308    }
309
310    async fn take_connection_if_kind(
311        &self,
312        connection_id: &str,
313        expected: ConnectionSlotKind,
314    ) -> Result<Option<ManagedConnection>> {
315        let mut connections = self.connections.write().await;
316        let Some(current) = connections.get(connection_id) else {
317            return Ok(None);
318        };
319
320        if !expected.matches(current) {
321            return Err(anyhow::anyhow!(
322                "Connection '{}' is {}, not {}",
323                connection_id,
324                current.kind().as_str(),
325                expected.label()
326            ));
327        }
328
329        Ok(connections.remove(connection_id))
330    }
331
332    pub async fn cancel_pending_connection(&self, connection_id: &str) -> bool {
333        let mut pending = self.pending_connections.write().await;
334        if let Some(token) = pending.remove(connection_id) {
335            token.cancel();
336            true
337        } else {
338            false
339        }
340    }
341
342    /// Close the SSH connection for `connection_id` (if it is SSH). Also tears
343    /// down any associated PTY session and prunes the generation counter so it
344    /// cannot leak across reconnects.
345    pub async fn close_connection(&self, connection_id: &str) -> Result<()> {
346        if let Some(connection) = self
347            .take_connection_if_kind(connection_id, ConnectionSlotKind::Ssh)
348            .await?
349        {
350            self.disconnect_managed_connection(connection_id, connection)
351                .await?;
352        }
353        Ok(())
354    }
355
356    // =========================================================================
357    // PTY (interactive shell) management — only valid on SSH connections.
358    // =========================================================================
359
360    /// Start a PTY shell connection (like ttyd does).
361    /// Enables interactive commands: vim, less, more, top, htop, etc.
362    pub async fn start_pty_connection(
363        &self,
364        connection_id: &str,
365        cols: u32,
366        rows: u32,
367    ) -> Result<u64> {
368        let client = self
369            .get_connection(connection_id)
370            .await
371            .ok_or_else(|| anyhow::anyhow!("Connection not found"))?;
372
373        // Cancel and remove any existing PTY session for this connection first.
374        // This ensures the old SSH channel and reader task are torn down before
375        // we create a new one, preventing orphaned sessions.
376        {
377            let mut pty_sessions = self.pty_sessions.write().await;
378            if let Some(old_session) = pty_sessions.remove(connection_id) {
379                old_session.cancel.cancel();
380                tracing::info!("Cancelled old PTY session for {}", connection_id);
381            }
382        }
383
384        let pty = {
385            let client = client.read().await;
386            client.create_pty_session(cols, rows).await?
387        };
388
389        // Bump generation so any in-flight Close for the old session is ignored.
390        let mut generations = self.pty_generations.write().await;
391        let generation_entry = generations.entry(connection_id.to_string()).or_insert(0);
392        *generation_entry += 1;
393        let current_gen = *generation_entry;
394        drop(generations);
395
396        let mut pty_sessions = self.pty_sessions.write().await;
397        pty_sessions.insert(connection_id.to_string(), Arc::new(pty));
398
399        Ok(current_gen)
400    }
401
402    /// Send data to PTY (user input).
403    ///
404    /// Backpressure: if the input channel is full we await `send`, preserving
405    /// keystroke order.
406    pub async fn write_to_pty(&self, connection_id: &str, data: Vec<u8>) -> Result<()> {
407        let tx = {
408            let pty_sessions = self.pty_sessions.read().await;
409            let pty = pty_sessions
410                .get(connection_id)
411                .ok_or_else(|| anyhow::anyhow!("PTY connection not found"))?;
412            pty.input_tx.clone()
413        };
414
415        tx.send(data)
416            .await
417            .map_err(|_| anyhow::anyhow!("PTY channel closed"))
418    }
419
420    /// Capture the active `PtySession` for a connection. Used by the macOS
421    /// bridge to spawn an output-forwarder task that holds a stable handle
422    /// to the session's `output_rx` for the lifetime of that PTY, even if
423    /// `start_pty_connection` is later called again for the same connection
424    /// (which would replace the entry in `pty_sessions`).
425    pub async fn get_pty_session(&self, connection_id: &str) -> Option<Arc<PtySession>> {
426        self.pty_sessions.read().await.get(connection_id).cloned()
427    }
428
429    /// Read a burst of PTY output — blocks until data arrives, then drains any
430    /// additional already-queued chunks up to `max_bytes`.
431    pub async fn read_pty_burst(&self, connection_id: &str, max_bytes: usize) -> Result<Vec<u8>> {
432        let pty = {
433            let pty_sessions = self.pty_sessions.read().await;
434            pty_sessions
435                .get(connection_id)
436                .cloned()
437                .ok_or_else(|| anyhow::anyhow!("PTY connection not found"))?
438        };
439
440        let mut rx = pty.output_rx.lock().await;
441
442        let mut out = match rx.recv().await {
443            Some(data) => data,
444            None => return Err(anyhow::anyhow!("PTY connection closed")),
445        };
446
447        while out.len() < max_bytes {
448            match rx.try_recv() {
449                Ok(more) => out.extend_from_slice(&more),
450                Err(tokio::sync::mpsc::error::TryRecvError::Empty) => break,
451                Err(tokio::sync::mpsc::error::TryRecvError::Disconnected) => break,
452            }
453        }
454
455        Ok(out)
456    }
457
458    /// Close PTY connection, but only if the generation matches.
459    pub async fn close_pty_connection(
460        &self,
461        connection_id: &str,
462        expected_gen: Option<u64>,
463    ) -> Result<()> {
464        if let Some(expected_generation) = expected_gen {
465            let generations = self.pty_generations.read().await;
466            let current_gen = generations.get(connection_id).copied().unwrap_or(0);
467            if current_gen != expected_generation {
468                tracing::info!(
469                    "Ignoring stale Close for {} (gen {} != current {})",
470                    connection_id,
471                    expected_generation,
472                    current_gen
473                );
474                return Ok(());
475            }
476        }
477        let mut pty_sessions = self.pty_sessions.write().await;
478        if let Some(session) = pty_sessions.remove(connection_id) {
479            session.cancel.cancel();
480        }
481        Ok(())
482    }
483
484    /// Get the cancellation token for a PTY session (used by WebSocket reader tasks).
485    pub async fn get_pty_cancel_token(&self, connection_id: &str) -> Option<CancellationToken> {
486        let sessions = self.pty_sessions.read().await;
487        sessions.get(connection_id).map(|s| s.cancel.clone())
488    }
489
490    /// Resize PTY terminal (send window-change to remote SSH channel)
491    pub async fn resize_pty(&self, connection_id: &str, cols: u32, rows: u32) -> Result<()> {
492        let pty_sessions = self.pty_sessions.read().await;
493        let pty = pty_sessions
494            .get(connection_id)
495            .ok_or_else(|| anyhow::anyhow!("PTY connection not found"))?;
496
497        pty.resize_tx
498            .send((cols, rows))
499            .await
500            .map_err(|_| anyhow::anyhow!("PTY resize channel closed"))
501    }
502
503    // =========================================================================
504    // Standalone SFTP
505    // =========================================================================
506
507    pub async fn create_sftp_connection(
508        &self,
509        connection_id: String,
510        config: crate::sftp_client::SftpConfig,
511    ) -> Result<()> {
512        let client = StandaloneSftpClient::connect(&config, self.host_keys.clone()).await?;
513        self.replace_managed_connection(
514            connection_id,
515            ManagedConnection::Sftp(Arc::new(RwLock::new(client))),
516        )
517        .await
518    }
519
520    pub async fn close_sftp_connection(&self, connection_id: &str) -> Result<()> {
521        if let Some(connection) = self
522            .take_connection_if_kind(connection_id, ConnectionSlotKind::Sftp)
523            .await?
524        {
525            self.disconnect_managed_connection(connection_id, connection)
526                .await?;
527        }
528        Ok(())
529    }
530
531    // =========================================================================
532    // FTP / FTPS
533    // =========================================================================
534
535    pub async fn create_ftp_connection(
536        &self,
537        connection_id: String,
538        config: crate::ftp_client::FtpConfig,
539    ) -> Result<()> {
540        let client = FtpClient::connect(&config).await?;
541        self.replace_managed_connection(
542            connection_id,
543            ManagedConnection::Ftp(Arc::new(RwLock::new(client))),
544        )
545        .await
546    }
547
548    pub async fn close_ftp_connection(&self, connection_id: &str) -> Result<()> {
549        if let Some(connection) = self
550            .take_connection_if_kind(connection_id, ConnectionSlotKind::Ftp)
551            .await?
552        {
553            self.disconnect_managed_connection(connection_id, connection)
554                .await?;
555        }
556        Ok(())
557    }
558
559    // =========================================================================
560    // Remote desktop (RDP / VNC)
561    // =========================================================================
562
563    pub async fn create_desktop_connection(
564        &self,
565        connection_id: String,
566        request: &DesktopConnectRequest,
567    ) -> Result<(u16, u16)> {
568        use crate::desktop_protocol::DesktopKind;
569        let (kind, client): (ProtocolKind, Box<dyn DesktopProtocol>) = match request.protocol {
570            DesktopKind::Rdp => {
571                let config = request.to_rdp_config();
572                (
573                    ProtocolKind::Rdp,
574                    Box::new(RdpClient::connect(&config).await?),
575                )
576            }
577            DesktopKind::Vnc => {
578                let config = request.to_vnc_config();
579                (
580                    ProtocolKind::Vnc,
581                    Box::new(VncClient::connect(&config).await?),
582                )
583            }
584        };
585
586        let (w, h) = client.desktop_size();
587
588        self.replace_managed_connection(
589            connection_id,
590            ManagedConnection::Desktop {
591                kind,
592                client: Arc::new(RwLock::new(client)),
593            },
594        )
595        .await?;
596
597        Ok((w, h))
598    }
599
600    pub async fn close_desktop_connection(&self, connection_id: &str) -> Result<()> {
601        if let Some(connection) = self
602            .take_connection_if_kind(connection_id, ConnectionSlotKind::Desktop)
603            .await?
604        {
605            self.disconnect_managed_connection(connection_id, connection)
606                .await?;
607        }
608        Ok(())
609    }
610
611    // =========================================================================
612    // Postgres
613    // =========================================================================
614
615    pub async fn create_postgres_connection(
616        &self,
617        connection_id: String,
618        config: PgConfig,
619    ) -> Result<()> {
620        // Tunneled configs need an already-open SSH connection in this
621        // same manager. Resolve it up front so a missing source is a
622        // single typed error instead of failing partway through connect.
623        let ssh_client = if let Some(tunnel) = config.ssh_tunnel.as_ref() {
624            match self.get_connection(&tunnel.ssh_connection_id).await {
625                Some(c) => Some(c),
626                None => {
627                    return Err(anyhow::Error::from(
628                        crate::postgres::PgError::TunnelSourceMissing(format!(
629                            "ssh connection '{}' is not registered or has been closed",
630                            tunnel.ssh_connection_id
631                        )),
632                    ));
633                }
634            }
635        } else {
636            None
637        };
638
639        let cancel_token = self.register_pending_connection(&connection_id).await;
640        let connect_result = tokio::select! {
641            res = PgPool::connect(config, ssh_client) => res.map_err(anyhow::Error::from),
642            _ = cancel_token.cancelled() => Err(anyhow::anyhow!("Connection cancelled by user")),
643        };
644        self.clear_pending_connection(&connection_id).await;
645
646        let pool = connect_result?;
647        self.replace_managed_connection(connection_id, ManagedConnection::Postgres(pool))
648            .await
649    }
650
651    pub async fn close_postgres_connection(&self, connection_id: &str) -> Result<()> {
652        if let Some(connection) = self
653            .take_connection_if_kind(connection_id, ConnectionSlotKind::Postgres)
654            .await?
655        {
656            self.disconnect_managed_connection(connection_id, connection)
657                .await?;
658        }
659        Ok(())
660    }
661
662    /// Start the frame update loop for a desktop connection.
663    ///
664    /// Not yet wired up to the WebSocket server — kept here so the RDP/VNC
665    /// stubs have a concrete dispatch point once the protocol clients gain
666    /// real implementations. Remove the allow once a caller appears.
667    #[allow(dead_code)]
668    pub async fn start_desktop_stream(
669        &self,
670        connection_id: &str,
671        frame_tx: mpsc::UnboundedSender<FrameUpdate>,
672        cancel: CancellationToken,
673    ) -> Result<()> {
674        let client = self
675            .get_desktop_connection(connection_id)
676            .await
677            .ok_or_else(|| anyhow::anyhow!("Desktop connection not found: {}", connection_id))?;
678        let client = client.read().await;
679        client.start_frame_loop(frame_tx, cancel).await
680    }
681}
682
683// =============================================================================
684// Unit tests
685// =============================================================================
686#[cfg(test)]
687mod tests {
688    use super::*;
689    use async_trait::async_trait;
690
691    struct TestDesktopClient;
692
693    #[async_trait]
694    impl DesktopProtocol for TestDesktopClient {
695        async fn start_frame_loop(
696            &self,
697            _frame_tx: mpsc::UnboundedSender<FrameUpdate>,
698            _cancel: CancellationToken,
699        ) -> Result<()> {
700            Ok(())
701        }
702
703        async fn send_key(&self, _key_code: u32, _down: bool) -> Result<()> {
704            Ok(())
705        }
706
707        async fn send_pointer(&self, _x: u16, _y: u16, _button_mask: u8) -> Result<()> {
708            Ok(())
709        }
710
711        async fn request_full_frame(&self) -> Result<()> {
712            Ok(())
713        }
714
715        async fn set_clipboard(&self, _text: String) -> Result<()> {
716            Ok(())
717        }
718
719        fn desktop_size(&self) -> (u16, u16) {
720            (1024, 768)
721        }
722
723        async fn resize(&mut self, _width: u16, _height: u16) -> Result<()> {
724            Ok(())
725        }
726
727        async fn disconnect(&mut self) -> Result<()> {
728            Ok(())
729        }
730    }
731
732    fn disconnected_ssh_client() -> SshClient {
733        SshClient::new(Arc::new(HostKeyStore::new(
734            std::env::temp_dir().join("r-shell-test-known-hosts"),
735        )))
736    }
737
738    #[tokio::test]
739    async fn test_new_manager_has_no_connections() {
740        let mgr = ConnectionManager::new();
741        assert!(mgr.list_connections().await.is_empty());
742    }
743
744    #[tokio::test]
745    async fn test_connection_kind_returns_none_for_unknown() {
746        let mgr = ConnectionManager::new();
747        assert!(mgr.connection_kind("unknown-id").await.is_none());
748        assert!(mgr.get_connection_type("unknown-id").await.is_none());
749    }
750
751    #[tokio::test]
752    async fn test_cancel_nonexistent_pending_connection() {
753        let mgr = ConnectionManager::new();
754        assert!(!mgr.cancel_pending_connection("ghost").await);
755    }
756
757    #[tokio::test]
758    async fn test_protocol_kind_round_trip() {
759        assert_eq!(ProtocolKind::Ssh.as_str(), "SSH");
760        assert_eq!(ProtocolKind::Sftp.as_str(), "SFTP");
761        assert_eq!(ProtocolKind::Ftp.as_str(), "FTP");
762        assert_eq!(ProtocolKind::Rdp.as_str(), "RDP");
763        assert_eq!(ProtocolKind::Vnc.as_str(), "VNC");
764        assert_eq!(ProtocolKind::Postgres.as_str(), "POSTGRES");
765    }
766
767    #[tokio::test]
768    async fn test_close_postgres_of_unknown_id_is_noop() {
769        let mgr = ConnectionManager::new();
770        let result = mgr.close_postgres_connection("ghost").await;
771        assert!(result.is_ok());
772    }
773
774    #[tokio::test]
775    async fn test_close_sftp_of_unknown_id_is_noop() {
776        let mgr = ConnectionManager::new();
777        let result = mgr.close_sftp_connection("ghost").await;
778        assert!(result.is_ok());
779    }
780
781    #[tokio::test]
782    async fn test_close_ftp_of_unknown_id_is_noop() {
783        let mgr = ConnectionManager::new();
784        let result = mgr.close_ftp_connection("ghost").await;
785        assert!(result.is_ok());
786    }
787
788    #[tokio::test]
789    async fn test_close_connection_rejects_non_ssh_without_removing_it() {
790        let mgr = ConnectionManager::new();
791        {
792            let mut connections = mgr.connections.write().await;
793            connections.insert(
794                "desktop".to_string(),
795                ManagedConnection::Desktop {
796                    kind: ProtocolKind::Rdp,
797                    client: Arc::new(RwLock::new(Box::new(TestDesktopClient))),
798                },
799            );
800        }
801
802        let err = mgr
803            .close_connection("desktop")
804            .await
805            .expect_err("closing an RDP connection through the SSH API must fail");
806        assert!(err.to_string().contains("not SSH"));
807        assert_eq!(
808            mgr.connection_kind("desktop").await,
809            Some(ProtocolKind::Rdp)
810        );
811    }
812
813    #[tokio::test]
814    async fn test_close_desktop_connection_rejects_ssh_without_removing_it() {
815        let mgr = ConnectionManager::new();
816        {
817            let mut connections = mgr.connections.write().await;
818            connections.insert(
819                "ssh".to_string(),
820                ManagedConnection::Ssh(Arc::new(RwLock::new(disconnected_ssh_client()))),
821            );
822        }
823
824        let err = mgr
825            .close_desktop_connection("ssh")
826            .await
827            .expect_err("closing an SSH connection through the desktop API must fail");
828        assert!(err.to_string().contains("not desktop"));
829        assert_eq!(mgr.connection_kind("ssh").await, Some(ProtocolKind::Ssh));
830    }
831}