Skip to main content

tldr_cli/commands/daemon/
ipc.rs

1//! Cross-platform IPC layer for daemon communication
2//!
3//! This module provides socket-based IPC for the TLDR daemon using:
4//! - Unix domain sockets on Unix systems (Linux, macOS)
5//! - TCP localhost connections on Windows
6//!
7//! # Security Mitigations
8//!
9//! - TIGER-P3-01: Socket path validation (no temp dir escapes)
10//! - TIGER-P3-03: Message size limits (10MB max) to prevent OOM
11//! - TIGER-P3-04: Symlink rejection at socket path
12//! - Unix sockets created with 0600 permissions (owner-only)
13//!
14//! # Protocol
15//!
16//! Newline-delimited JSON:
17//! - Client sends: `{"cmd": "...", ...}\n`
18//! - Server responds: `{...}\n`
19
20use std::io;
21use std::path::{Path, PathBuf};
22
23use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
24
25use crate::commands::daemon::error::{DaemonError, DaemonResult};
26use crate::commands::daemon::pid::compute_hash;
27use crate::commands::daemon::types::{DaemonCommand, DaemonResponse};
28
29// =============================================================================
30// Constants
31// =============================================================================
32
33/// Maximum message size in bytes (10MB)
34/// This prevents malicious clients from causing OOM via oversized messages.
35/// (TIGER-P3-03)
36pub const MAX_MESSAGE_SIZE: usize = 10 * 1024 * 1024;
37
38/// Connection timeout in seconds
39pub const CONNECTION_TIMEOUT_SECS: u64 = 5;
40
41/// Read timeout in seconds
42pub const READ_TIMEOUT_SECS: u64 = 30;
43
44// =============================================================================
45// Path/Port Computation
46// =============================================================================
47
48/// Compute the socket path for a project (Unix).
49///
50/// Path format: `{temp_dir}/tldr-{hash}.sock`
51/// Uses same hash as PID file for consistency.
52///
53/// # Security (TIGER-P3-01)
54///
55/// The path is validated to ensure it stays within the temp directory
56/// and doesn't escape via symlinks or path traversal.
57#[cfg(unix)]
58pub fn compute_socket_path(project: &Path) -> PathBuf {
59    let hash = compute_hash(project);
60    let tmp_dir = std::env::temp_dir();
61    tmp_dir.join(format!("tldr-{}.sock", hash))
62}
63
64/// Compute the TCP port for a project (Windows).
65///
66/// Port range: 49152-59151 (dynamic/private port range)
67/// Uses hash to deterministically map project to port.
68#[cfg(windows)]
69pub fn compute_tcp_port(project: &Path) -> u16 {
70    let hash = compute_hash(project);
71    let hash_int = u64::from_str_radix(&hash, 16).unwrap_or(0);
72    49152 + (hash_int % 10000) as u16
73}
74
75// For cross-platform code that needs socket path on all platforms
76#[cfg(not(unix))]
77pub fn compute_socket_path(project: &Path) -> PathBuf {
78    // On Windows, return a path that won't be used (TCP is used instead)
79    let hash = compute_hash(project);
80    let tmp_dir = std::env::temp_dir();
81    tmp_dir.join(format!("tldr-{}.sock", hash))
82}
83
84#[cfg(not(windows))]
85pub fn compute_tcp_port(project: &Path) -> u16 {
86    // On Unix, return a port that won't be used (Unix socket is used instead)
87    let hash = compute_hash(project);
88    let hash_int = u64::from_str_radix(&hash, 16).unwrap_or(0);
89    49152 + (hash_int % 10000) as u16
90}
91
92// =============================================================================
93// Security Validation
94// =============================================================================
95
96/// Validate that a socket path is safe to use.
97///
98/// # Security Checks (TIGER-P3-01, TIGER-P3-04)
99///
100/// 1. Path must be within the system temp directory
101/// 2. Path must not contain symlinks
102/// 3. Path must not escape temp dir via `..` traversal
103pub fn validate_socket_path(socket_path: &Path) -> DaemonResult<()> {
104    let tmp_dir = std::env::temp_dir();
105
106    // Canonicalize temp dir (resolve symlinks in temp dir itself)
107    let canonical_tmp = tmp_dir.canonicalize().unwrap_or(tmp_dir);
108
109    // Check that socket path starts with temp dir
110    // We use the parent directory since the socket file doesn't exist yet
111    let socket_parent = socket_path.parent().unwrap_or(socket_path);
112
113    // Canonicalize parent if it exists
114    let canonical_parent = socket_parent
115        .canonicalize()
116        .unwrap_or_else(|_| socket_parent.to_path_buf());
117
118    if !canonical_parent.starts_with(&canonical_tmp) {
119        return Err(DaemonError::PermissionDenied {
120            path: socket_path.to_path_buf(),
121        });
122    }
123
124    // Check for path traversal attempts in the filename
125    if let Some(filename) = socket_path.file_name() {
126        let filename_str = filename.to_string_lossy();
127        if filename_str.contains("..") || filename_str.contains('/') || filename_str.contains('\\')
128        {
129            return Err(DaemonError::PermissionDenied {
130                path: socket_path.to_path_buf(),
131            });
132        }
133    }
134
135    Ok(())
136}
137
138/// Check if a path is a symlink.
139///
140/// # Security (TIGER-P3-04)
141///
142/// Rejects symlinks at socket path to prevent symlink attacks.
143#[cfg(unix)]
144pub fn check_not_symlink(path: &Path) -> DaemonResult<()> {
145    if let Ok(metadata) = std::fs::symlink_metadata(path) {
146        if metadata.file_type().is_symlink() {
147            return Err(DaemonError::PermissionDenied {
148                path: path.to_path_buf(),
149            });
150        }
151    }
152    Ok(())
153}
154
155#[cfg(not(unix))]
156pub fn check_not_symlink(path: &Path) -> DaemonResult<()> {
157    // Windows symlink check
158    if let Ok(metadata) = std::fs::symlink_metadata(path) {
159        if metadata.file_type().is_symlink() {
160            return Err(DaemonError::PermissionDenied {
161                path: path.to_path_buf(),
162            });
163        }
164    }
165    Ok(())
166}
167
168// =============================================================================
169// IpcListener - Server Side
170// =============================================================================
171
172/// Platform-agnostic IPC listener
173pub struct IpcListener {
174    #[cfg(unix)]
175    inner: tokio::net::UnixListener,
176    #[cfg(windows)]
177    inner: tokio::net::TcpListener,
178    /// Path to socket file (for cleanup)
179    #[allow(dead_code)]
180    socket_path: PathBuf,
181}
182
183impl IpcListener {
184    /// Bind a new IPC listener for the given project.
185    ///
186    /// # Unix
187    /// Creates a Unix domain socket at `/tmp/tldr-{hash}.sock`
188    /// with permissions 0600 (owner-only).
189    ///
190    /// # Windows
191    /// Binds to TCP localhost on a deterministic port.
192    ///
193    /// # Security
194    /// - TIGER-P3-01: Validates socket path stays in temp dir
195    /// - TIGER-P3-04: Rejects symlinks at socket path
196    pub async fn bind(project: &Path) -> DaemonResult<Self> {
197        #[cfg(unix)]
198        {
199            Self::bind_unix(project).await
200        }
201        #[cfg(windows)]
202        {
203            Self::bind_tcp(project).await
204        }
205    }
206
207    #[cfg(unix)]
208    async fn bind_unix(project: &Path) -> DaemonResult<Self> {
209        use std::os::unix::fs::PermissionsExt;
210
211        let socket_path = compute_socket_path(project);
212
213        // Validate socket path security
214        validate_socket_path(&socket_path)?;
215
216        // Check for existing symlink (TIGER-P3-04)
217        check_not_symlink(&socket_path)?;
218
219        // Remove existing socket if present (could be stale)
220        if socket_path.exists() {
221            // Double-check it's not a symlink before removing
222            check_not_symlink(&socket_path)?;
223            std::fs::remove_file(&socket_path).map_err(DaemonError::SocketBindFailed)?;
224        }
225
226        // Bind the Unix socket
227        let listener = tokio::net::UnixListener::bind(&socket_path)
228            .map_err(DaemonError::SocketBindFailed)?;
229
230        // Set socket permissions to 0600 (owner-only) - TIGER-P3-01
231        let permissions = std::fs::Permissions::from_mode(0o600);
232        std::fs::set_permissions(&socket_path, permissions)
233            .map_err(DaemonError::SocketBindFailed)?;
234
235        Ok(Self {
236            inner: listener,
237            socket_path,
238        })
239    }
240
241    #[cfg(windows)]
242    async fn bind_tcp(project: &Path) -> DaemonResult<Self> {
243        let socket_path = compute_socket_path(project); // For reference only
244        let port = compute_tcp_port(project);
245        let addr = format!("127.0.0.1:{}", port);
246
247        let listener = tokio::net::TcpListener::bind(&addr).await.map_err(|e| {
248            if e.kind() == io::ErrorKind::AddrInUse {
249                DaemonError::AddressInUse { addr }
250            } else {
251                DaemonError::SocketBindFailed(e)
252            }
253        })?;
254
255        Ok(Self {
256            inner: listener,
257            socket_path,
258        })
259    }
260
261    /// Accept a new connection.
262    ///
263    /// Returns an `IpcStream` that can be used for bidirectional communication.
264    pub async fn accept(&self) -> DaemonResult<IpcStream> {
265        #[cfg(unix)]
266        {
267            let (stream, _addr) = self.inner.accept().await.map_err(DaemonError::Io)?;
268            Ok(IpcStream {
269                inner: IpcStreamInner::Unix(stream),
270            })
271        }
272        #[cfg(windows)]
273        {
274            let (stream, _addr) = self.inner.accept().await.map_err(DaemonError::Io)?;
275            Ok(IpcStream {
276                inner: IpcStreamInner::Tcp(stream),
277            })
278        }
279    }
280}
281
282// =============================================================================
283// IpcStream - Bidirectional Communication
284// =============================================================================
285
286/// Inner stream type for platform abstraction
287enum IpcStreamInner {
288    #[cfg(unix)]
289    Unix(tokio::net::UnixStream),
290    #[cfg(windows)]
291    Tcp(tokio::net::TcpStream),
292    // Allow both variants on all platforms for testing
293    #[cfg(all(not(unix), not(windows)))]
294    Dummy,
295}
296
297/// Platform-agnostic IPC stream for bidirectional communication.
298pub struct IpcStream {
299    inner: IpcStreamInner,
300}
301
302impl IpcStream {
303    /// Connect to a daemon for the given project.
304    ///
305    /// # Unix
306    /// Connects to Unix domain socket at `/tmp/tldr-{hash}.sock`
307    ///
308    /// # Windows
309    /// Connects to TCP localhost on a deterministic port.
310    pub async fn connect(project: &Path) -> DaemonResult<Self> {
311        #[cfg(unix)]
312        {
313            Self::connect_unix(project).await
314        }
315        #[cfg(windows)]
316        {
317            Self::connect_tcp(project).await
318        }
319    }
320
321    #[cfg(unix)]
322    async fn connect_unix(project: &Path) -> DaemonResult<Self> {
323        let socket_path = compute_socket_path(project);
324
325        // Validate socket path security
326        validate_socket_path(&socket_path)?;
327
328        // Check socket exists
329        if !socket_path.exists() {
330            return Err(DaemonError::NotRunning);
331        }
332
333        // Check for symlink attack (TIGER-P3-04)
334        check_not_symlink(&socket_path)?;
335
336        // Connect with timeout
337        let connect_future = tokio::net::UnixStream::connect(&socket_path);
338        let timeout = tokio::time::Duration::from_secs(CONNECTION_TIMEOUT_SECS);
339
340        match tokio::time::timeout(timeout, connect_future).await {
341            Ok(Ok(stream)) => Ok(Self {
342                inner: IpcStreamInner::Unix(stream),
343            }),
344            Ok(Err(e)) if e.kind() == io::ErrorKind::ConnectionRefused => {
345                Err(DaemonError::ConnectionRefused)
346            }
347            Ok(Err(e)) if e.kind() == io::ErrorKind::NotFound => Err(DaemonError::NotRunning),
348            Ok(Err(e)) => Err(DaemonError::Io(e)),
349            Err(_) => Err(DaemonError::ConnectionTimeout {
350                timeout_secs: CONNECTION_TIMEOUT_SECS,
351            }),
352        }
353    }
354
355    #[cfg(windows)]
356    async fn connect_tcp(project: &Path) -> DaemonResult<Self> {
357        let port = compute_tcp_port(project);
358        let addr = format!("127.0.0.1:{}", port);
359
360        // Connect with timeout
361        let connect_future = tokio::net::TcpStream::connect(&addr);
362        let timeout = tokio::time::Duration::from_secs(CONNECTION_TIMEOUT_SECS);
363
364        match tokio::time::timeout(timeout, connect_future).await {
365            Ok(Ok(stream)) => Ok(Self {
366                inner: IpcStreamInner::Tcp(stream),
367            }),
368            Ok(Err(e)) if e.kind() == io::ErrorKind::ConnectionRefused => {
369                Err(DaemonError::ConnectionRefused)
370            }
371            Ok(Err(e)) => Err(DaemonError::Io(e)),
372            Err(_) => Err(DaemonError::ConnectionTimeout {
373                timeout_secs: CONNECTION_TIMEOUT_SECS,
374            }),
375        }
376    }
377
378    /// Send a command to the daemon.
379    ///
380    /// Serializes the command to JSON and sends with a newline delimiter.
381    pub async fn send_command(&mut self, cmd: &DaemonCommand) -> DaemonResult<()> {
382        let json = serde_json::to_string(cmd)?;
383        self.send_raw(&json).await
384    }
385
386    /// Send a raw JSON string to the daemon.
387    ///
388    /// Adds newline delimiter automatically.
389    pub async fn send_raw(&mut self, json: &str) -> DaemonResult<()> {
390        // Check message size (TIGER-P3-03)
391        if json.len() > MAX_MESSAGE_SIZE {
392            return Err(DaemonError::InvalidMessage(format!(
393                "message too large: {} bytes (max {})",
394                json.len(),
395                MAX_MESSAGE_SIZE
396            )));
397        }
398
399        let mut message = json.to_string();
400        message.push('\n');
401
402        match &mut self.inner {
403            #[cfg(unix)]
404            IpcStreamInner::Unix(stream) => {
405                stream.write_all(message.as_bytes()).await?;
406                stream.flush().await?;
407            }
408            #[cfg(windows)]
409            IpcStreamInner::Tcp(stream) => {
410                stream.write_all(message.as_bytes()).await?;
411                stream.flush().await?;
412            }
413            #[cfg(all(not(unix), not(windows)))]
414            IpcStreamInner::Dummy => {}
415        }
416
417        Ok(())
418    }
419
420    /// Receive a response from the daemon.
421    ///
422    /// Reads a newline-delimited JSON response and deserializes it.
423    pub async fn recv_response(&mut self) -> DaemonResult<DaemonResponse> {
424        let json = self.recv_raw().await?;
425        let response: DaemonResponse = serde_json::from_str(&json)?;
426        Ok(response)
427    }
428
429    /// Receive a raw JSON string from the daemon.
430    ///
431    /// Reads until newline delimiter.
432    pub async fn recv_raw(&mut self) -> DaemonResult<String> {
433        let timeout = tokio::time::Duration::from_secs(READ_TIMEOUT_SECS);
434
435        match &mut self.inner {
436            #[cfg(unix)]
437            IpcStreamInner::Unix(stream) => {
438                let mut reader = BufReader::new(stream);
439                let mut line = String::new();
440
441                let read_future = reader.read_line(&mut line);
442
443                match tokio::time::timeout(timeout, read_future).await {
444                    Ok(Ok(0)) => Err(DaemonError::ConnectionRefused), // EOF
445                    Ok(Ok(n)) if n > MAX_MESSAGE_SIZE => Err(DaemonError::InvalidMessage(format!(
446                        "response too large: {} bytes (max {})",
447                        n, MAX_MESSAGE_SIZE
448                    ))),
449                    Ok(Ok(_)) => Ok(line.trim_end().to_string()),
450                    Ok(Err(e)) => Err(DaemonError::Io(e)),
451                    Err(_) => Err(DaemonError::ConnectionTimeout {
452                        timeout_secs: READ_TIMEOUT_SECS,
453                    }),
454                }
455            }
456            #[cfg(windows)]
457            IpcStreamInner::Tcp(stream) => {
458                let mut reader = BufReader::new(stream);
459                let mut line = String::new();
460
461                let read_future = reader.read_line(&mut line);
462
463                match tokio::time::timeout(timeout, read_future).await {
464                    Ok(Ok(0)) => Err(DaemonError::ConnectionRefused), // EOF
465                    Ok(Ok(n)) if n > MAX_MESSAGE_SIZE => Err(DaemonError::InvalidMessage(format!(
466                        "response too large: {} bytes (max {})",
467                        n, MAX_MESSAGE_SIZE
468                    ))),
469                    Ok(Ok(_)) => Ok(line.trim_end().to_string()),
470                    Ok(Err(e)) => Err(DaemonError::Io(e)),
471                    Err(_) => Err(DaemonError::ConnectionTimeout {
472                        timeout_secs: READ_TIMEOUT_SECS,
473                    }),
474                }
475            }
476            #[cfg(all(not(unix), not(windows)))]
477            IpcStreamInner::Dummy => Err(DaemonError::NotRunning),
478        }
479    }
480}
481
482// =============================================================================
483// Server-side message handling
484// =============================================================================
485
486/// Read a command from a client connection.
487///
488/// Used by the daemon to receive commands from clients.
489pub async fn read_command(stream: &mut IpcStream) -> DaemonResult<DaemonCommand> {
490    let json = stream.recv_raw().await?;
491
492    // Check message size (TIGER-P3-03)
493    if json.len() > MAX_MESSAGE_SIZE {
494        return Err(DaemonError::InvalidMessage(format!(
495            "command too large: {} bytes (max {})",
496            json.len(),
497            MAX_MESSAGE_SIZE
498        )));
499    }
500
501    let cmd: DaemonCommand = serde_json::from_str(&json)?;
502    Ok(cmd)
503}
504
505/// Send a response to a client connection.
506///
507/// Used by the daemon to respond to client commands.
508pub async fn send_response(stream: &mut IpcStream, response: &DaemonResponse) -> DaemonResult<()> {
509    let json = serde_json::to_string(response)?;
510    stream.send_raw(&json).await
511}
512
513// =============================================================================
514// Cleanup
515// =============================================================================
516
517/// Clean up the socket file for a project.
518///
519/// Safe to call even if socket doesn't exist.
520pub fn cleanup_socket(project: &Path) -> DaemonResult<()> {
521    let socket_path = compute_socket_path(project);
522
523    if socket_path.exists() {
524        // Safety check: don't remove symlinks
525        check_not_symlink(&socket_path)?;
526        std::fs::remove_file(&socket_path)?;
527    }
528
529    Ok(())
530}
531
532/// Check if a socket exists and is connectable.
533///
534/// Used to detect stale sockets that can be cleaned up.
535pub async fn check_socket_alive(project: &Path) -> bool {
536    (IpcStream::connect(project).await).is_ok()
537}
538
539// =============================================================================
540// High-level client functions
541// =============================================================================
542
543/// Send a command to the daemon and receive a response.
544///
545/// Convenience function that handles connection, send, and receive.
546pub async fn send_command(project: &Path, cmd: &DaemonCommand) -> DaemonResult<DaemonResponse> {
547    let mut stream = IpcStream::connect(project).await?;
548    stream.send_command(cmd).await?;
549    stream.recv_response().await
550}
551
552/// Send a raw JSON command to the daemon and receive a raw response.
553///
554/// Useful for low-level debugging or custom commands.
555pub async fn send_raw_command(project: &Path, json: &str) -> DaemonResult<String> {
556    let mut stream = IpcStream::connect(project).await?;
557    stream.send_raw(json).await?;
558    stream.recv_raw().await
559}
560
561// =============================================================================
562// Tests
563// =============================================================================
564
565#[cfg(test)]
566mod tests {
567    use super::*;
568    use std::path::PathBuf;
569    use tempfile::TempDir;
570
571    #[test]
572    fn test_compute_socket_path_format() {
573        let project = PathBuf::from("/test/project");
574        let socket_path = compute_socket_path(&project);
575
576        let filename = socket_path.file_name().unwrap().to_str().unwrap();
577        assert!(filename.starts_with("tldr-"));
578        assert!(filename.ends_with(".sock"));
579    }
580
581    #[test]
582    fn test_compute_socket_path_deterministic() {
583        let project = PathBuf::from("/test/project");
584        let path1 = compute_socket_path(&project);
585        let path2 = compute_socket_path(&project);
586        assert_eq!(path1, path2);
587    }
588
589    #[test]
590    fn test_compute_socket_path_different_projects() {
591        let project1 = PathBuf::from("/test/project1");
592        let project2 = PathBuf::from("/test/project2");
593        let path1 = compute_socket_path(&project1);
594        let path2 = compute_socket_path(&project2);
595        assert_ne!(path1, path2);
596    }
597
598    #[test]
599    fn test_compute_tcp_port_range() {
600        let project = PathBuf::from("/test/project");
601        let port = compute_tcp_port(&project);
602        assert!(port >= 49152);
603        assert!(port < 59152);
604    }
605
606    #[test]
607    fn test_compute_tcp_port_deterministic() {
608        let project = PathBuf::from("/test/project");
609        let port1 = compute_tcp_port(&project);
610        let port2 = compute_tcp_port(&project);
611        assert_eq!(port1, port2);
612    }
613
614    #[test]
615    fn test_validate_socket_path_valid() {
616        let tmp_dir = std::env::temp_dir();
617        let socket_path = tmp_dir.join("tldr-test.sock");
618        assert!(validate_socket_path(&socket_path).is_ok());
619    }
620
621    #[test]
622    fn test_validate_socket_path_traversal() {
623        let tmp_dir = std::env::temp_dir();
624        let socket_path = tmp_dir.join("../etc/passwd");
625        // This should fail because the canonicalized path escapes temp dir
626        // Note: behavior depends on whether /etc exists and is a directory
627        let result = validate_socket_path(&socket_path);
628        // Should fail either due to path validation or filename check
629        // The exact error depends on the system
630        assert!(result.is_err() || !socket_path.starts_with(&tmp_dir));
631    }
632
633    #[test]
634    fn test_validate_socket_path_bad_filename() {
635        let tmp_dir = std::env::temp_dir();
636        // Create a path with .. in filename (not directory traversal)
637        let socket_path = tmp_dir.join("test..sock");
638        assert!(validate_socket_path(&socket_path).is_err());
639    }
640
641    #[test]
642    fn test_max_message_size_constant() {
643        // Verify 10MB limit
644        assert_eq!(MAX_MESSAGE_SIZE, 10 * 1024 * 1024);
645    }
646
647    #[test]
648    fn test_cleanup_socket_nonexistent() {
649        let temp = TempDir::new().unwrap();
650        let project = temp.path().join("nonexistent");
651
652        // Should not error on nonexistent socket
653        let result = cleanup_socket(&project);
654        assert!(result.is_ok());
655    }
656
657    #[cfg(unix)]
658    #[test]
659    fn test_check_not_symlink_regular_file() {
660        let temp = TempDir::new().unwrap();
661        let file_path = temp.path().join("regular.txt");
662        std::fs::write(&file_path, "test").unwrap();
663
664        assert!(check_not_symlink(&file_path).is_ok());
665    }
666
667    #[cfg(unix)]
668    #[test]
669    fn test_check_not_symlink_symlink() {
670        let temp = TempDir::new().unwrap();
671        let file_path = temp.path().join("regular.txt");
672        let link_path = temp.path().join("symlink.txt");
673
674        std::fs::write(&file_path, "test").unwrap();
675        std::os::unix::fs::symlink(&file_path, &link_path).unwrap();
676
677        assert!(check_not_symlink(&link_path).is_err());
678    }
679
680    #[cfg(unix)]
681    #[test]
682    fn test_check_not_symlink_nonexistent() {
683        let temp = TempDir::new().unwrap();
684        let path = temp.path().join("nonexistent");
685
686        // Nonexistent path should be OK (nothing to check)
687        assert!(check_not_symlink(&path).is_ok());
688    }
689
690    #[tokio::test]
691    async fn test_connect_nonexistent_daemon() {
692        let temp = TempDir::new().unwrap();
693        let project = temp.path();
694
695        let result = IpcStream::connect(project).await;
696        assert!(matches!(result, Err(DaemonError::NotRunning)));
697    }
698
699    // Integration tests for listener/stream would require a running daemon
700    // Those are tested in daemon_test.rs Phase 5+
701}