Skip to main content

tnnl/
client.rs

1use std::time::Duration;
2
3use anyhow::{Context, Result};
4use base64::{Engine as _, engine::general_purpose::STANDARD as B64};
5use futures::AsyncWriteExt;
6use futures::future::poll_fn;
7use tokio::io::{AsyncReadExt as _, AsyncWriteExt as _};
8use tokio::net::TcpStream;
9use tokio_util::compat::{FuturesAsyncReadCompatExt, TokioAsyncReadCompatExt};
10use yamux::{Config, Connection, Mode};
11
12use crate::log as tlog;
13use crate::protocol::{self, ControlMsg};
14use crate::proxy;
15use crate::store;
16use crate::update;
17
18const MAX_BACKOFF: Duration = Duration::from_secs(30);
19const INITIAL_BACKOFF: Duration = Duration::from_secs(1);
20const BODY_CAP: usize = 10 * 1024 * 1024;
21
22pub struct TunnelOpts<'a> {
23    pub local_port: u16,
24    pub local_host: &'a str,
25    pub server_addr: &'a str,
26    pub server_port: u16,
27    pub token: &'a str,
28    pub subdomain: Option<&'a str>,
29    pub auth: Option<&'a str>,
30    pub inspect: bool,
31}
32
33pub async fn run(opts: TunnelOpts<'_>) -> Result<()> {
34    store::init();
35    let expected_auth = opts.auth.map(|a| format!("Basic {}", B64.encode(a)));
36    let mut backoff = INITIAL_BACKOFF;
37
38    loop {
39        tlog::info(&format!(
40            "connecting to {}:{}...",
41            opts.server_addr, opts.server_port
42        ));
43
44        let attempt_start = std::time::Instant::now();
45        match connect_and_tunnel(&opts, expected_auth.as_deref()).await {
46            Ok(()) => {
47                tlog::info("connection closed");
48                break;
49            }
50            Err(e) => {
51                tlog::error(&format!("{e:#}"));
52                tlog::info(&format!("reconnecting in {}s...", backoff.as_secs()));
53                tokio::time::sleep(backoff).await;
54                // If we were connected long enough to show the banner, reset backoff.
55                if attempt_start.elapsed() > Duration::from_secs(5) {
56                    backoff = INITIAL_BACKOFF;
57                } else {
58                    backoff = (backoff * 2).min(MAX_BACKOFF);
59                }
60            }
61        }
62    }
63
64    Ok(())
65}
66
67async fn connect_and_tunnel(opts: &TunnelOpts<'_>, expected_auth: Option<&str>) -> Result<()> {
68    let TunnelOpts {
69        local_port,
70        local_host,
71        server_addr,
72        server_port,
73        token,
74        subdomain,
75        inspect,
76        ..
77    } = opts;
78    let socket = tokio::time::timeout(
79        Duration::from_secs(10),
80        TcpStream::connect(format!("{server_addr}:{server_port}")),
81    )
82    .await
83    .context("connection timed out")?
84    .context("failed to connect to server")?;
85
86    let mut config = Config::default();
87    config.set_split_send_size(16 * 1024);
88
89    let mut connection = Connection::new(socket.compat(), config, Mode::Client);
90
91    let mut control_stream = poll_fn(|cx| connection.poll_new_outbound(cx))
92        .await
93        .context("failed to open control stream")?;
94
95    let (inbound_tx, mut inbound_rx) = tokio::sync::mpsc::channel::<yamux::Stream>(32);
96    tokio::spawn(async move {
97        loop {
98            match poll_fn(|cx| connection.poll_next_inbound(cx)).await {
99                Some(Ok(stream)) => {
100                    if inbound_tx.send(stream).await.is_err() {
101                        break;
102                    }
103                }
104                Some(Err(e)) => {
105                    tlog::error(&format!("yamux: {e}"));
106                    break;
107                }
108                None => break,
109            }
110        }
111    });
112
113    let nonce = {
114        use rand::Rng;
115        let bytes: [u8; 32] = rand::rng().random();
116        B64.encode(bytes)
117    };
118    let hmac = if token.is_empty() {
119        None
120    } else {
121        Some(protocol::compute_hmac(token, &nonce))
122    };
123    let auth = ControlMsg::Auth {
124        subdomain: subdomain.map(|s| s.to_string()),
125        nonce,
126        hmac,
127    };
128    control_stream.write_all(&auth.encode()?).await?;
129    control_stream.flush().await?;
130
131    let resp = protocol::read_msg(&mut control_stream).await?;
132    let tunnel_url = match resp {
133        ControlMsg::AuthOk { url, .. } => url,
134        ControlMsg::Error { message } => anyhow::bail!("server error: {message}"),
135        _ => anyhow::bail!("unexpected response from server"),
136    };
137
138    let display_url = if *server_addr == "127.0.0.1" || *server_addr == "localhost" {
139        format!("http://{tunnel_url}")
140    } else {
141        format!("https://{tunnel_url}")
142    };
143    tlog::banner(&display_url, local_host, *local_port, *inspect);
144    update::check_in_background();
145
146    let (shutdown_tx, mut shutdown_rx) = tokio::sync::oneshot::channel::<()>();
147    tokio::spawn(async move {
148        tokio::signal::ctrl_c().await.ok();
149        eprintln!();
150        tlog::info("shutting down...");
151        let _ = shutdown_tx.send(());
152    });
153
154    loop {
155        tokio::select! {
156            biased;
157            _ = &mut shutdown_rx => {
158                control_stream.close().await.ok();
159                break;
160            }
161            stream = inbound_rx.recv() => {
162                match stream {
163                    Some(s) => { tokio::spawn(handle_stream(s, *local_port, local_host.to_string(), expected_auth.map(|s| s.to_string()), *inspect)); }
164                    None => break,
165                }
166            }
167        }
168    }
169
170    Ok(())
171}
172
173async fn handle_stream(
174    stream: yamux::Stream,
175    local_port: u16,
176    local_host: String,
177    expected_auth: Option<String>,
178    inspect: bool,
179) {
180    if let Err(e) = proxy_to_local(
181        stream,
182        local_port,
183        &local_host,
184        expected_auth.as_deref(),
185        inspect,
186    )
187    .await
188    {
189        tlog::error(&format!("proxy: {e}"));
190    }
191}
192
193async fn proxy_to_local(
194    stream: yamux::Stream,
195    local_port: u16,
196    local_host: &str,
197    expected_auth: Option<&str>,
198    inspect: bool,
199) -> Result<()> {
200    let mut tunnel = stream.compat();
201
202    let req_head = proxy::read_http_head(&mut tunnel).await?;
203    let (method, path) = proxy::parse_request_line(&req_head);
204    let start = std::time::Instant::now();
205    let id = store::next_id();
206
207    let head_end = proxy::headers_end(&req_head).unwrap_or(req_head.len());
208    let req_headers = &req_head[..head_end];
209    let body_prefix = req_head[head_end..].to_vec();
210    let content_length = proxy::parse_content_length(req_headers);
211
212    if let Some(expected) = expected_auth {
213        let provided = proxy::extract_authorization(req_headers);
214        if provided.as_deref() != Some(expected) {
215            proxy::write_401(&mut tunnel).await.ok();
216            tlog::request(&method, &path, 401, start.elapsed().as_millis() as u64, id);
217            return Ok(());
218        }
219    }
220
221    let mut local = match TcpStream::connect(format!("{local_host}:{local_port}")).await {
222        Ok(s) => s,
223        Err(_) => {
224            proxy::write_502(&mut tunnel).await.ok();
225            tlog::request(&method, &path, 502, start.elapsed().as_millis() as u64, id);
226            return Ok(());
227        }
228    };
229
230    if inspect {
231        let req_body = read_body_exact(&mut tunnel, body_prefix, content_length).await;
232
233        local.write_all(req_headers).await?;
234        local.write_all(&req_body).await?;
235        local.flush().await?;
236
237        store::store(store::StoredRequest {
238            id,
239            port: local_port,
240            method: method.clone(),
241            path: path.clone(),
242            raw_headers: String::from_utf8_lossy(req_headers).into_owned(),
243            body_b64: B64.encode(&req_body),
244        });
245
246        let resp_head = proxy::read_http_head(&mut local).await.unwrap_or_default();
247        let status = proxy::parse_response_status(&resp_head);
248        let resp_head_end = proxy::headers_end(&resp_head).unwrap_or(resp_head.len());
249        let resp_headers = &resp_head[..resp_head_end];
250        let resp_already = resp_head[resp_head_end..].to_vec();
251
252        let resp_cl = proxy::parse_content_length(resp_headers);
253        let resp_body = if resp_cl > 0 {
254            read_body_exact(&mut local, resp_already, resp_cl).await
255        } else if proxy::is_chunked(resp_headers) {
256            read_chunked(&mut local, resp_already).await
257        } else {
258            resp_already
259        };
260
261        let out_headers = if proxy::is_chunked(resp_headers) {
262            rebuild_resp_headers(resp_headers, resp_body.len())
263        } else {
264            resp_headers.to_vec()
265        };
266
267        tunnel.write_all(&out_headers).await?;
268        tunnel.write_all(&resp_body).await?;
269        tunnel.flush().await.ok();
270
271        let elapsed = start.elapsed().as_millis() as u64;
272        tlog::request(&method, &path, status, elapsed, id);
273
274        let req_raw = String::from_utf8_lossy(req_headers);
275        let req_body_str = String::from_utf8_lossy(&req_body);
276        tlog::inspect_request(id, &req_raw, &req_body_str);
277
278        let resp_raw = String::from_utf8_lossy(&out_headers);
279        let resp_body_str = String::from_utf8_lossy(&resp_body);
280        tlog::inspect_response(status, &resp_raw, &resp_body_str, id);
281    } else {
282        local.write_all(req_headers).await?;
283        local.write_all(&body_prefix).await?;
284        local.flush().await?;
285
286        // Store what we have (headers + prefix). For requests without a body or
287        // where the body fit in the initial read, this is complete.
288        store::store(store::StoredRequest {
289            id,
290            port: local_port,
291            method: method.clone(),
292            path: path.clone(),
293            raw_headers: String::from_utf8_lossy(req_headers).into_owned(),
294            body_b64: B64.encode(&body_prefix),
295        });
296
297        let mut peek = [0u8; 512];
298        let n = local.read(&mut peek).await.unwrap_or(0);
299        let status = proxy::parse_response_status(&peek[..n]);
300        tunnel.write_all(&peek[..n]).await?;
301        tokio::io::copy_bidirectional(&mut local, &mut tunnel)
302            .await
303            .ok();
304        tlog::request(
305            &method,
306            &path,
307            status,
308            start.elapsed().as_millis() as u64,
309            id,
310        );
311    }
312
313    Ok(())
314}
315
316async fn read_body_exact<R: tokio::io::AsyncRead + Unpin>(
317    reader: &mut R,
318    mut buf: Vec<u8>,
319    total: usize,
320) -> Vec<u8> {
321    let target = total.min(BODY_CAP);
322    let mut tmp = [0u8; 8192];
323    while buf.len() < target {
324        let want = (target - buf.len()).min(8192);
325        match reader.read(&mut tmp[..want]).await {
326            Ok(0) | Err(_) => break,
327            Ok(n) => buf.extend_from_slice(&tmp[..n]),
328        }
329    }
330    buf
331}
332
333async fn read_chunked<R: tokio::io::AsyncRead + Unpin>(
334    reader: &mut R,
335    initial: Vec<u8>,
336) -> Vec<u8> {
337    let mut raw = initial;
338    let mut body = Vec::new();
339    let mut tmp = [0u8; 8192];
340
341    'outer: loop {
342        let mut pos = 0;
343        loop {
344            let slice = &raw[pos..];
345            let Some(crlf) = slice.windows(2).position(|w| w == b"\r\n") else {
346                break;
347            };
348            let size_str = std::str::from_utf8(&slice[..crlf])
349                .unwrap_or("0")
350                .split(';')
351                .next()
352                .unwrap_or("0")
353                .trim();
354            let chunk_size = usize::from_str_radix(size_str, 16).unwrap_or(0);
355            if chunk_size == 0 {
356                let after_size_line = pos + crlf + 2;
357                if after_size_line + 2 <= raw.len() {
358                    break 'outer;
359                }
360                break;
361            }
362            let data_start = pos + crlf + 2;
363            let data_end = data_start + chunk_size;
364            if data_end + 2 > raw.len() {
365                break;
366            }
367            body.extend_from_slice(&raw[data_start..data_end]);
368            pos = data_end + 2;
369            if body.len() >= BODY_CAP {
370                break 'outer;
371            }
372        }
373        raw.drain(..pos);
374        match reader.read(&mut tmp).await {
375            Ok(0) | Err(_) => break,
376            Ok(n) => raw.extend_from_slice(&tmp[..n]),
377        }
378    }
379    body
380}
381
382fn rebuild_resp_headers(headers: &[u8], body_len: usize) -> Vec<u8> {
383    let text = String::from_utf8_lossy(headers);
384    let mut out = Vec::new();
385    for (i, line) in text.split("\r\n").enumerate() {
386        if line.is_empty() {
387            continue;
388        }
389        let lower = line.to_ascii_lowercase();
390        if lower.starts_with("transfer-encoding:") || lower.starts_with("content-length:") {
391            continue;
392        }
393        out.extend_from_slice(line.as_bytes());
394        out.extend_from_slice(b"\r\n");
395        if i == 0 {
396            out.extend_from_slice(format!("Content-Length: {body_len}\r\n").as_bytes());
397        }
398    }
399    out.extend_from_slice(b"\r\n");
400    out
401}
402
403pub async fn replay(id: u64) -> Result<()> {
404    let req = store::find(id).ok_or_else(|| anyhow::anyhow!("request #{id} not found"))?;
405
406    tlog::info(&format!("replaying #{id}: {} {}", req.method, req.path));
407
408    let mut local = TcpStream::connect(format!("127.0.0.1:{}", req.port))
409        .await
410        .with_context(|| format!("failed to connect to localhost:{}", req.port))?;
411
412    local.write_all(req.raw_headers.as_bytes()).await?;
413    let body = B64.decode(&req.body_b64).unwrap_or_default();
414    local.write_all(&body).await?;
415    local.flush().await?;
416
417    let resp_head = proxy::read_http_head(&mut local).await.unwrap_or_default();
418    let resp_head_end = proxy::headers_end(&resp_head).unwrap_or(resp_head.len());
419    let resp_headers = &resp_head[..resp_head_end];
420    let resp_already = resp_head[resp_head_end..].to_vec();
421
422    let resp_cl = proxy::parse_content_length(resp_headers);
423    let resp_body = if resp_cl > 0 {
424        read_body_exact(&mut local, resp_already, resp_cl).await
425    } else if proxy::is_chunked(resp_headers) {
426        read_chunked(&mut local, resp_already).await
427    } else {
428        resp_already
429    };
430
431    let body_str = String::from_utf8_lossy(&resp_body);
432    let status = proxy::parse_response_status(&resp_head);
433    tlog::success(&format!("replayed #{id} → {status}"));
434    if !body_str.trim().is_empty() {
435        eprintln!("{body_str}");
436    }
437    Ok(())
438}