Skip to main content

synapse_pingora/tunnel/
shell.rs

1//! Remote shell handler for the tunnel client.
2
3use base64::engine::general_purpose::STANDARD;
4use base64::Engine;
5use dashmap::DashMap;
6use portable_pty::{native_pty_system, CommandBuilder, PtySize};
7use std::io::{Read, Write};
8use std::sync::{Arc, Mutex};
9use std::time::{Duration, Instant};
10use tokio::sync::broadcast;
11use tracing::{debug, warn};
12
13use super::client::TunnelClientHandle;
14use super::types::LegacyTunnelMessage;
15use crate::metrics::MetricsRegistry;
16
17const DEFAULT_COLS: u16 = 80;
18const DEFAULT_ROWS: u16 = 24;
19
20/// Dangerous terminal escape sequence prefixes that can be used for terminal injection attacks.
21/// These allow arbitrary command injection, clipboard access, or window manipulation.
22const DANGEROUS_ESCAPE_PREFIXES: &[&[u8]] = &[
23    b"\x1b]", // OSC - Operating System Command (clipboard, window title injection)
24    b"\x1bP", // DCS - Device Control String (Sixel, ReGIS, DECRQSS attacks)
25    b"\x1b_", // APC - Application Program Command
26    b"\x1b^", // PM - Privacy Message
27    b"\x1bX", // SOS - Start of String
28];
29
30/// Filter dangerous terminal escape sequences from input data.
31/// Allows normal CSI sequences (ESC [) for cursor control and colors,
32/// but blocks OSC, DCS, APC, PM, and SOS sequences that can be exploited.
33fn filter_dangerous_escapes(input: &[u8]) -> Vec<u8> {
34    let mut output = Vec::with_capacity(input.len());
35    let mut i = 0;
36
37    while i < input.len() {
38        // Check for ESC character
39        if input[i] == 0x1b && i + 1 < input.len() {
40            let next_byte = input[i + 1];
41
42            // Check if this is a dangerous escape sequence
43            let is_dangerous = DANGEROUS_ESCAPE_PREFIXES
44                .iter()
45                .any(|prefix| input[i..].starts_with(prefix));
46
47            if is_dangerous {
48                // Skip the dangerous escape sequence
49                // Find the sequence terminator (ST = ESC \ or BEL, or end of input)
50                let mut j = i + 2;
51                while j < input.len() {
52                    // String Terminator: ESC \
53                    if j + 1 < input.len() && input[j] == 0x1b && input[j + 1] == b'\\' {
54                        j += 2;
55                        break;
56                    }
57                    // BEL (often used to terminate OSC)
58                    if input[j] == 0x07 {
59                        j += 1;
60                        break;
61                    }
62                    j += 1;
63                }
64                i = j;
65                continue;
66            } else if next_byte == b'[' {
67                // CSI sequence - allowed (cursor control, colors, etc.)
68                output.push(input[i]);
69                i += 1;
70                continue;
71            }
72        }
73
74        // Normal byte or safe escape - pass through
75        output.push(input[i]);
76        i += 1;
77    }
78
79    output
80}
81
82struct ShellSession {
83    writer: Arc<Mutex<Box<dyn Write + Send>>>,
84    master: Arc<Mutex<Box<dyn portable_pty::MasterPty + Send>>>,
85    child: Arc<Mutex<Box<dyn portable_pty::Child + Send>>>,
86    shell: String,
87}
88
89/// Remote shell handler for tunnel legacy messages.
90pub struct TunnelShellService {
91    handle: TunnelClientHandle,
92    sessions: Arc<DashMap<String, ShellSession>>,
93    default_shell: String,
94    enabled: bool,
95    metrics: Arc<MetricsRegistry>,
96}
97
98impl TunnelShellService {
99    /// Create a new shell service with the given tunnel handle.
100    pub fn new(handle: TunnelClientHandle, enabled: bool, metrics: Arc<MetricsRegistry>) -> Self {
101        let default_shell = std::env::var("SHELL").unwrap_or_else(|_| "/bin/bash".to_string());
102        Self {
103            handle,
104            sessions: Arc::new(DashMap::new()),
105            default_shell,
106            enabled,
107            metrics,
108        }
109    }
110
111    /// Start the shell service (listens for legacy tunnel messages).
112    pub async fn run(self, mut shutdown_rx: broadcast::Receiver<()>) {
113        if !self.enabled {
114            warn!("Remote shell service is disabled by configuration");
115            return;
116        }
117
118        let mut rx = self.handle.subscribe_legacy();
119        loop {
120            tokio::select! {
121                message = rx.recv() => {
122                    match message {
123                        Ok(message) => {
124                            let started = Instant::now();
125                            self.handle_message(message).await;
126                            self.metrics
127                                .tunnel_metrics()
128                                .record_handler_latency_ms(
129                                    super::types::TunnelChannel::Shell,
130                                    started.elapsed().as_millis() as u64,
131                                );
132                        }
133                        Err(broadcast::error::RecvError::Lagged(count)) => {
134                            warn!("Shell service lagged by {} messages", count);
135                            continue;
136                        }
137                        Err(broadcast::error::RecvError::Closed) => {
138                            warn!("Shell service channel closed");
139                            break;
140                        }
141                    }
142                }
143                _ = shutdown_rx.recv() => {
144                    debug!("Shell service shutdown signal received");
145                    break;
146                }
147            }
148        }
149
150        // Cleanup all active sessions
151        let session_ids: Vec<String> = self.sessions.iter().map(|e| e.key().clone()).collect();
152        for id in session_ids {
153            self.end_session(&id, "service shutdown");
154        }
155    }
156
157    async fn handle_message(&self, message: LegacyTunnelMessage) {
158        if !self.enabled {
159            if let Some(session_id) = message.session_id {
160                self.send_shell_error(&session_id, "remote shell is disabled on this sensor");
161            }
162            return;
163        }
164
165        match message.message_type.as_str() {
166            "shell-data" => {
167                self.handle_shell_data(message).await;
168            }
169            "shell-resize" => {
170                self.handle_shell_resize(message).await;
171            }
172            _ => {}
173        }
174    }
175
176    async fn handle_shell_data(&self, message: LegacyTunnelMessage) {
177        let Some(session_id) = message.session_id.clone() else {
178            warn!("shell-data received without sessionId");
179            return;
180        };
181
182        let payload = message.payload;
183
184        if let Some(action) = payload.get("action").and_then(|value| value.as_str()) {
185            match action {
186                "start" => {
187                    let cols = payload
188                        .get("cols")
189                        .and_then(|value| value.as_u64())
190                        .map(|value| value as u16)
191                        .unwrap_or(DEFAULT_COLS);
192                    let rows = payload
193                        .get("rows")
194                        .and_then(|value| value.as_u64())
195                        .map(|value| value as u16)
196                        .unwrap_or(DEFAULT_ROWS);
197                    if let Err(err) = self.start_session(&session_id, cols, rows) {
198                        self.send_shell_error(&session_id, err);
199                    }
200                    return;
201                }
202                "end" => {
203                    self.end_session(&session_id, "session ended");
204                    return;
205                }
206                _ => {}
207            }
208        }
209
210        if let Some(data) = payload.get("data").and_then(|value| value.as_str()) {
211            if let Err(err) = self.write_input(&session_id, data) {
212                self.send_shell_error(&session_id, err);
213            }
214        }
215    }
216
217    async fn handle_shell_resize(&self, message: LegacyTunnelMessage) {
218        let Some(session_id) = message.session_id.clone() else {
219            warn!("shell-resize received without sessionId");
220            return;
221        };
222        let cols = message
223            .payload
224            .get("cols")
225            .and_then(|value| value.as_u64())
226            .map(|value| value as u16)
227            .unwrap_or(DEFAULT_COLS);
228        let rows = message
229            .payload
230            .get("rows")
231            .and_then(|value| value.as_u64())
232            .map(|value| value as u16)
233            .unwrap_or(DEFAULT_ROWS);
234
235        if let Some(session) = self.sessions.get(&session_id) {
236            if let Err(err) = session
237                .master
238                .lock()
239                .unwrap_or_else(|p| p.into_inner())
240                .resize(PtySize {
241                    rows,
242                    cols,
243                    pixel_width: 0,
244                    pixel_height: 0,
245                })
246            {
247                self.send_shell_error(&session_id, format!("resize failed: {}", err));
248            }
249        } else {
250            warn!("shell-resize for unknown session {}", session_id);
251        }
252    }
253
254    fn start_session(&self, session_id: &str, cols: u16, rows: u16) -> Result<(), String> {
255        if self.sessions.contains_key(session_id) {
256            return Err("shell session already active".to_string());
257        }
258
259        let pty_system = native_pty_system();
260        let pair = pty_system
261            .openpty(PtySize {
262                rows,
263                cols,
264                pixel_width: 0,
265                pixel_height: 0,
266            })
267            .map_err(|err| format!("failed to open pty: {}", err))?;
268
269        let shell = self.default_shell.clone();
270        let mut cmd = CommandBuilder::new(&shell);
271        cmd.env("TERM", "xterm-256color");
272        let child = pair
273            .slave
274            .spawn_command(cmd)
275            .map_err(|err| format!("failed to spawn shell: {}", err))?;
276
277        let reader = pair
278            .master
279            .try_clone_reader()
280            .map_err(|err| format!("failed to clone pty reader: {}", err))?;
281        let writer = pair
282            .master
283            .take_writer()
284            .map_err(|err| format!("failed to get pty writer: {}", err))?;
285
286        let session = ShellSession {
287            writer: Arc::new(Mutex::new(writer)),
288            master: Arc::new(Mutex::new(pair.master)),
289            child: Arc::new(Mutex::new(child)),
290            shell: shell.clone(),
291        };
292
293        self.sessions.insert(session_id.to_string(), session);
294        self.send_shell_ready(session_id, &shell);
295
296        self.spawn_reader(session_id.to_string(), reader);
297        self.spawn_waiter(session_id.to_string());
298
299        Ok(())
300    }
301
302    fn spawn_reader(&self, session_id: String, mut reader: Box<dyn Read + Send>) {
303        let handle = self.handle.clone();
304        let sessions = Arc::clone(&self.sessions);
305
306        tokio::task::spawn_blocking(move || {
307            let mut buffer = [0u8; 8192];
308            loop {
309                // Check if session still exists
310                if !sessions.contains_key(&session_id) {
311                    break;
312                }
313
314                match reader.read(&mut buffer) {
315                    Ok(0) => {
316                        debug!("shell output closed for {}", session_id);
317                        break;
318                    }
319                    Ok(bytes_read) => {
320                        let encoded = STANDARD.encode(&buffer[..bytes_read]);
321                        let message = serde_json::json!({
322                            "type": "shell-data",
323                            "sessionId": session_id,
324                            "payload": { "data": encoded },
325                            "timestamp": chrono::Utc::now().to_rfc3339(),
326                        });
327                        // Use sync send for backpressure against PTY reader
328                        let _ = handle.send_json_sync(message);
329                    }
330                    Err(err) => {
331                        // Avoid logging error if it's just the session closing
332                        if sessions.contains_key(&session_id) {
333                            let message = serde_json::json!({
334                                "type": "shell-error",
335                                "sessionId": session_id,
336                                "payload": { "error": format!("shell output error: {}", err) },
337                                "timestamp": chrono::Utc::now().to_rfc3339(),
338                            });
339                            let _ = handle.send_json_sync(message);
340                        }
341                        break;
342                    }
343                }
344            }
345        });
346    }
347
348    fn spawn_waiter(&self, session_id: String) {
349        let sessions = Arc::clone(&self.sessions);
350        let handle = self.handle.clone();
351
352        tokio::task::spawn(async move {
353            loop {
354                let status = {
355                    if let Some(entry) = sessions.get(&session_id) {
356                        let mut child = entry.child.lock().unwrap_or_else(|p| p.into_inner());
357                        match child.try_wait() {
358                            Ok(Some(status)) => Some(Ok(status)),
359                            Ok(None) => None,
360                            Err(err) => Some(Err(err)),
361                        }
362                    } else {
363                        break;
364                    }
365                };
366
367                match status {
368                    Some(Ok(status)) => {
369                        let exit_code = status.exit_code();
370                        let message = serde_json::json!({
371                            "type": "shell-exit",
372                            "sessionId": session_id,
373                            "payload": { "code": exit_code },
374                            "timestamp": chrono::Utc::now().to_rfc3339(),
375                        });
376                        let _ = handle.send_json(message).await;
377                        sessions.remove(&session_id);
378                        break;
379                    }
380                    Some(Err(err)) => {
381                        let message = serde_json::json!({
382                            "type": "shell-error",
383                            "sessionId": session_id,
384                            "payload": { "error": format!("shell wait error: {}", err) },
385                            "timestamp": chrono::Utc::now().to_rfc3339(),
386                        });
387                        let _ = handle.send_json(message).await;
388                        sessions.remove(&session_id);
389                        break;
390                    }
391                    None => {
392                        // Still running, wait a bit
393                        tokio::time::sleep(Duration::from_millis(250)).await;
394                    }
395                }
396            }
397        });
398    }
399
400    fn write_input(&self, session_id: &str, data: &str) -> Result<(), String> {
401        let decoded = STANDARD
402            .decode(data.as_bytes())
403            .map_err(|err| format!("invalid base64 input: {}", err))?;
404
405        // Filter dangerous terminal escape sequences before writing to PTY
406        let sanitized = filter_dangerous_escapes(&decoded);
407
408        if let Some(session) = self.sessions.get(session_id) {
409            let mut writer = session.writer.lock().unwrap_or_else(|p| p.into_inner());
410            writer
411                .write_all(&sanitized)
412                .map_err(|err| format!("failed to write to pty: {}", err))?;
413            writer
414                .flush()
415                .map_err(|err| format!("failed to flush pty: {}", err))?;
416            Ok(())
417        } else {
418            Err("shell session not found".to_string())
419        }
420    }
421
422    fn end_session(&self, session_id: &str, reason: &str) {
423        if let Some((_, session)) = self.sessions.remove(session_id) {
424            let mut child = session.child.lock().unwrap_or_else(|p| p.into_inner());
425            if let Err(err) = child.kill() {
426                warn!("Failed to kill shell session {}: {}", session_id, err);
427            }
428        }
429
430        let message = serde_json::json!({
431            "type": "shell-exit",
432            "sessionId": session_id,
433            "payload": { "code": 0, "reason": reason },
434            "timestamp": chrono::Utc::now().to_rfc3339(),
435        });
436        let _ = self.handle.send_json_blocking(message);
437    }
438
439    fn send_shell_ready(&self, session_id: &str, shell: &str) {
440        let message = serde_json::json!({
441            "type": "shell-ready",
442            "sessionId": session_id,
443            "payload": { "shell": shell },
444            "timestamp": chrono::Utc::now().to_rfc3339(),
445        });
446        let _ = self.handle.send_json_blocking(message);
447    }
448
449    fn send_shell_error(&self, session_id: &str, error_message: impl Into<String>) {
450        let message = serde_json::json!({
451            "type": "shell-error",
452            "sessionId": session_id,
453            "payload": { "error": error_message.into() },
454            "timestamp": chrono::Utc::now().to_rfc3339(),
455        });
456        let _ = self.handle.send_json_blocking(message);
457    }
458}