Skip to main content

tnnl/
server.rs

1use std::collections::VecDeque;
2use std::net::IpAddr;
3use std::sync::Arc;
4use std::time::{Duration, Instant};
5
6use anyhow::{Context, Result};
7use dashmap::DashMap;
8use futures::future::poll_fn;
9use futures::{AsyncReadExt, AsyncWriteExt};
10use tokio::net::TcpListener;
11use tokio::sync::{mpsc, oneshot};
12use tokio_util::compat::{FuturesAsyncReadCompatExt, TokioAsyncReadCompatExt};
13use yamux::{Config, Connection, Mode};
14
15use crate::log as tlog;
16use crate::protocol::{self, ControlMsg};
17use crate::proxy;
18
19const MAX_TUNNELS_TOTAL: usize = 200;
20const MAX_TUNNELS_PER_IP: usize = 5;
21const MAX_CONNECTS_PER_MINUTE: usize = 15;
22const RATE_WINDOW: Duration = Duration::from_secs(60);
23
24type OpenStreamReply = oneshot::Sender<Result<yamux::Stream>>;
25
26struct ClientHandle {
27    stream_tx: mpsc::Sender<OpenStreamReply>,
28}
29
30type Registry = Arc<DashMap<String, ClientHandle>>;
31
32struct IpState {
33    recent_connects: VecDeque<Instant>,
34    active_tunnels: usize,
35}
36
37type IpTracker = Arc<DashMap<IpAddr, IpState>>;
38
39pub async fn run(
40    control_port: u16,
41    http_port: u16,
42    domain: &str,
43    token: Option<&str>,
44) -> Result<()> {
45    let registry: Registry = Arc::new(DashMap::new());
46    let ip_tracker: IpTracker = Arc::new(DashMap::new());
47    let domain = domain.to_string();
48    let token = token.map(|t| t.to_string());
49
50    let control_listener = TcpListener::bind(format!("0.0.0.0:{control_port}"))
51        .await
52        .context(format!("failed to bind control port {control_port}"))?;
53    let http_listener = TcpListener::bind(format!("0.0.0.0:{http_port}"))
54        .await
55        .context(format!("failed to bind http port {http_port}"))?;
56
57    tlog::info(&format!("control listening on 0.0.0.0:{control_port}"));
58    tlog::info(&format!("http listening on 0.0.0.0:{http_port}"));
59    tlog::info(&format!("domain: *.{domain}"));
60    if token.is_none() {
61        tlog::info("token auth disabled — open server");
62    }
63
64    let reg = registry.clone();
65    let ipt = ip_tracker.clone();
66    let tok = token.clone();
67    let dom = domain.clone();
68    let control_task = tokio::spawn(async move {
69        loop {
70            match control_listener.accept().await {
71                Ok((socket, addr)) => {
72                    let ip = addr.ip();
73
74                    if !check_rate_limit(&ipt, ip) {
75                        tlog::info(&format!(
76                            "rate limited {ip} (>{MAX_CONNECTS_PER_MINUTE}/min)"
77                        ));
78                        drop(socket);
79                        continue;
80                    }
81
82                    tlog::info(&format!("client connected from {addr}"));
83                    let reg = reg.clone();
84                    let ipt = ipt.clone();
85                    let tok = tok.clone();
86                    let dom = dom.clone();
87                    tokio::spawn(async move {
88                        if let Err(e) =
89                            handle_client(socket, reg, ipt, ip, tok.as_deref(), &dom).await
90                        {
91                            tlog::error(&format!("client {addr}: {e}"));
92                        }
93                    });
94                }
95                Err(e) => tlog::error(&format!("accept error: {e}")),
96            }
97        }
98    });
99
100    let reg = registry.clone();
101    let dom = domain.clone();
102    let http_task = tokio::spawn(async move {
103        loop {
104            match http_listener.accept().await {
105                Ok((socket, _)) => {
106                    let reg = reg.clone();
107                    let dom = dom.clone();
108                    tokio::spawn(async move {
109                        if let Err(e) = handle_http(socket, reg, &dom).await {
110                            tlog::error(&format!("http: {e}"));
111                        }
112                    });
113                }
114                Err(e) => tlog::error(&format!("http accept error: {e}")),
115            }
116        }
117    });
118
119    tokio::select! {
120        r = control_task => r?,
121        r = http_task => r?,
122    }
123
124    Ok(())
125}
126
127// Returns false if the IP has exceeded the connection rate limit.
128// Also prunes stale timestamps on every call.
129fn check_rate_limit(ip_tracker: &IpTracker, ip: IpAddr) -> bool {
130    let now = Instant::now();
131    let cutoff = now - RATE_WINDOW;
132    let mut entry = ip_tracker.entry(ip).or_insert_with(|| IpState {
133        recent_connects: VecDeque::new(),
134        active_tunnels: 0,
135    });
136    while entry.recent_connects.front().is_some_and(|t| *t < cutoff) {
137        entry.recent_connects.pop_front();
138    }
139    if entry.recent_connects.len() >= MAX_CONNECTS_PER_MINUTE {
140        return false;
141    }
142    entry.recent_connects.push_back(now);
143    true
144}
145
146async fn handle_client(
147    socket: tokio::net::TcpStream,
148    registry: Registry,
149    ip_tracker: IpTracker,
150    peer_ip: IpAddr,
151    expected_token: Option<&str>,
152    domain: &str,
153) -> Result<()> {
154    let mut config = Config::default();
155    config.set_split_send_size(16 * 1024);
156
157    let mut connection = Connection::new(socket.compat(), config, Mode::Server);
158
159    let mut control_stream = poll_fn(|cx| connection.poll_next_inbound(cx))
160        .await
161        .context("no control stream")??;
162
163    let (stream_tx, mut stream_rx) = mpsc::channel::<OpenStreamReply>(32);
164    let conn_task = tokio::spawn(async move {
165        loop {
166            tokio::select! {
167                biased;
168                Some(reply_tx) = stream_rx.recv() => {
169                    let result = poll_fn(|cx| connection.poll_new_outbound(cx)).await;
170                    let _ = reply_tx.send(result.map_err(|e| anyhow::anyhow!("{e}")));
171                }
172                inbound = poll_fn(|cx| connection.poll_next_inbound(cx)) => {
173                    match inbound {
174                        Some(Ok(_)) => {}
175                        Some(Err(e)) => {
176                            tlog::error(&format!("yamux error: {e}"));
177                            break;
178                        }
179                        None => break,
180                    }
181                }
182            }
183        }
184    });
185
186    let msg = protocol::read_msg(&mut control_stream).await?;
187    let (requested_subdomain, nonce, provided_hmac) = match msg {
188        ControlMsg::Auth {
189            subdomain,
190            nonce,
191            hmac,
192        } => (subdomain, nonce, hmac),
193        _ => anyhow::bail!("expected Auth message"),
194    };
195
196    if let Some(secret) = expected_token {
197        let expected_hmac = protocol::compute_hmac(secret, &nonce);
198        if provided_hmac.as_deref() != Some(expected_hmac.as_str()) {
199            let encoded = ControlMsg::Error {
200                message: "invalid secret".into(),
201            }
202            .encode()?;
203            control_stream.write_all(&encoded).await.ok();
204            control_stream.close().await.ok();
205            anyhow::bail!("invalid secret from {peer_ip}");
206        }
207    }
208
209    let ip_active = ip_tracker
210        .get(&peer_ip)
211        .map(|s| s.active_tunnels)
212        .unwrap_or(0);
213    if ip_active >= MAX_TUNNELS_PER_IP {
214        let encoded = ControlMsg::Error {
215            message: format!("max {MAX_TUNNELS_PER_IP} tunnels per IP"),
216        }
217        .encode()?;
218        control_stream.write_all(&encoded).await.ok();
219        control_stream.close().await.ok();
220        anyhow::bail!("{peer_ip} hit per-IP tunnel limit");
221    }
222
223    if registry.len() >= MAX_TUNNELS_TOTAL {
224        let encoded = ControlMsg::Error {
225            message: "server at capacity, try again later".into(),
226        }
227        .encode()?;
228        control_stream.write_all(&encoded).await.ok();
229        control_stream.close().await.ok();
230        anyhow::bail!("global tunnel limit reached");
231    }
232
233    let subdomain = requested_subdomain
234        .filter(|s| {
235            !s.is_empty() && s.len() <= 63 && s.chars().all(|c| c.is_alphanumeric() || c == '-')
236        })
237        .unwrap_or_else(|| {
238            use rand::Rng;
239            let mut rng = rand::rng();
240            format!("{:08x}", rng.random::<u32>())
241        });
242
243    if registry.contains_key(&subdomain) {
244        let encoded = ControlMsg::Error {
245            message: format!("subdomain '{subdomain}' already in use"),
246        }
247        .encode()?;
248        control_stream.write_all(&encoded).await.ok();
249        control_stream.close().await.ok();
250        anyhow::bail!("subdomain collision: {subdomain}");
251    }
252
253    let full_domain = format!("{subdomain}.{domain}");
254    let ok = ControlMsg::AuthOk {
255        subdomain: subdomain.clone(),
256        url: full_domain.clone(),
257    };
258    control_stream.write_all(&ok.encode()?).await?;
259    control_stream.flush().await?;
260
261    registry.insert(subdomain.clone(), ClientHandle { stream_tx });
262    ip_tracker
263        .entry(peer_ip)
264        .and_modify(|s| s.active_tunnels += 1);
265
266    tlog::success(&format!(
267        "tunnel live: {full_domain} (ip={peer_ip}, active={}/{MAX_TUNNELS_PER_IP})",
268        ip_active + 1
269    ));
270
271    let mut buf = [0u8; 1024];
272    loop {
273        match control_stream.read(&mut buf).await {
274            Ok(0) | Err(_) => break,
275            Ok(_) => {}
276        }
277    }
278
279    registry.remove(&subdomain);
280    ip_tracker.entry(peer_ip).and_modify(|s| {
281        if s.active_tunnels > 0 {
282            s.active_tunnels -= 1;
283        }
284    });
285    conn_task.abort();
286
287    tlog::info(&format!("client disconnected, removed {full_domain}"));
288
289    Ok(())
290}
291
292const INSTALL_SH: &str = include_str!("../install.sh");
293
294async fn handle_http(
295    mut socket: tokio::net::TcpStream,
296    registry: Registry,
297    domain: &str,
298) -> Result<()> {
299    let head = proxy::read_http_head(&mut socket).await?;
300    let host = proxy::extract_host(&head).context("no Host header")?;
301
302    if host == domain {
303        let (_, path) = proxy::parse_request_line(&head);
304        let response: Vec<u8> = if path == "/install.sh" {
305            format!(
306                "HTTP/1.1 200 OK\r\nContent-Type: text/plain; charset=utf-8\r\nContent-Length: {}\r\nConnection: close\r\n\r\n{}",
307                INSTALL_SH.len(),
308                INSTALL_SH
309            )
310            .into_bytes()
311        } else {
312            b"HTTP/1.1 301 Moved Permanently\r\nLocation: https://github.com/jbingen/tnnl\r\nContent-Length: 0\r\nConnection: close\r\n\r\n".to_vec()
313        };
314        tokio::io::AsyncWriteExt::write_all(&mut socket, &response)
315            .await
316            .ok();
317        return Ok(());
318    }
319
320    let subdomain = host
321        .strip_suffix(&format!(".{domain}"))
322        .context(format!("host '{host}' not a subdomain of {domain}"))?
323        .to_string();
324
325    let stream_tx = match registry.get(&subdomain) {
326        Some(entry) => entry.stream_tx.clone(),
327        None => {
328            proxy::write_404(&mut socket).await.ok();
329            return Ok(());
330        }
331    };
332
333    let (reply_tx, reply_rx) = oneshot::channel();
334    stream_tx
335        .send(reply_tx)
336        .await
337        .map_err(|_| anyhow::anyhow!("client disconnected"))?;
338
339    let tunnel_stream = reply_rx
340        .await
341        .map_err(|_| anyhow::anyhow!("client disconnected"))??;
342
343    let mut tunnel_compat = tunnel_stream.compat();
344
345    tokio::io::AsyncWriteExt::write_all(&mut tunnel_compat, &head).await?;
346
347    tokio::io::copy_bidirectional(&mut socket, &mut tunnel_compat)
348        .await
349        .ok();
350
351    Ok(())
352}