ssh_commander_core/postgres/tunnel.rs
1//! SSH local-port forwarding for the Postgres explorer.
2//!
3//! Binds a local TCP listener on `127.0.0.1:<ephemeral>` and, for every
4//! inbound connection, opens a fresh `direct-tcpip` SSH channel to the
5//! configured remote endpoint and bidirectionally splices bytes between
6//! the local socket and the SSH channel. Pattern matches `ssh -L`.
7//!
8//! # Lifetime
9//!
10//! `SshTunnel` owns a `CancellationToken` and a `JoinHandle` for the
11//! accept loop. Dropping the tunnel cancels the loop and releases the
12//! local listener. Per-connection forwarder tasks are independent — they
13//! finish naturally when either side closes the stream — so the drop is
14//! best-effort: any in-flight Postgres traffic continues until the
15//! sockets close, which matches the observable behavior of `ssh -L`
16//! when the controlling terminal exits.
17//!
18//! # Concurrency
19//!
20//! The accept loop holds an `Arc<RwLock<SshClient>>`. Each accepted
21//! connection acquires a *read* lock for the duration of the
22//! `channel_open_direct_tcpip` round-trip only — the lock is dropped
23//! before splicing begins so the same SSH session can host many
24//! simultaneous Postgres connections without serializing channel
25//! opens.
26
27use std::sync::Arc;
28
29use tokio::io::AsyncWriteExt;
30use tokio::net::TcpListener;
31use tokio::sync::RwLock;
32use tokio::task::JoinHandle;
33use tokio_util::sync::CancellationToken;
34
35use crate::ssh::SshClient;
36
37/// Live SSH local-forward to `(remote_host, remote_port)`.
38///
39/// Public surface is intentionally tiny: the local port the caller
40/// should target, and an opaque drop guard. The accept loop, channel
41/// management, and byte splicing are private.
42pub struct SshTunnel {
43 /// Loopback port the local listener is bound to. Stable for the
44 /// lifetime of the tunnel.
45 local_port: u16,
46 cancel: CancellationToken,
47 /// Held only so the accept-loop task is aborted on drop. Never
48 /// awaited externally.
49 _accept_task: JoinHandle<()>,
50}
51
52impl SshTunnel {
53 pub fn local_port(&self) -> u16 {
54 self.local_port
55 }
56
57 /// Open the listener and start the accept loop. Returns once the
58 /// listener is bound; per-connection channels open lazily on
59 /// inbound traffic.
60 ///
61 /// Errors:
62 /// - `bind` fails (vanishingly rare on `127.0.0.1:0`)
63 /// - the SshClient is not connected (caller bug — should have
64 /// confirmed before invoking)
65 pub async fn open(
66 ssh_client: Arc<RwLock<SshClient>>,
67 remote_host: String,
68 remote_port: u16,
69 ) -> anyhow::Result<Self> {
70 // 0.0.0.0 would expose the forward to the LAN — `127.0.0.1` keeps
71 // it loopback-only, matching `ssh -L` defaults.
72 let listener = TcpListener::bind("127.0.0.1:0").await?;
73 let local_port = listener.local_addr()?.port();
74
75 let cancel = CancellationToken::new();
76 let task_cancel = cancel.clone();
77
78 let accept_task = tokio::spawn(async move {
79 run_accept_loop(listener, ssh_client, remote_host, remote_port, task_cancel).await;
80 });
81
82 Ok(Self {
83 local_port,
84 cancel,
85 _accept_task: accept_task,
86 })
87 }
88}
89
90impl Drop for SshTunnel {
91 fn drop(&mut self) {
92 // Cancellation is sufficient — the listener is owned by the
93 // accept task, so dropping the JoinHandle (with abort behavior)
94 // and signalling cancel both ensure the listener is closed.
95 self.cancel.cancel();
96 }
97}
98
99async fn run_accept_loop(
100 listener: TcpListener,
101 ssh_client: Arc<RwLock<SshClient>>,
102 remote_host: String,
103 remote_port: u16,
104 cancel: CancellationToken,
105) {
106 loop {
107 tokio::select! {
108 _ = cancel.cancelled() => {
109 tracing::debug!("postgres tunnel accept loop cancelled");
110 return;
111 }
112 res = listener.accept() => {
113 match res {
114 Ok((local_stream, peer)) => {
115 let ssh_client = ssh_client.clone();
116 let remote_host = remote_host.clone();
117 let conn_cancel = cancel.clone();
118 tokio::spawn(async move {
119 if let Err(e) = forward_one(
120 local_stream,
121 ssh_client,
122 &remote_host,
123 remote_port,
124 conn_cancel,
125 )
126 .await
127 {
128 tracing::warn!(
129 peer = %peer,
130 error = %e,
131 "postgres tunnel forwarder ended with error"
132 );
133 }
134 });
135 }
136 Err(e) => {
137 tracing::warn!("postgres tunnel accept failed: {e}");
138 // Don't tight-loop on a fatal listener error —
139 // a brief yield lets the runtime mark the
140 // listener dead, after which subsequent accepts
141 // also fail and we exit on cancel.
142 tokio::task::yield_now().await;
143 }
144 }
145 }
146 }
147 }
148}
149
150async fn forward_one(
151 mut local_stream: tokio::net::TcpStream,
152 ssh_client: Arc<RwLock<SshClient>>,
153 remote_host: &str,
154 remote_port: u16,
155 cancel: CancellationToken,
156) -> anyhow::Result<()> {
157 // Brief read-lock window: open the SSH channel, then drop the
158 // guard so other tasks can use the SshClient while we splice.
159 let channel = {
160 let guard = ssh_client.read().await;
161 guard.open_direct_tcpip(remote_host, remote_port).await?
162 };
163
164 let mut stream = channel.into_stream();
165 let (mut local_read, mut local_write) = local_stream.split();
166 let (mut ssh_read, mut ssh_write) = tokio::io::split(&mut stream);
167
168 // Bidirectional splice. `tokio::io::copy` returns when its source
169 // EOFs. Either direction finishing tears down both — Postgres
170 // connections are duplex and a half-open state is never useful.
171 let local_to_ssh = async {
172 let r = tokio::io::copy(&mut local_read, &mut ssh_write).await;
173 let _ = ssh_write.shutdown().await;
174 r
175 };
176 let ssh_to_local = async {
177 let r = tokio::io::copy(&mut ssh_read, &mut local_write).await;
178 let _ = local_write.shutdown().await;
179 r
180 };
181
182 tokio::select! {
183 _ = cancel.cancelled() => {
184 tracing::debug!("postgres tunnel forwarder cancelled");
185 Ok(())
186 }
187 res = async {
188 tokio::try_join!(local_to_ssh, ssh_to_local).map(|_| ())
189 } => {
190 res.map_err(anyhow::Error::from)
191 }
192 }
193}
194
195#[cfg(test)]
196mod tests {
197 use super::*;
198
199 /// `SshTunnel::open` doesn't try to use the SSH client until a TCP
200 /// connection arrives. So binding the listener succeeds even when
201 /// the underlying SshClient is disconnected — which is the right
202 /// behavior: failure surfaces on first use, not on construction.
203 /// We can't easily build a real connected SshClient without a server,
204 /// but we can verify the bind step and the local port assignment.
205 #[tokio::test]
206 async fn open_binds_local_port_immediately() {
207 use crate::ssh::HostKeyStore;
208 let host_keys = Arc::new(HostKeyStore::new(
209 std::env::temp_dir().join("r-shell-tunnel-test-known-hosts"),
210 ));
211 let client = Arc::new(RwLock::new(SshClient::new(host_keys)));
212 let tunnel = SshTunnel::open(client, "irrelevant".to_string(), 5432)
213 .await
214 .expect("bind should succeed");
215 assert!(tunnel.local_port() > 0);
216 // Listener is reachable as long as the tunnel is alive.
217 let probe = tokio::net::TcpStream::connect(("127.0.0.1", tunnel.local_port())).await;
218 assert!(probe.is_ok(), "listener should accept connections");
219 }
220
221 /// Dropping the tunnel cancels the accept loop and releases the
222 /// listener. Asserted by binding a *new* listener on the same port
223 /// after drop — succeeds only if the original is gone.
224 #[tokio::test]
225 async fn drop_releases_local_port() {
226 use crate::ssh::HostKeyStore;
227 let host_keys = Arc::new(HostKeyStore::new(
228 std::env::temp_dir().join("r-shell-tunnel-test-known-hosts-2"),
229 ));
230 let client = Arc::new(RwLock::new(SshClient::new(host_keys)));
231 let tunnel = SshTunnel::open(client, "irrelevant".to_string(), 5432)
232 .await
233 .expect("bind");
234 let port = tunnel.local_port();
235 drop(tunnel);
236
237 // Give the runtime a tick to process the cancellation. SO_REUSEADDR
238 // is on by default for ephemeral ports on macOS/Linux, so a re-bind
239 // attempt is the cleanest assertion that the slot is free.
240 tokio::task::yield_now().await;
241 let rebind = TcpListener::bind(("127.0.0.1", port)).await;
242 assert!(rebind.is_ok(), "port {port} should be reusable after drop");
243 }
244}