Skip to main content

rustgate/c2/
server.rs

1use crate::cert::CertificateAuthority;
2use crate::error::{ProxyError, Result};
3use crate::protocol::{
4    frame_tunnel_data, parse_tunnel_data, Command, CommandResponse, ControlMessage, WsTextMessage,
5};
6use crate::ws::{self, ChannelMap};
7use bytes::Bytes;
8use futures_util::{SinkExt, StreamExt};
9use std::collections::HashMap;
10use std::sync::atomic::{AtomicU32, Ordering};
11use std::sync::Arc;
12use tokio::io::{AsyncBufReadExt, AsyncReadExt, AsyncWriteExt, BufReader};
13use tokio::net::{TcpListener, TcpStream};
14use tokio::sync::mpsc;
15use tokio_rustls::TlsAcceptor;
16use tokio_tungstenite::tungstenite::Message;
17use tracing::{error, info, warn};
18
19struct ClientHandle {
20    cn: String,
21    session_id: u64,
22    ws_tx: mpsc::Sender<Message>,
23    shutdown_tx: tokio::sync::watch::Sender<bool>,
24    channels: Arc<ChannelMap>,
25    /// Pending reverse tunnels: tunnel_id -> remote_port (waiting for client Ok)
26    pending_reverse: Arc<tokio::sync::RwLock<HashMap<u32, u16>>>,
27    /// Pending SOCKS tunnels: tunnel_ids waiting for client Ok before authorization
28    pending_socks: Arc<tokio::sync::RwLock<std::collections::HashSet<u32>>>,
29    /// Tunnel IDs authorized by the operator (SOCKS commands, granted on client Ok)
30    authorized_tunnels: Arc<tokio::sync::RwLock<std::collections::HashSet<u32>>>,
31    /// Active reverse tunnel listeners: tunnel_id -> abort handle
32    reverse_listeners: Arc<tokio::sync::RwLock<HashMap<u32, tokio::task::AbortHandle>>>,
33}
34
35struct ServerState {
36    clients: Arc<tokio::sync::RwLock<HashMap<String, ClientHandle>>>,
37    next_tunnel_id: AtomicU32,
38    next_session_id: std::sync::atomic::AtomicU64,
39}
40
41impl ServerState {
42    fn alloc_tunnel_id(&self) -> u32 {
43        self.next_tunnel_id.fetch_add(1, Ordering::Relaxed)
44    }
45}
46
47/// Run the C2 server.
48pub async fn run(
49    host: &str,
50    port: u16,
51    server_name: &str,
52    ca: Arc<CertificateAuthority>,
53) -> Result<()> {
54    let listen_addr = format!("{host}:{port}");
55
56    // Generate server cert using the advertised server_name, not the bind address
57    let server_ck = ca.generate_server_cert(server_name)?;
58    let ca_cert_der = ca.ca_cert_der();
59    let tls_config =
60        crate::tls::make_mtls_server_config(server_ck.cert_der, server_ck.key_der, ca_cert_der)?;
61    let acceptor = TlsAcceptor::from(tls_config);
62
63    let state = Arc::new(ServerState {
64        clients: Arc::new(tokio::sync::RwLock::new(HashMap::new())),
65        next_session_id: std::sync::atomic::AtomicU64::new(1),
66        next_tunnel_id: AtomicU32::new(1),
67    });
68
69    let listener = TcpListener::bind(&listen_addr).await?;
70    info!(
71        "C2 server listening on {listen_addr} (cert name: {server_name}, mTLS required)"
72    );
73
74    let state_stdin = state.clone();
75    tokio::spawn(async move {
76        if let Err(e) = stdin_command_loop(state_stdin).await {
77            error!("Stdin command loop error: {e}");
78        }
79    });
80
81    // Limit concurrent handshakes to prevent pre-auth exhaustion
82    let handshake_semaphore = Arc::new(tokio::sync::Semaphore::new(64));
83
84    loop {
85        let (stream, peer) = listener.accept().await?;
86        let acceptor = acceptor.clone();
87        let state = state.clone();
88        let sem = handshake_semaphore.clone();
89
90        tokio::spawn(async move {
91            // Acquire permit for handshake only
92            let permit = match sem.try_acquire() {
93                Ok(p) => p,
94                Err(_) => {
95                    warn!("Rejecting {peer}: too many concurrent handshakes");
96                    return;
97                }
98            };
99
100            // Perform TLS + WS handshake under the permit
101            let handshake_result = perform_handshake(stream, peer, &acceptor).await;
102            drop(permit); // Release immediately after handshake
103
104            match handshake_result {
105                Ok((ws_stream, fingerprint, cn)) => {
106                    let (shutdown_tx, shutdown_rx) = tokio::sync::watch::channel(false);
107                    match run_session(ws_stream, peer, fingerprint, cn, state, shutdown_tx, shutdown_rx).await {
108                        Ok(()) => info!("Client {peer} disconnected"),
109                        Err(e) => warn!("Client {peer} error: {e}"),
110                    }
111                }
112                Err(e) => warn!("Client {peer} handshake error: {e}"),
113            }
114        });
115    }
116}
117
118/// Perform TLS + WebSocket handshake with timeouts. Returns (ws_stream, fingerprint, cn).
119async fn perform_handshake(
120    stream: TcpStream,
121    peer: std::net::SocketAddr,
122    acceptor: &TlsAcceptor,
123) -> Result<(ws::ServerWsStream, String, String)> {
124    let tls_stream = tokio::time::timeout(
125        std::time::Duration::from_secs(15),
126        acceptor.accept(stream),
127    )
128    .await
129    .map_err(|_| ProxyError::Other(format!("TLS handshake timed out for {peer}")))?
130    .map_err(|e| ProxyError::Other(format!("TLS handshake failed for {peer}: {e}")))?;
131
132    let (fingerprint, cn) = extract_client_identity(&tls_stream);
133    info!("Client authenticated: {cn} [{fingerprint}] ({peer})");
134
135    let ws_stream = tokio::time::timeout(
136        std::time::Duration::from_secs(10),
137        ws::accept_ws(tls_stream),
138    )
139    .await
140    .map_err(|_| ProxyError::Other(format!("WebSocket upgrade timed out for {peer}")))?
141    ?;
142
143    Ok((ws_stream, fingerprint, cn))
144}
145
146/// Run the authenticated C2 session after handshake.
147async fn run_session(
148    ws_stream: ws::ServerWsStream,
149    _peer: std::net::SocketAddr,
150    fingerprint: String,
151    cn: String,
152    state: Arc<ServerState>,
153    shutdown_tx: tokio::sync::watch::Sender<bool>,
154    mut shutdown_rx: tokio::sync::watch::Receiver<bool>,
155) -> Result<()> {
156    let client_label = format!("{cn} [{fingerprint}]");
157    let (mut ws_sink, mut ws_source) = ws_stream.split();
158
159    let channels = Arc::new(ChannelMap::new(2)); // Server uses even IDs
160    let (ws_tx, mut ws_rx) = mpsc::channel::<Message>(256);
161
162    let reverse_listeners: Arc<tokio::sync::RwLock<HashMap<u32, tokio::task::AbortHandle>>> =
163        Arc::new(tokio::sync::RwLock::new(HashMap::new()));
164    let pending_reverse: Arc<tokio::sync::RwLock<HashMap<u32, u16>>> =
165        Arc::new(tokio::sync::RwLock::new(HashMap::new()));
166    let authorized_tunnels: Arc<tokio::sync::RwLock<std::collections::HashSet<u32>>> =
167        Arc::new(tokio::sync::RwLock::new(std::collections::HashSet::new()));
168    let pending_socks: Arc<tokio::sync::RwLock<std::collections::HashSet<u32>>> =
169        Arc::new(tokio::sync::RwLock::new(std::collections::HashSet::new()));
170
171    let session_id = state
172        .next_session_id
173        .fetch_add(1, std::sync::atomic::Ordering::Relaxed);
174
175    // If a session with the same fingerprint exists, evict it (stale/half-open).
176    {
177        let mut clients = state.clients.write().await;
178        if let Some(old) = clients.remove(&fingerprint) {
179            warn!("[{client_label}] Evicting stale session for reconnect");
180            // Wipe all authorization/pending state so the old task cannot act
181            old.authorized_tunnels.write().await.clear();
182            old.pending_socks.write().await.clear();
183            old.pending_reverse.write().await.clear();
184            // Close all channels — drops senders, unblocking relay tasks
185            old.channels.close_all().await;
186            // Abort reverse listeners so ports are freed
187            for handle in old.reverse_listeners.write().await.drain() {
188                handle.1.abort();
189            }
190            // Signal the old session to shut down
191            let _ = old.shutdown_tx.send(true);
192            drop(old);
193        }
194        clients.insert(
195            fingerprint.clone(),
196            ClientHandle {
197                cn: cn.clone(),
198                session_id,
199                ws_tx: ws_tx.clone(),
200                shutdown_tx,
201                channels: channels.clone(),
202                pending_reverse: pending_reverse.clone(),
203                pending_socks: pending_socks.clone(),
204                authorized_tunnels: authorized_tunnels.clone(),
205                reverse_listeners: reverse_listeners.clone(),
206            },
207        );
208    }
209
210    // Writer task
211    let label_writer = client_label.clone();
212    let writer_handle = tokio::spawn(async move {
213        while let Some(msg) = ws_rx.recv().await {
214            if ws_sink.send(msg).await.is_err() {
215                info!("[{label_writer}] WS write closed");
216                break;
217            }
218        }
219    });
220
221    // Reader loop
222    let channels_reader = channels.clone();
223    let ws_tx_reader = ws_tx.clone();
224    let label_reader = client_label.clone();
225    let tunnel_state = ClientTunnelState {
226        pending_reverse: pending_reverse.clone(),
227        pending_socks: pending_socks.clone(),
228        authorized_tunnels: authorized_tunnels.clone(),
229        reverse_listeners: reverse_listeners.clone(),
230    };
231    loop {
232        let msg_result = tokio::select! {
233            msg = ws_source.next() => msg,
234            _ = shutdown_rx.changed() => {
235                info!("[{label_reader}] Session shutdown signal received");
236                break;
237            }
238        };
239        let msg = match msg_result {
240            Some(Ok(m)) => m,
241            Some(Err(e)) => {
242                warn!("[{label_reader}] WebSocket read error: {e}");
243                break;
244            }
245            None => break,
246        };
247
248        match msg {
249            Message::Text(text) => match serde_json::from_str::<WsTextMessage>(&text) {
250                Ok(WsTextMessage::Response(resp)) => {
251                    handle_response(
252                        &label_reader,
253                        &resp,
254                        &tunnel_state,
255                        &channels_reader,
256                        ws_tx_reader.clone(),
257                    )
258                    .await;
259                }
260                Ok(WsTextMessage::Control(ctrl)) => {
261                    handle_server_control(
262                        &label_reader,
263                        ctrl,
264                        channels_reader.clone(),
265                        &tunnel_state.authorized_tunnels,
266                        ws_tx_reader.clone(),
267                    )
268                    .await;
269                }
270                Ok(WsTextMessage::Command(_)) => {
271                    warn!("[{label_reader}] Unexpected command from client");
272                }
273                Err(e) => {
274                    warn!("[{label_reader}] Failed to parse message: {e}");
275                }
276            },
277            Message::Binary(data) => {
278                if let Some((channel_id, payload)) = parse_tunnel_data(&data) {
279                    if !channels_reader
280                        .send(channel_id, Bytes::copy_from_slice(payload))
281                        .await
282                    {
283                        warn!("[{label_reader}] Data for unknown channel {channel_id}");
284                    }
285                }
286            }
287            Message::Close(_) => break,
288            _ => {}
289        }
290    }
291
292    // Cleanup: abort writer, close channels, abort reverse listeners
293    writer_handle.abort();
294    channels.close_all().await;
295    {
296        let listeners = reverse_listeners.read().await;
297        for handle in listeners.values() {
298            handle.abort();
299        }
300    }
301    // Only remove from clients map if this session is still the current one (generation check)
302    {
303        let mut clients = state.clients.write().await;
304        if let Some(existing) = clients.get(&fingerprint) {
305            if existing.session_id == session_id {
306                clients.remove(&fingerprint);
307            }
308        }
309    }
310    info!("[{client_label}] Client removed");
311
312    Ok(())
313}
314
315/// Per-client tunnel state used by handle_response.
316struct ClientTunnelState {
317    pending_reverse: Arc<tokio::sync::RwLock<HashMap<u32, u16>>>,
318    pending_socks: Arc<tokio::sync::RwLock<std::collections::HashSet<u32>>>,
319    authorized_tunnels: Arc<tokio::sync::RwLock<std::collections::HashSet<u32>>>,
320    reverse_listeners: Arc<tokio::sync::RwLock<HashMap<u32, tokio::task::AbortHandle>>>,
321}
322
323/// Handle client responses — authorize tunnels on Ok, revoke on Error.
324async fn handle_response(
325    label: &str,
326    resp: &CommandResponse,
327    ts: &ClientTunnelState,
328    channels: &Arc<ChannelMap>,
329    ws_tx: mpsc::Sender<Message>,
330) {
331    match resp {
332        CommandResponse::SocksReady { tunnel_id: tid } => {
333            if ts.pending_socks.write().await.remove(tid) {
334                ts.authorized_tunnels.write().await.insert(*tid);
335                info!("[{label}] SOCKS tunnel {tid} authorized via SocksReady");
336            } else {
337                warn!("[{label}] Unexpected SocksReady for tunnel {tid}");
338            }
339        }
340        CommandResponse::ReverseTunnelReady { tunnel_id: tid } => {
341            // Client validated the target — now start the reverse listener
342            let remote_port = ts.pending_reverse.write().await.remove(tid);
343            if let Some(port) = remote_port {
344                info!("[{label}] Starting reverse listener on 127.0.0.1:{port} (tunnel {tid})");
345                let channels = channels.clone();
346                let tid = *tid;
347                let label = label.to_string();
348                let handle = tokio::spawn(async move {
349                    if let Err(e) =
350                        reverse_listen_loop(port, tid, channels, ws_tx, &label).await
351                    {
352                        warn!("[{label}] Reverse listener error: {e}");
353                    }
354                });
355                ts.reverse_listeners
356                    .write()
357                    .await
358                    .insert(tid, handle.abort_handle());
359            } else {
360                info!("[{label}] Ok response: tunnel_id={tid}");
361            }
362        }
363        CommandResponse::Ok { .. } => {
364            info!("[{label}] Ok response");
365        }
366        CommandResponse::Error { tunnel_id, message } => {
367            // Revoke the specific failed tunnel, not all pending
368            if let Some(tid) = tunnel_id {
369                if ts.pending_socks.write().await.remove(tid) {
370                    ts.authorized_tunnels.write().await.remove(tid);
371                    info!("[{label}] Revoked failed SOCKS tunnel {tid}");
372                }
373                ts.pending_reverse.write().await.remove(tid);
374            }
375            warn!("[{label}] Error response: {message}");
376        }
377        CommandResponse::Pong { seq } => {
378            info!("[{label}] Pong seq={seq}");
379        }
380    }
381}
382
383/// Accept loop for reverse tunnel: binds remote_port, sends ChannelOpen for each connection.
384async fn reverse_listen_loop(
385    port: u16,
386    tunnel_id: u32,
387    channels: Arc<ChannelMap>,
388    ws_tx: mpsc::Sender<Message>,
389    label: &str,
390) -> Result<()> {
391    let listener = TcpListener::bind(format!("127.0.0.1:{port}")).await?;
392    info!("[{label}] Reverse tunnel {tunnel_id} listening on 127.0.0.1:{port}");
393
394    loop {
395        let (tcp, peer) = listener.accept().await?;
396        let channel_id = channels.alloc_id();
397        info!("[{label}] Reverse connection from {peer}, channel {channel_id}");
398
399        // Register data channel and readiness waiter BEFORE sending ChannelOpen
400        // so inbound data frames are buffered even if peer responds instantly.
401        let (data_tx, data_rx) = mpsc::channel::<Bytes>(256);
402        channels.insert_with_tunnel(channel_id, tunnel_id, data_tx).await;
403        let ready_rx = channels.wait_ready(channel_id).await;
404
405        let open = WsTextMessage::Control(ControlMessage::ChannelOpen {
406            channel_id,
407            tunnel_id,
408            target: None,
409        });
410        if let Ok(json) = serde_json::to_string(&open) {
411            if ws_tx.send(Message::Text(json)).await.is_err() {
412                break Ok(());
413            }
414        }
415
416        let channels = channels.clone();
417        let ws_tx = ws_tx.clone();
418        let label = label.to_string();
419        tokio::spawn(async move {
420            // Timeout readiness wait to prevent indefinite hangs from non-responsive clients
421            let ready_result = tokio::time::timeout(
422                std::time::Duration::from_secs(10),
423                ready_rx,
424            )
425            .await;
426            if ready_result.is_err() || ready_result.unwrap().is_err() {
427                warn!("[{label}] Channel {channel_id} ready timeout or signal dropped");
428                channels.remove(channel_id).await;
429                let close = WsTextMessage::Control(ControlMessage::ChannelClose { channel_id });
430                if let Ok(json) = serde_json::to_string(&close) {
431                    let _ = ws_tx.send(Message::Text(json)).await;
432                }
433                return;
434            }
435            relay_tcp_ws(tcp, channel_id, data_rx, channels, ws_tx, &label).await;
436        });
437    }
438}
439
440/// Handle control messages from client on the server side.
441async fn handle_server_control(
442    label: &str,
443    ctrl: ControlMessage,
444    channels: Arc<ChannelMap>,
445    authorized_tunnels: &tokio::sync::RwLock<std::collections::HashSet<u32>>,
446    ws_tx: mpsc::Sender<Message>,
447) {
448    match ctrl {
449        ControlMessage::ChannelOpen {
450            channel_id,
451            tunnel_id,
452            target,
453        } => {
454            // Validate channel_id: must be odd (client-originated) and not already in use
455            if channel_id % 2 == 0 {
456                warn!("[{label}] Rejected ChannelOpen with even channel_id {channel_id}");
457                return;
458            }
459            if channels.has(channel_id).await {
460                warn!("[{label}] Rejected ChannelOpen with duplicate channel_id {channel_id}");
461                let close =
462                    WsTextMessage::Control(ControlMessage::ChannelClose { channel_id });
463                if let Ok(json) = serde_json::to_string(&close) {
464                    let _ = ws_tx.send(Message::Text(json)).await;
465                }
466                return;
467            }
468
469            // Validate: only allow ChannelOpen for operator-authorized tunnels
470            if !authorized_tunnels.read().await.contains(&tunnel_id) {
471                warn!(
472                    "[{label}] Rejected unsolicited ChannelOpen for tunnel {tunnel_id}, channel {channel_id}"
473                );
474                let close =
475                    WsTextMessage::Control(ControlMessage::ChannelClose { channel_id });
476                if let Ok(json) = serde_json::to_string(&close) {
477                    let _ = ws_tx.send(Message::Text(json)).await;
478                }
479                return;
480            }
481
482            let target = match target {
483                Some(t) => t,
484                None => {
485                    warn!("[{label}] ChannelOpen without target");
486                    return;
487                }
488            };
489
490            // Reserve the channel_id atomically BEFORE async connect
491            // to prevent duplicate ChannelOpen from creating parallel connections.
492            let (data_tx, data_rx) = mpsc::channel::<Bytes>(256);
493            channels
494                .insert_with_tunnel(channel_id, tunnel_id, data_tx)
495                .await;
496
497            info!("[{label}] Channel {channel_id} -> connecting to {target}");
498
499            let channels = channels.clone();
500            let label = label.to_string();
501            tokio::spawn(async move {
502                // Bounded connect timeout to prevent indefinite hangs on blackholed targets
503                let connect_result = tokio::time::timeout(
504                    std::time::Duration::from_secs(10),
505                    TcpStream::connect(&target),
506                )
507                .await;
508                match connect_result {
509                    Ok(Ok(tcp)) => {
510                        // Re-check channel is still registered (not revoked during connect)
511                        if !channels.has(channel_id).await {
512                            warn!("[{label}] Channel {channel_id} revoked during connect, dropping");
513                            drop(tcp);
514                            return;
515                        }
516
517                        info!("[{label}] Channel {channel_id} connected to {target}");
518
519                        let ready = WsTextMessage::Control(ControlMessage::ChannelReady {
520                            channel_id,
521                        });
522                        if let Ok(json) = serde_json::to_string(&ready) {
523                            let _ = ws_tx.send(Message::Text(json)).await;
524                        }
525
526                        relay_tcp_ws(tcp, channel_id, data_rx, channels, ws_tx.clone(), &label)
527                            .await;
528                    }
529                    Ok(Err(e)) => {
530                        warn!("[{label}] Failed to connect to {target}: {e}");
531                        channels.remove(channel_id).await;
532                        let close =
533                            WsTextMessage::Control(ControlMessage::ChannelClose { channel_id });
534                        if let Ok(json) = serde_json::to_string(&close) {
535                            let _ = ws_tx.send(Message::Text(json)).await;
536                        }
537                    }
538                    Err(_) => {
539                        warn!("[{label}] Connect to {target} timed out for channel {channel_id}");
540                        channels.remove(channel_id).await;
541                        let close =
542                            WsTextMessage::Control(ControlMessage::ChannelClose { channel_id });
543                        if let Ok(json) = serde_json::to_string(&close) {
544                            let _ = ws_tx.send(Message::Text(json)).await;
545                        }
546                    }
547                }
548            });
549        }
550        ControlMessage::ChannelReady { channel_id } => {
551            channels.signal_ready(channel_id).await;
552            info!("[{label}] Channel {channel_id} ready");
553        }
554        ControlMessage::ChannelClose { channel_id } => {
555            channels.remove(channel_id).await;
556            info!("[{label}] Channel {channel_id} closed");
557        }
558    }
559}
560
561/// Bidirectional relay between a TCP stream and a WS channel.
562/// `data_rx` must already be registered in `channels` before calling this.
563async fn relay_tcp_ws(
564    tcp: TcpStream,
565    channel_id: u32,
566    mut data_rx: mpsc::Receiver<Bytes>,
567    channels: Arc<ChannelMap>,
568    ws_tx: mpsc::Sender<Message>,
569    label: &str,
570) {
571    let (mut tcp_read, mut tcp_write) = tcp.into_split();
572
573    let ws2tcp = tokio::spawn(async move {
574        while let Some(data) = data_rx.recv().await {
575            if tcp_write.write_all(&data).await.is_err() {
576                break;
577            }
578        }
579        let _ = tcp_write.shutdown().await;
580    });
581
582    let ws_tx_data = ws_tx.clone();
583    let tcp2ws = tokio::spawn(async move {
584        let mut buf = vec![0u8; 8192];
585        loop {
586            match tcp_read.read(&mut buf).await {
587                Ok(0) | Err(_) => break,
588                Ok(n) => {
589                    let frame = frame_tunnel_data(channel_id, &buf[..n]);
590                    if ws_tx_data.send(Message::Binary(frame)).await.is_err() {
591                        break;
592                    }
593                }
594            }
595        }
596    });
597
598    // When first direction finishes: notify peer, give grace period to drain,
599    // then remove channel routing and force-abort.
600    let ws2tcp_abort = ws2tcp.abort_handle();
601    let tcp2ws_abort = tcp2ws.abort_handle();
602
603    tokio::select! {
604        _ = ws2tcp => {}
605        _ = tcp2ws => {}
606    }
607
608    // Notify peer that we're closing (channel stays registered for drain)
609    let close = WsTextMessage::Control(ControlMessage::ChannelClose { channel_id });
610    if let Ok(json) = serde_json::to_string(&close) {
611        let _ = ws_tx.send(Message::Text(json)).await;
612    }
613
614    // Grace period: channel stays registered so in-flight frames can still be delivered
615    tokio::time::sleep(std::time::Duration::from_secs(5)).await;
616
617    // Now remove channel routing and force-abort any remaining task
618    channels.remove(channel_id).await;
619    ws2tcp_abort.abort();
620    tcp2ws_abort.abort();
621    info!("[{label}] Channel {channel_id} closed");
622}
623
624/// Extract (fingerprint, CN) from a client's peer certificate.
625/// Fingerprint is hex-encoded SHA-256 of the raw DER certificate (first 16 hex chars).
626fn extract_client_identity(
627    tls_stream: &tokio_rustls::server::TlsStream<TcpStream>,
628) -> (String, String) {
629    let (_, server_conn) = tls_stream.get_ref();
630    let certs = server_conn.peer_certificates().unwrap_or_default();
631    let cert_der = match certs.first() {
632        Some(c) => c.as_ref(),
633        None => return ("unknown".into(), "unknown".into()),
634    };
635
636    // SHA-256 fingerprint of the raw DER certificate
637    let digest = ring::digest::digest(&ring::digest::SHA256, cert_der);
638    let fingerprint: String = digest.as_ref().iter().map(|b| format!("{b:02x}")).collect();
639
640    let cn = extract_cn_from_der(cert_der).unwrap_or_else(|| "unknown".into());
641    (fingerprint, cn)
642}
643
644/// Extract the LAST CN from a DER-encoded certificate (minimal ASN.1 parsing).
645/// In X.509, issuer DN appears before subject DN, so the last CN OID match
646/// corresponds to the subject (leaf) CN, not the issuer (CA) CN.
647fn extract_cn_from_der(der: &[u8]) -> Option<String> {
648    let cn_oid = [0x55, 0x04, 0x03];
649    let mut last_cn: Option<String> = None;
650    for i in 0..der.len().saturating_sub(3) {
651        if der[i..i + 3] == cn_oid {
652            let val_start = i + 3;
653            if val_start + 2 <= der.len() {
654                let _tag = der[val_start];
655                let len = der[val_start + 1] as usize;
656                let str_start = val_start + 2;
657                if str_start + len <= der.len() {
658                    if let Ok(s) = String::from_utf8(der[str_start..str_start + len].to_vec()) {
659                        last_cn = Some(s);
660                    }
661                }
662            }
663        }
664    }
665    last_cn
666}
667
668/// Read commands from stdin and dispatch to connected clients.
669async fn stdin_command_loop(state: Arc<ServerState>) -> Result<()> {
670    let stdin = tokio::io::stdin();
671    let reader = BufReader::new(stdin);
672    let mut lines = reader.lines();
673
674    while let Ok(Some(line)) = lines.next_line().await {
675        let line = line.trim().to_string();
676        if line.is_empty() {
677            continue;
678        }
679
680        let parts: Vec<&str> = line.split_whitespace().collect();
681        match parts.first().copied() {
682            Some("list") => {
683                let clients = state.clients.read().await;
684                if clients.is_empty() {
685                    info!("No connected clients");
686                } else {
687                    for (fp, handle) in clients.iter() {
688                        info!("  - {} [{}]", handle.cn, fp);
689                    }
690                }
691            }
692            Some("socks") if parts.len() == 3 => {
693                let cn = parts[1];
694                let port: u16 = match parts[2].parse() {
695                    Ok(p) => p,
696                    Err(_) => {
697                        warn!("Invalid port: {}", parts[2]);
698                        continue;
699                    }
700                };
701                let tunnel_id = state.alloc_tunnel_id();
702                // Do NOT authorize yet — wait for client Ok response.
703                // Authorization happens in handle_response on success.
704                {
705                    let clients = state.clients.read().await;
706                    if let Some(client) = find_client_in_map(&clients, cn) {
707                        // Track as pending so handle_response knows to authorize on Ok
708                        client
709                            .pending_socks
710                            .write()
711                            .await
712                            .insert(tunnel_id);
713                    }
714                }
715                send_command_to_client(
716                    &state,
717                    cn,
718                    WsTextMessage::Command(Command::Socks { tunnel_id, port }),
719                )
720                .await;
721            }
722            Some("reverse") if parts.len() == 4 => {
723                let cn = parts[1];
724                let remote_port: u16 = match parts[2].parse() {
725                    Ok(p) => p,
726                    Err(_) => {
727                        warn!("Invalid port: {}", parts[2]);
728                        continue;
729                    }
730                };
731                let local_target = parts[3].to_string();
732                let tunnel_id = state.alloc_tunnel_id();
733
734                // Store pending reverse tunnel so we start the listener on Ok response
735                // We need access to the per-client pending_reverse map, but it's inside handle_client.
736                // Instead, we use a shared state approach: store in ServerState.
737                // For simplicity, we broadcast the reverse command and track the tunnel_id
738                // globally. The handle_response for the specific client will pick it up.
739                send_command_to_client_with_reverse(
740                    &state,
741                    cn,
742                    tunnel_id,
743                    remote_port,
744                    local_target,
745                )
746                .await;
747            }
748            Some("stop") if parts.len() == 3 => {
749                let cn = parts[1];
750                let tunnel_id: u32 = match parts[2].parse() {
751                    Ok(id) => id,
752                    Err(_) => {
753                        warn!("Invalid tunnel ID: {}", parts[2]);
754                        continue;
755                    }
756                };
757                // Fully revoke: clear all pending + active state for this tunnel
758                {
759                    let clients = state.clients.read().await;
760                    if let Some(client) = find_client_in_map(&clients, cn) {
761                        // Clear pending state so late acks cannot resurrect the tunnel
762                        client.pending_socks.write().await.remove(&tunnel_id);
763                        client.pending_reverse.write().await.remove(&tunnel_id);
764                        client.authorized_tunnels.write().await.remove(&tunnel_id);
765                        if let Some(handle) =
766                            client.reverse_listeners.write().await.remove(&tunnel_id)
767                        {
768                            handle.abort();
769                            info!("Aborted reverse listener for tunnel {tunnel_id}");
770                        }
771                        let closed = client.channels.close_tunnel(tunnel_id).await;
772                        if !closed.is_empty() {
773                            info!("Closed {} server-side channels for tunnel {tunnel_id}", closed.len());
774                        }
775                    }
776                }
777                send_command_to_client(
778                    &state,
779                    cn,
780                    WsTextMessage::Command(Command::StopTunnel { tunnel_id }),
781                )
782                .await;
783            }
784            Some("help") | Some("?") => {
785                info!("Commands:");
786                info!("  list                                              - List connected clients");
787                info!("  socks <client_cn> <port>                          - Start SOCKS5 on client");
788                info!("  reverse <client_cn> <remote_port> <local_target>  - Reverse tunnel");
789                info!("  stop <client_cn> <tunnel_id>                      - Stop a tunnel");
790            }
791            _ => {
792                warn!("Unknown command. Type 'help' for usage.");
793            }
794        }
795    }
796
797    Ok(())
798}
799
800/// Find a client by CN or fingerprint and send a command.
801async fn send_command_to_client(state: &ServerState, id: &str, msg: WsTextMessage) {
802    let ws_tx = {
803        let clients = state.clients.read().await;
804        match find_client_in_map(&clients, id) {
805            Some(client) => client.ws_tx.clone(),
806            None => return,
807        }
808    };
809    if let Ok(json) = serde_json::to_string(&msg) {
810        if ws_tx.send(Message::Text(json)).await.is_err() {
811            warn!("Failed to send to {id}");
812        } else {
813            info!("Sent command to {id}");
814        }
815    }
816}
817
818/// Send a ReverseTunnel command and register the pending tunnel_id -> remote_port.
819async fn send_command_to_client_with_reverse(
820    state: &ServerState,
821    id: &str,
822    tunnel_id: u32,
823    remote_port: u16,
824    local_target: String,
825) {
826    let msg = WsTextMessage::Command(Command::ReverseTunnel {
827        tunnel_id,
828        remote_port,
829        local_target,
830    });
831    let (ws_tx, pending_reverse) = {
832        let clients = state.clients.read().await;
833        match find_client_in_map(&clients, id) {
834            Some(client) => (client.ws_tx.clone(), client.pending_reverse.clone()),
835            None => return,
836        }
837    };
838    pending_reverse.write().await.insert(tunnel_id, remote_port);
839    if let Ok(json) = serde_json::to_string(&msg) {
840        if ws_tx.send(Message::Text(json)).await.is_err() {
841            warn!("Failed to send to {id}");
842            pending_reverse.write().await.remove(&tunnel_id);
843        } else {
844            info!("Sent reverse tunnel command to {id} (tunnel {tunnel_id}, port {remote_port})");
845        }
846    }
847}
848
849/// Look up a client by fingerprint (prefix) or CN. Returns None on ambiguity.
850fn find_client_in_map<'a>(
851    clients: &'a HashMap<String, ClientHandle>,
852    id: &str,
853) -> Option<&'a ClientHandle> {
854    // Exact fingerprint
855    if let Some(handle) = clients.get(id) {
856        return Some(handle);
857    }
858    // Fingerprint prefix
859    let fp_matches: Vec<_> = clients
860        .iter()
861        .filter(|(fp, _)| fp.starts_with(id))
862        .collect();
863    if fp_matches.len() == 1 {
864        return Some(fp_matches[0].1);
865    }
866    // CN match (reject ambiguous)
867    let cn_matches: Vec<_> = clients.values().filter(|h| h.cn == id).collect();
868    match cn_matches.len() {
869        1 => Some(cn_matches[0]),
870        0 => {
871            warn!("Client not found: {id}");
872            None
873        }
874        n => {
875            warn!("Ambiguous CN '{id}' matches {n} clients. Use fingerprint instead.");
876            None
877        }
878    }
879}