Skip to main content

ssh_commander_core/postgres/
tunnel.rs

1//! SSH local-port forwarding for the Postgres explorer.
2//!
3//! Binds a local TCP listener on `127.0.0.1:<ephemeral>` and, for every
4//! inbound connection, opens a fresh `direct-tcpip` SSH channel to the
5//! configured remote endpoint and bidirectionally splices bytes between
6//! the local socket and the SSH channel. Pattern matches `ssh -L`.
7//!
8//! # Lifetime
9//!
10//! `SshTunnel` owns a `CancellationToken` and a `JoinHandle` for the
11//! accept loop. Dropping the tunnel cancels the loop and releases the
12//! local listener. Per-connection forwarder tasks are independent — they
13//! finish naturally when either side closes the stream — so the drop is
14//! best-effort: any in-flight Postgres traffic continues until the
15//! sockets close, which matches the observable behavior of `ssh -L`
16//! when the controlling terminal exits.
17//!
18//! # Concurrency
19//!
20//! The accept loop holds an `Arc<RwLock<SshClient>>`. Each accepted
21//! connection acquires a *read* lock for the duration of the
22//! `channel_open_direct_tcpip` round-trip only — the lock is dropped
23//! before splicing begins so the same SSH session can host many
24//! simultaneous Postgres connections without serializing channel
25//! opens.
26
27use std::sync::Arc;
28
29use tokio::io::AsyncWriteExt;
30use tokio::net::TcpListener;
31use tokio::sync::RwLock;
32use tokio::task::JoinHandle;
33use tokio_util::sync::CancellationToken;
34
35use crate::ssh::SshClient;
36
37/// Live SSH local-forward to `(remote_host, remote_port)`.
38///
39/// Public surface is intentionally tiny: the local port the caller
40/// should target, and an opaque drop guard. The accept loop, channel
41/// management, and byte splicing are private.
42pub struct SshTunnel {
43    /// Loopback port the local listener is bound to. Stable for the
44    /// lifetime of the tunnel.
45    local_port: u16,
46    cancel: CancellationToken,
47    /// Held only so the accept-loop task is aborted on drop. Never
48    /// awaited externally.
49    _accept_task: JoinHandle<()>,
50}
51
52impl SshTunnel {
53    pub fn local_port(&self) -> u16 {
54        self.local_port
55    }
56
57    /// Open the listener and start the accept loop. Returns once the
58    /// listener is bound; per-connection channels open lazily on
59    /// inbound traffic.
60    ///
61    /// Errors:
62    /// - `bind` fails (vanishingly rare on `127.0.0.1:0`)
63    /// - the SshClient is not connected (caller bug — should have
64    ///   confirmed before invoking)
65    pub async fn open(
66        ssh_client: Arc<RwLock<SshClient>>,
67        remote_host: String,
68        remote_port: u16,
69    ) -> anyhow::Result<Self> {
70        // 0.0.0.0 would expose the forward to the LAN — `127.0.0.1` keeps
71        // it loopback-only, matching `ssh -L` defaults.
72        let listener = TcpListener::bind("127.0.0.1:0").await?;
73        let local_port = listener.local_addr()?.port();
74
75        let cancel = CancellationToken::new();
76        let task_cancel = cancel.clone();
77
78        let accept_task = tokio::spawn(async move {
79            run_accept_loop(listener, ssh_client, remote_host, remote_port, task_cancel).await;
80        });
81
82        Ok(Self {
83            local_port,
84            cancel,
85            _accept_task: accept_task,
86        })
87    }
88}
89
90impl Drop for SshTunnel {
91    fn drop(&mut self) {
92        // Cancellation is sufficient — the listener is owned by the
93        // accept task, so dropping the JoinHandle (with abort behavior)
94        // and signalling cancel both ensure the listener is closed.
95        self.cancel.cancel();
96    }
97}
98
99async fn run_accept_loop(
100    listener: TcpListener,
101    ssh_client: Arc<RwLock<SshClient>>,
102    remote_host: String,
103    remote_port: u16,
104    cancel: CancellationToken,
105) {
106    loop {
107        tokio::select! {
108            _ = cancel.cancelled() => {
109                tracing::debug!("postgres tunnel accept loop cancelled");
110                return;
111            }
112            res = listener.accept() => {
113                match res {
114                    Ok((local_stream, peer)) => {
115                        let ssh_client = ssh_client.clone();
116                        let remote_host = remote_host.clone();
117                        let conn_cancel = cancel.clone();
118                        tokio::spawn(async move {
119                            if let Err(e) = forward_one(
120                                local_stream,
121                                ssh_client,
122                                &remote_host,
123                                remote_port,
124                                conn_cancel,
125                            )
126                            .await
127                            {
128                                tracing::warn!(
129                                    peer = %peer,
130                                    error = %e,
131                                    "postgres tunnel forwarder ended with error"
132                                );
133                            }
134                        });
135                    }
136                    Err(e) => {
137                        tracing::warn!("postgres tunnel accept failed: {e}");
138                        // Don't tight-loop on a fatal listener error —
139                        // a brief yield lets the runtime mark the
140                        // listener dead, after which subsequent accepts
141                        // also fail and we exit on cancel.
142                        tokio::task::yield_now().await;
143                    }
144                }
145            }
146        }
147    }
148}
149
150async fn forward_one(
151    mut local_stream: tokio::net::TcpStream,
152    ssh_client: Arc<RwLock<SshClient>>,
153    remote_host: &str,
154    remote_port: u16,
155    cancel: CancellationToken,
156) -> anyhow::Result<()> {
157    // Brief read-lock window: open the SSH channel, then drop the
158    // guard so other tasks can use the SshClient while we splice.
159    let channel = {
160        let guard = ssh_client.read().await;
161        guard.open_direct_tcpip(remote_host, remote_port).await?
162    };
163
164    let mut stream = channel.into_stream();
165    let (mut local_read, mut local_write) = local_stream.split();
166    let (mut ssh_read, mut ssh_write) = tokio::io::split(&mut stream);
167
168    // Bidirectional splice. `tokio::io::copy` returns when its source
169    // EOFs. Either direction finishing tears down both — Postgres
170    // connections are duplex and a half-open state is never useful.
171    let local_to_ssh = async {
172        let r = tokio::io::copy(&mut local_read, &mut ssh_write).await;
173        let _ = ssh_write.shutdown().await;
174        r
175    };
176    let ssh_to_local = async {
177        let r = tokio::io::copy(&mut ssh_read, &mut local_write).await;
178        let _ = local_write.shutdown().await;
179        r
180    };
181
182    tokio::select! {
183        _ = cancel.cancelled() => {
184            tracing::debug!("postgres tunnel forwarder cancelled");
185            Ok(())
186        }
187        res = async {
188            tokio::try_join!(local_to_ssh, ssh_to_local).map(|_| ())
189        } => {
190            res.map_err(anyhow::Error::from)
191        }
192    }
193}
194
195#[cfg(test)]
196mod tests {
197    use super::*;
198
199    /// `SshTunnel::open` doesn't try to use the SSH client until a TCP
200    /// connection arrives. So binding the listener succeeds even when
201    /// the underlying SshClient is disconnected — which is the right
202    /// behavior: failure surfaces on first use, not on construction.
203    /// We can't easily build a real connected SshClient without a server,
204    /// but we can verify the bind step and the local port assignment.
205    #[tokio::test]
206    async fn open_binds_local_port_immediately() {
207        use crate::ssh::HostKeyStore;
208        let host_keys = Arc::new(HostKeyStore::new(
209            std::env::temp_dir().join("r-shell-tunnel-test-known-hosts"),
210        ));
211        let client = Arc::new(RwLock::new(SshClient::new(host_keys)));
212        let tunnel = SshTunnel::open(client, "irrelevant".to_string(), 5432)
213            .await
214            .expect("bind should succeed");
215        assert!(tunnel.local_port() > 0);
216        // Listener is reachable as long as the tunnel is alive.
217        let probe = tokio::net::TcpStream::connect(("127.0.0.1", tunnel.local_port())).await;
218        assert!(probe.is_ok(), "listener should accept connections");
219    }
220
221    /// Dropping the tunnel cancels the accept loop and releases the
222    /// listener. Asserted by binding a *new* listener on the same port
223    /// after drop — succeeds only if the original is gone.
224    #[tokio::test]
225    async fn drop_releases_local_port() {
226        use crate::ssh::HostKeyStore;
227        let host_keys = Arc::new(HostKeyStore::new(
228            std::env::temp_dir().join("r-shell-tunnel-test-known-hosts-2"),
229        ));
230        let client = Arc::new(RwLock::new(SshClient::new(host_keys)));
231        let tunnel = SshTunnel::open(client, "irrelevant".to_string(), 5432)
232            .await
233            .expect("bind");
234        let port = tunnel.local_port();
235        drop(tunnel);
236
237        // Give the runtime a tick to process the cancellation. SO_REUSEADDR
238        // is on by default for ephemeral ports on macOS/Linux, so a re-bind
239        // attempt is the cleanest assertion that the slot is free.
240        tokio::task::yield_now().await;
241        let rebind = TcpListener::bind(("127.0.0.1", port)).await;
242        assert!(rebind.is_ok(), "port {port} should be reusable after drop");
243    }
244}