1use crate::desktop_protocol::{DesktopConnectRequest, DesktopProtocol, FrameUpdate};
2use crate::ftp_client::FtpClient;
3use crate::postgres::{PgConfig, PgPool};
4use crate::rdp_client::RdpClient;
5use crate::sftp_client::StandaloneSftpClient;
6use crate::ssh::{HostKeyStore, PtySession, SshClient, SshConfig};
7use crate::vnc_client::VncClient;
8use anyhow::Result;
9use std::collections::HashMap;
10use std::sync::Arc;
11use tokio::sync::RwLock;
12use tokio::sync::mpsc;
13use tokio_util::sync::CancellationToken;
14
15#[derive(Debug, Clone, Copy, PartialEq, Eq)]
20pub enum ProtocolKind {
21 Ssh,
22 Sftp,
23 Ftp,
24 Rdp,
25 Vnc,
26 Postgres,
27}
28
29impl ProtocolKind {
30 pub fn as_str(self) -> &'static str {
31 match self {
32 ProtocolKind::Ssh => "SSH",
33 ProtocolKind::Sftp => "SFTP",
34 ProtocolKind::Ftp => "FTP",
35 ProtocolKind::Rdp => "RDP",
36 ProtocolKind::Vnc => "VNC",
37 ProtocolKind::Postgres => "POSTGRES",
38 }
39 }
40}
41
42pub enum ManagedConnection {
48 Ssh(Arc<RwLock<SshClient>>),
49 Sftp(Arc<RwLock<StandaloneSftpClient>>),
50 Ftp(Arc<RwLock<FtpClient>>),
51 Desktop {
52 kind: ProtocolKind, client: Arc<RwLock<Box<dyn DesktopProtocol>>>,
54 },
55 Postgres(Arc<PgPool>),
59}
60
61impl ManagedConnection {
62 pub fn kind(&self) -> ProtocolKind {
63 match self {
64 ManagedConnection::Ssh(_) => ProtocolKind::Ssh,
65 ManagedConnection::Sftp(_) => ProtocolKind::Sftp,
66 ManagedConnection::Ftp(_) => ProtocolKind::Ftp,
67 ManagedConnection::Desktop { kind, .. } => *kind,
68 ManagedConnection::Postgres(_) => ProtocolKind::Postgres,
69 }
70 }
71}
72
73#[derive(Debug, Clone, Copy)]
74enum ConnectionSlotKind {
75 Ssh,
76 Sftp,
77 Ftp,
78 Desktop,
79 Postgres,
80}
81
82impl ConnectionSlotKind {
83 fn label(self) -> &'static str {
84 match self {
85 ConnectionSlotKind::Ssh => "SSH",
86 ConnectionSlotKind::Sftp => "SFTP",
87 ConnectionSlotKind::Ftp => "FTP",
88 ConnectionSlotKind::Desktop => "desktop",
89 ConnectionSlotKind::Postgres => "postgres",
90 }
91 }
92
93 fn matches(self, connection: &ManagedConnection) -> bool {
94 match self {
95 ConnectionSlotKind::Ssh => matches!(connection, ManagedConnection::Ssh(_)),
96 ConnectionSlotKind::Sftp => matches!(connection, ManagedConnection::Sftp(_)),
97 ConnectionSlotKind::Ftp => matches!(connection, ManagedConnection::Ftp(_)),
98 ConnectionSlotKind::Desktop => {
99 matches!(connection, ManagedConnection::Desktop { .. })
100 }
101 ConnectionSlotKind::Postgres => matches!(connection, ManagedConnection::Postgres(_)),
102 }
103 }
104}
105
106pub struct ConnectionManager {
111 connections: Arc<RwLock<HashMap<String, ManagedConnection>>>,
112 pty_sessions: Arc<RwLock<HashMap<String, Arc<PtySession>>>>,
113 pty_generations: Arc<RwLock<HashMap<String, u64>>>,
116 pending_connections: Arc<RwLock<HashMap<String, CancellationToken>>>,
117 host_keys: Arc<HostKeyStore>,
119}
120
121impl Default for ConnectionManager {
122 fn default() -> Self {
123 Self::new()
124 }
125}
126
127impl ConnectionManager {
128 pub fn new() -> Self {
129 Self::with_host_keys(Arc::new(HostKeyStore::new(HostKeyStore::default_path())))
130 }
131
132 pub fn with_host_keys(host_keys: Arc<HostKeyStore>) -> Self {
133 Self {
134 connections: Arc::new(RwLock::new(HashMap::new())),
135 pty_sessions: Arc::new(RwLock::new(HashMap::new())),
136 pty_generations: Arc::new(RwLock::new(HashMap::new())),
137 pending_connections: Arc::new(RwLock::new(HashMap::new())),
138 host_keys,
139 }
140 }
141
142 pub fn host_keys(&self) -> Arc<HostKeyStore> {
146 self.host_keys.clone()
147 }
148
149 pub async fn connection_kind(&self, id: &str) -> Option<ProtocolKind> {
155 let connections = self.connections.read().await;
156 connections.get(id).map(|c| c.kind())
157 }
158
159 pub async fn get_connection_type(&self, id: &str) -> Option<String> {
162 self.connection_kind(id)
163 .await
164 .map(|k| k.as_str().to_string())
165 }
166
167 pub async fn list_connections(&self) -> Vec<String> {
168 let connections = self.connections.read().await;
169 connections.keys().cloned().collect()
170 }
171
172 pub async fn get_connection(&self, id: &str) -> Option<Arc<RwLock<SshClient>>> {
174 let connections = self.connections.read().await;
175 match connections.get(id) {
176 Some(ManagedConnection::Ssh(c)) => Some(c.clone()),
177 _ => None,
178 }
179 }
180
181 pub async fn get_sftp_client(&self, id: &str) -> Option<Arc<RwLock<StandaloneSftpClient>>> {
182 let connections = self.connections.read().await;
183 match connections.get(id) {
184 Some(ManagedConnection::Sftp(c)) => Some(c.clone()),
185 _ => None,
186 }
187 }
188
189 pub async fn get_ftp_client(&self, id: &str) -> Option<Arc<RwLock<FtpClient>>> {
190 let connections = self.connections.read().await;
191 match connections.get(id) {
192 Some(ManagedConnection::Ftp(c)) => Some(c.clone()),
193 _ => None,
194 }
195 }
196
197 pub async fn get_desktop_connection(
198 &self,
199 id: &str,
200 ) -> Option<Arc<RwLock<Box<dyn DesktopProtocol>>>> {
201 let connections = self.connections.read().await;
202 match connections.get(id) {
203 Some(ManagedConnection::Desktop { client, .. }) => Some(client.clone()),
204 _ => None,
205 }
206 }
207
208 pub async fn get_postgres_pool(&self, id: &str) -> Option<Arc<PgPool>> {
209 let connections = self.connections.read().await;
210 match connections.get(id) {
211 Some(ManagedConnection::Postgres(c)) => Some(c.clone()),
212 _ => None,
213 }
214 }
215
216 pub async fn create_connection(&self, connection_id: String, config: SshConfig) -> Result<()> {
221 let mut client = SshClient::new(self.host_keys.clone());
222 let cancel_token = self.register_pending_connection(&connection_id).await;
223
224 let connect_result = tokio::select! {
225 res = client.connect(&config) => res,
226 _ = cancel_token.cancelled() => Err(anyhow::anyhow!("Connection cancelled by user")),
227 };
228
229 self.clear_pending_connection(&connection_id).await;
230
231 connect_result?;
232 self.replace_managed_connection(
233 connection_id,
234 ManagedConnection::Ssh(Arc::new(RwLock::new(client))),
235 )
236 .await
237 }
238
239 async fn register_pending_connection(&self, connection_id: &str) -> CancellationToken {
240 let token = CancellationToken::new();
241 let mut pending = self.pending_connections.write().await;
242 pending.insert(connection_id.to_string(), token.clone());
243 token
244 }
245
246 async fn clear_pending_connection(&self, connection_id: &str) {
247 let mut pending = self.pending_connections.write().await;
248 pending.remove(connection_id);
249 }
250
251 async fn disconnect_managed_connection(
252 &self,
253 connection_id: &str,
254 connection: ManagedConnection,
255 ) -> Result<()> {
256 match connection {
257 ManagedConnection::Ssh(client) => {
258 {
259 let mut pty_sessions = self.pty_sessions.write().await;
260 if let Some(session) = pty_sessions.remove(connection_id) {
261 session.cancel.cancel();
262 }
263 }
264 {
265 let mut generations = self.pty_generations.write().await;
266 generations.remove(connection_id);
267 }
268 let mut client = client.write().await;
269 client.disconnect().await?;
270 }
271 ManagedConnection::Sftp(client) => {
272 let mut client = client.write().await;
273 client.disconnect().await?;
274 }
275 ManagedConnection::Ftp(client) => {
276 let mut client = client.write().await;
277 client.disconnect().await?;
278 }
279 ManagedConnection::Desktop { client, .. } => {
280 let mut client = client.write().await;
281 client.disconnect().await?;
282 }
283 ManagedConnection::Postgres(pool) => {
284 pool.shutdown().await;
285 }
286 }
287 Ok(())
288 }
289
290 async fn replace_managed_connection(
291 &self,
292 connection_id: String,
293 replacement: ManagedConnection,
294 ) -> Result<()> {
295 let previous = {
296 let mut connections = self.connections.write().await;
297 connections.remove(&connection_id)
298 };
299
300 if let Some(previous) = previous {
301 self.disconnect_managed_connection(&connection_id, previous)
302 .await?;
303 }
304
305 let mut connections = self.connections.write().await;
306 connections.insert(connection_id, replacement);
307 Ok(())
308 }
309
310 async fn take_connection_if_kind(
311 &self,
312 connection_id: &str,
313 expected: ConnectionSlotKind,
314 ) -> Result<Option<ManagedConnection>> {
315 let mut connections = self.connections.write().await;
316 let Some(current) = connections.get(connection_id) else {
317 return Ok(None);
318 };
319
320 if !expected.matches(current) {
321 return Err(anyhow::anyhow!(
322 "Connection '{}' is {}, not {}",
323 connection_id,
324 current.kind().as_str(),
325 expected.label()
326 ));
327 }
328
329 Ok(connections.remove(connection_id))
330 }
331
332 pub async fn cancel_pending_connection(&self, connection_id: &str) -> bool {
333 let mut pending = self.pending_connections.write().await;
334 if let Some(token) = pending.remove(connection_id) {
335 token.cancel();
336 true
337 } else {
338 false
339 }
340 }
341
342 pub async fn close_connection(&self, connection_id: &str) -> Result<()> {
346 if let Some(connection) = self
347 .take_connection_if_kind(connection_id, ConnectionSlotKind::Ssh)
348 .await?
349 {
350 self.disconnect_managed_connection(connection_id, connection)
351 .await?;
352 }
353 Ok(())
354 }
355
356 pub async fn start_pty_connection(
363 &self,
364 connection_id: &str,
365 cols: u32,
366 rows: u32,
367 ) -> Result<u64> {
368 let client = self
369 .get_connection(connection_id)
370 .await
371 .ok_or_else(|| anyhow::anyhow!("Connection not found"))?;
372
373 {
377 let mut pty_sessions = self.pty_sessions.write().await;
378 if let Some(old_session) = pty_sessions.remove(connection_id) {
379 old_session.cancel.cancel();
380 tracing::info!("Cancelled old PTY session for {}", connection_id);
381 }
382 }
383
384 let pty = {
385 let client = client.read().await;
386 client.create_pty_session(cols, rows).await?
387 };
388
389 let mut generations = self.pty_generations.write().await;
391 let generation_entry = generations.entry(connection_id.to_string()).or_insert(0);
392 *generation_entry += 1;
393 let current_gen = *generation_entry;
394 drop(generations);
395
396 let mut pty_sessions = self.pty_sessions.write().await;
397 pty_sessions.insert(connection_id.to_string(), Arc::new(pty));
398
399 Ok(current_gen)
400 }
401
402 pub async fn write_to_pty(&self, connection_id: &str, data: Vec<u8>) -> Result<()> {
407 let tx = {
408 let pty_sessions = self.pty_sessions.read().await;
409 let pty = pty_sessions
410 .get(connection_id)
411 .ok_or_else(|| anyhow::anyhow!("PTY connection not found"))?;
412 pty.input_tx.clone()
413 };
414
415 tx.send(data)
416 .await
417 .map_err(|_| anyhow::anyhow!("PTY channel closed"))
418 }
419
420 pub async fn get_pty_session(&self, connection_id: &str) -> Option<Arc<PtySession>> {
426 self.pty_sessions.read().await.get(connection_id).cloned()
427 }
428
429 pub async fn read_pty_burst(&self, connection_id: &str, max_bytes: usize) -> Result<Vec<u8>> {
432 let pty = {
433 let pty_sessions = self.pty_sessions.read().await;
434 pty_sessions
435 .get(connection_id)
436 .cloned()
437 .ok_or_else(|| anyhow::anyhow!("PTY connection not found"))?
438 };
439
440 let mut rx = pty.output_rx.lock().await;
441
442 let mut out = match rx.recv().await {
443 Some(data) => data,
444 None => return Err(anyhow::anyhow!("PTY connection closed")),
445 };
446
447 while out.len() < max_bytes {
448 match rx.try_recv() {
449 Ok(more) => out.extend_from_slice(&more),
450 Err(tokio::sync::mpsc::error::TryRecvError::Empty) => break,
451 Err(tokio::sync::mpsc::error::TryRecvError::Disconnected) => break,
452 }
453 }
454
455 Ok(out)
456 }
457
458 pub async fn close_pty_connection(
460 &self,
461 connection_id: &str,
462 expected_gen: Option<u64>,
463 ) -> Result<()> {
464 if let Some(expected_generation) = expected_gen {
465 let generations = self.pty_generations.read().await;
466 let current_gen = generations.get(connection_id).copied().unwrap_or(0);
467 if current_gen != expected_generation {
468 tracing::info!(
469 "Ignoring stale Close for {} (gen {} != current {})",
470 connection_id,
471 expected_generation,
472 current_gen
473 );
474 return Ok(());
475 }
476 }
477 let mut pty_sessions = self.pty_sessions.write().await;
478 if let Some(session) = pty_sessions.remove(connection_id) {
479 session.cancel.cancel();
480 }
481 Ok(())
482 }
483
484 pub async fn get_pty_cancel_token(&self, connection_id: &str) -> Option<CancellationToken> {
486 let sessions = self.pty_sessions.read().await;
487 sessions.get(connection_id).map(|s| s.cancel.clone())
488 }
489
490 pub async fn resize_pty(&self, connection_id: &str, cols: u32, rows: u32) -> Result<()> {
492 let pty_sessions = self.pty_sessions.read().await;
493 let pty = pty_sessions
494 .get(connection_id)
495 .ok_or_else(|| anyhow::anyhow!("PTY connection not found"))?;
496
497 pty.resize_tx
498 .send((cols, rows))
499 .await
500 .map_err(|_| anyhow::anyhow!("PTY resize channel closed"))
501 }
502
503 pub async fn create_sftp_connection(
508 &self,
509 connection_id: String,
510 config: crate::sftp_client::SftpConfig,
511 ) -> Result<()> {
512 let client = StandaloneSftpClient::connect(&config, self.host_keys.clone()).await?;
513 self.replace_managed_connection(
514 connection_id,
515 ManagedConnection::Sftp(Arc::new(RwLock::new(client))),
516 )
517 .await
518 }
519
520 pub async fn close_sftp_connection(&self, connection_id: &str) -> Result<()> {
521 if let Some(connection) = self
522 .take_connection_if_kind(connection_id, ConnectionSlotKind::Sftp)
523 .await?
524 {
525 self.disconnect_managed_connection(connection_id, connection)
526 .await?;
527 }
528 Ok(())
529 }
530
531 pub async fn create_ftp_connection(
536 &self,
537 connection_id: String,
538 config: crate::ftp_client::FtpConfig,
539 ) -> Result<()> {
540 let client = FtpClient::connect(&config).await?;
541 self.replace_managed_connection(
542 connection_id,
543 ManagedConnection::Ftp(Arc::new(RwLock::new(client))),
544 )
545 .await
546 }
547
548 pub async fn close_ftp_connection(&self, connection_id: &str) -> Result<()> {
549 if let Some(connection) = self
550 .take_connection_if_kind(connection_id, ConnectionSlotKind::Ftp)
551 .await?
552 {
553 self.disconnect_managed_connection(connection_id, connection)
554 .await?;
555 }
556 Ok(())
557 }
558
559 pub async fn create_desktop_connection(
564 &self,
565 connection_id: String,
566 request: &DesktopConnectRequest,
567 ) -> Result<(u16, u16)> {
568 use crate::desktop_protocol::DesktopKind;
569 let (kind, client): (ProtocolKind, Box<dyn DesktopProtocol>) = match request.protocol {
570 DesktopKind::Rdp => {
571 let config = request.to_rdp_config();
572 (
573 ProtocolKind::Rdp,
574 Box::new(RdpClient::connect(&config).await?),
575 )
576 }
577 DesktopKind::Vnc => {
578 let config = request.to_vnc_config();
579 (
580 ProtocolKind::Vnc,
581 Box::new(VncClient::connect(&config).await?),
582 )
583 }
584 };
585
586 let (w, h) = client.desktop_size();
587
588 self.replace_managed_connection(
589 connection_id,
590 ManagedConnection::Desktop {
591 kind,
592 client: Arc::new(RwLock::new(client)),
593 },
594 )
595 .await?;
596
597 Ok((w, h))
598 }
599
600 pub async fn close_desktop_connection(&self, connection_id: &str) -> Result<()> {
601 if let Some(connection) = self
602 .take_connection_if_kind(connection_id, ConnectionSlotKind::Desktop)
603 .await?
604 {
605 self.disconnect_managed_connection(connection_id, connection)
606 .await?;
607 }
608 Ok(())
609 }
610
611 pub async fn create_postgres_connection(
616 &self,
617 connection_id: String,
618 config: PgConfig,
619 ) -> Result<()> {
620 let ssh_client = if let Some(tunnel) = config.ssh_tunnel.as_ref() {
624 match self.get_connection(&tunnel.ssh_connection_id).await {
625 Some(c) => Some(c),
626 None => {
627 return Err(anyhow::Error::from(
628 crate::postgres::PgError::TunnelSourceMissing(format!(
629 "ssh connection '{}' is not registered or has been closed",
630 tunnel.ssh_connection_id
631 )),
632 ));
633 }
634 }
635 } else {
636 None
637 };
638
639 let cancel_token = self.register_pending_connection(&connection_id).await;
640 let connect_result = tokio::select! {
641 res = PgPool::connect(config, ssh_client) => res.map_err(anyhow::Error::from),
642 _ = cancel_token.cancelled() => Err(anyhow::anyhow!("Connection cancelled by user")),
643 };
644 self.clear_pending_connection(&connection_id).await;
645
646 let pool = connect_result?;
647 self.replace_managed_connection(connection_id, ManagedConnection::Postgres(pool))
648 .await
649 }
650
651 pub async fn close_postgres_connection(&self, connection_id: &str) -> Result<()> {
652 if let Some(connection) = self
653 .take_connection_if_kind(connection_id, ConnectionSlotKind::Postgres)
654 .await?
655 {
656 self.disconnect_managed_connection(connection_id, connection)
657 .await?;
658 }
659 Ok(())
660 }
661
662 #[allow(dead_code)]
668 pub async fn start_desktop_stream(
669 &self,
670 connection_id: &str,
671 frame_tx: mpsc::UnboundedSender<FrameUpdate>,
672 cancel: CancellationToken,
673 ) -> Result<()> {
674 let client = self
675 .get_desktop_connection(connection_id)
676 .await
677 .ok_or_else(|| anyhow::anyhow!("Desktop connection not found: {}", connection_id))?;
678 let client = client.read().await;
679 client.start_frame_loop(frame_tx, cancel).await
680 }
681}
682
683#[cfg(test)]
687mod tests {
688 use super::*;
689 use async_trait::async_trait;
690
691 struct TestDesktopClient;
692
693 #[async_trait]
694 impl DesktopProtocol for TestDesktopClient {
695 async fn start_frame_loop(
696 &self,
697 _frame_tx: mpsc::UnboundedSender<FrameUpdate>,
698 _cancel: CancellationToken,
699 ) -> Result<()> {
700 Ok(())
701 }
702
703 async fn send_key(&self, _key_code: u32, _down: bool) -> Result<()> {
704 Ok(())
705 }
706
707 async fn send_pointer(&self, _x: u16, _y: u16, _button_mask: u8) -> Result<()> {
708 Ok(())
709 }
710
711 async fn request_full_frame(&self) -> Result<()> {
712 Ok(())
713 }
714
715 async fn set_clipboard(&self, _text: String) -> Result<()> {
716 Ok(())
717 }
718
719 fn desktop_size(&self) -> (u16, u16) {
720 (1024, 768)
721 }
722
723 async fn resize(&mut self, _width: u16, _height: u16) -> Result<()> {
724 Ok(())
725 }
726
727 async fn disconnect(&mut self) -> Result<()> {
728 Ok(())
729 }
730 }
731
732 fn disconnected_ssh_client() -> SshClient {
733 SshClient::new(Arc::new(HostKeyStore::new(
734 std::env::temp_dir().join("r-shell-test-known-hosts"),
735 )))
736 }
737
738 #[tokio::test]
739 async fn test_new_manager_has_no_connections() {
740 let mgr = ConnectionManager::new();
741 assert!(mgr.list_connections().await.is_empty());
742 }
743
744 #[tokio::test]
745 async fn test_connection_kind_returns_none_for_unknown() {
746 let mgr = ConnectionManager::new();
747 assert!(mgr.connection_kind("unknown-id").await.is_none());
748 assert!(mgr.get_connection_type("unknown-id").await.is_none());
749 }
750
751 #[tokio::test]
752 async fn test_cancel_nonexistent_pending_connection() {
753 let mgr = ConnectionManager::new();
754 assert!(!mgr.cancel_pending_connection("ghost").await);
755 }
756
757 #[tokio::test]
758 async fn test_protocol_kind_round_trip() {
759 assert_eq!(ProtocolKind::Ssh.as_str(), "SSH");
760 assert_eq!(ProtocolKind::Sftp.as_str(), "SFTP");
761 assert_eq!(ProtocolKind::Ftp.as_str(), "FTP");
762 assert_eq!(ProtocolKind::Rdp.as_str(), "RDP");
763 assert_eq!(ProtocolKind::Vnc.as_str(), "VNC");
764 assert_eq!(ProtocolKind::Postgres.as_str(), "POSTGRES");
765 }
766
767 #[tokio::test]
768 async fn test_close_postgres_of_unknown_id_is_noop() {
769 let mgr = ConnectionManager::new();
770 let result = mgr.close_postgres_connection("ghost").await;
771 assert!(result.is_ok());
772 }
773
774 #[tokio::test]
775 async fn test_close_sftp_of_unknown_id_is_noop() {
776 let mgr = ConnectionManager::new();
777 let result = mgr.close_sftp_connection("ghost").await;
778 assert!(result.is_ok());
779 }
780
781 #[tokio::test]
782 async fn test_close_ftp_of_unknown_id_is_noop() {
783 let mgr = ConnectionManager::new();
784 let result = mgr.close_ftp_connection("ghost").await;
785 assert!(result.is_ok());
786 }
787
788 #[tokio::test]
789 async fn test_close_connection_rejects_non_ssh_without_removing_it() {
790 let mgr = ConnectionManager::new();
791 {
792 let mut connections = mgr.connections.write().await;
793 connections.insert(
794 "desktop".to_string(),
795 ManagedConnection::Desktop {
796 kind: ProtocolKind::Rdp,
797 client: Arc::new(RwLock::new(Box::new(TestDesktopClient))),
798 },
799 );
800 }
801
802 let err = mgr
803 .close_connection("desktop")
804 .await
805 .expect_err("closing an RDP connection through the SSH API must fail");
806 assert!(err.to_string().contains("not SSH"));
807 assert_eq!(
808 mgr.connection_kind("desktop").await,
809 Some(ProtocolKind::Rdp)
810 );
811 }
812
813 #[tokio::test]
814 async fn test_close_desktop_connection_rejects_ssh_without_removing_it() {
815 let mgr = ConnectionManager::new();
816 {
817 let mut connections = mgr.connections.write().await;
818 connections.insert(
819 "ssh".to_string(),
820 ManagedConnection::Ssh(Arc::new(RwLock::new(disconnected_ssh_client()))),
821 );
822 }
823
824 let err = mgr
825 .close_desktop_connection("ssh")
826 .await
827 .expect_err("closing an SSH connection through the desktop API must fail");
828 assert!(err.to_string().contains("not desktop"));
829 assert_eq!(mgr.connection_kind("ssh").await, Some(ProtocolKind::Ssh));
830 }
831}