tldr_cli/commands/daemon/
ipc.rs1use std::io;
21use std::path::{Path, PathBuf};
22
23use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
24
25use crate::commands::daemon::error::{DaemonError, DaemonResult};
26use crate::commands::daemon::pid::compute_hash;
27use crate::commands::daemon::types::{DaemonCommand, DaemonResponse};
28
29pub const MAX_MESSAGE_SIZE: usize = 10 * 1024 * 1024;
37
38pub const CONNECTION_TIMEOUT_SECS: u64 = 5;
40
41pub const READ_TIMEOUT_SECS: u64 = 30;
43
44#[cfg(unix)]
58pub fn compute_socket_path(project: &Path) -> PathBuf {
59 let hash = compute_hash(project);
60 let tmp_dir = std::env::temp_dir();
61 tmp_dir.join(format!("tldr-{}.sock", hash))
62}
63
64#[cfg(windows)]
69pub fn compute_tcp_port(project: &Path) -> u16 {
70 let hash = compute_hash(project);
71 let hash_int = u64::from_str_radix(&hash, 16).unwrap_or(0);
72 49152 + (hash_int % 10000) as u16
73}
74
75#[cfg(not(unix))]
77pub fn compute_socket_path(project: &Path) -> PathBuf {
78 let hash = compute_hash(project);
80 let tmp_dir = std::env::temp_dir();
81 tmp_dir.join(format!("tldr-{}.sock", hash))
82}
83
84#[cfg(not(windows))]
85pub fn compute_tcp_port(project: &Path) -> u16 {
86 let hash = compute_hash(project);
88 let hash_int = u64::from_str_radix(&hash, 16).unwrap_or(0);
89 49152 + (hash_int % 10000) as u16
90}
91
92pub fn validate_socket_path(socket_path: &Path) -> DaemonResult<()> {
104 let tmp_dir = std::env::temp_dir();
105
106 let canonical_tmp = tmp_dir.canonicalize().unwrap_or(tmp_dir);
108
109 let socket_parent = socket_path.parent().unwrap_or(socket_path);
112
113 let canonical_parent = socket_parent
115 .canonicalize()
116 .unwrap_or_else(|_| socket_parent.to_path_buf());
117
118 if !canonical_parent.starts_with(&canonical_tmp) {
119 return Err(DaemonError::PermissionDenied {
120 path: socket_path.to_path_buf(),
121 });
122 }
123
124 if let Some(filename) = socket_path.file_name() {
126 let filename_str = filename.to_string_lossy();
127 if filename_str.contains("..") || filename_str.contains('/') || filename_str.contains('\\')
128 {
129 return Err(DaemonError::PermissionDenied {
130 path: socket_path.to_path_buf(),
131 });
132 }
133 }
134
135 Ok(())
136}
137
138#[cfg(unix)]
144pub fn check_not_symlink(path: &Path) -> DaemonResult<()> {
145 if let Ok(metadata) = std::fs::symlink_metadata(path) {
146 if metadata.file_type().is_symlink() {
147 return Err(DaemonError::PermissionDenied {
148 path: path.to_path_buf(),
149 });
150 }
151 }
152 Ok(())
153}
154
155#[cfg(not(unix))]
156pub fn check_not_symlink(path: &Path) -> DaemonResult<()> {
157 if let Ok(metadata) = std::fs::symlink_metadata(path) {
159 if metadata.file_type().is_symlink() {
160 return Err(DaemonError::PermissionDenied {
161 path: path.to_path_buf(),
162 });
163 }
164 }
165 Ok(())
166}
167
168pub struct IpcListener {
174 #[cfg(unix)]
175 inner: tokio::net::UnixListener,
176 #[cfg(windows)]
177 inner: tokio::net::TcpListener,
178 #[allow(dead_code)]
180 socket_path: PathBuf,
181}
182
183impl IpcListener {
184 pub async fn bind(project: &Path) -> DaemonResult<Self> {
197 #[cfg(unix)]
198 {
199 Self::bind_unix(project).await
200 }
201 #[cfg(windows)]
202 {
203 Self::bind_tcp(project).await
204 }
205 }
206
207 #[cfg(unix)]
208 async fn bind_unix(project: &Path) -> DaemonResult<Self> {
209 use std::os::unix::fs::PermissionsExt;
210
211 let socket_path = compute_socket_path(project);
212
213 validate_socket_path(&socket_path)?;
215
216 check_not_symlink(&socket_path)?;
218
219 if socket_path.exists() {
221 check_not_symlink(&socket_path)?;
223 std::fs::remove_file(&socket_path).map_err(DaemonError::SocketBindFailed)?;
224 }
225
226 let listener = tokio::net::UnixListener::bind(&socket_path)
228 .map_err(DaemonError::SocketBindFailed)?;
229
230 let permissions = std::fs::Permissions::from_mode(0o600);
232 std::fs::set_permissions(&socket_path, permissions)
233 .map_err(DaemonError::SocketBindFailed)?;
234
235 Ok(Self {
236 inner: listener,
237 socket_path,
238 })
239 }
240
241 #[cfg(windows)]
242 async fn bind_tcp(project: &Path) -> DaemonResult<Self> {
243 let socket_path = compute_socket_path(project); let port = compute_tcp_port(project);
245 let addr = format!("127.0.0.1:{}", port);
246
247 let listener = tokio::net::TcpListener::bind(&addr).await.map_err(|e| {
248 if e.kind() == io::ErrorKind::AddrInUse {
249 DaemonError::AddressInUse { addr }
250 } else {
251 DaemonError::SocketBindFailed(e)
252 }
253 })?;
254
255 Ok(Self {
256 inner: listener,
257 socket_path,
258 })
259 }
260
261 pub async fn accept(&self) -> DaemonResult<IpcStream> {
265 #[cfg(unix)]
266 {
267 let (stream, _addr) = self.inner.accept().await.map_err(DaemonError::Io)?;
268 Ok(IpcStream {
269 inner: IpcStreamInner::Unix(stream),
270 })
271 }
272 #[cfg(windows)]
273 {
274 let (stream, _addr) = self.inner.accept().await.map_err(DaemonError::Io)?;
275 Ok(IpcStream {
276 inner: IpcStreamInner::Tcp(stream),
277 })
278 }
279 }
280}
281
282enum IpcStreamInner {
288 #[cfg(unix)]
289 Unix(tokio::net::UnixStream),
290 #[cfg(windows)]
291 Tcp(tokio::net::TcpStream),
292 #[cfg(all(not(unix), not(windows)))]
294 Dummy,
295}
296
297pub struct IpcStream {
299 inner: IpcStreamInner,
300}
301
302impl IpcStream {
303 pub async fn connect(project: &Path) -> DaemonResult<Self> {
311 #[cfg(unix)]
312 {
313 Self::connect_unix(project).await
314 }
315 #[cfg(windows)]
316 {
317 Self::connect_tcp(project).await
318 }
319 }
320
321 #[cfg(unix)]
322 async fn connect_unix(project: &Path) -> DaemonResult<Self> {
323 let socket_path = compute_socket_path(project);
324
325 validate_socket_path(&socket_path)?;
327
328 if !socket_path.exists() {
330 return Err(DaemonError::NotRunning);
331 }
332
333 check_not_symlink(&socket_path)?;
335
336 let connect_future = tokio::net::UnixStream::connect(&socket_path);
338 let timeout = tokio::time::Duration::from_secs(CONNECTION_TIMEOUT_SECS);
339
340 match tokio::time::timeout(timeout, connect_future).await {
341 Ok(Ok(stream)) => Ok(Self {
342 inner: IpcStreamInner::Unix(stream),
343 }),
344 Ok(Err(e)) if e.kind() == io::ErrorKind::ConnectionRefused => {
345 Err(DaemonError::ConnectionRefused)
346 }
347 Ok(Err(e)) if e.kind() == io::ErrorKind::NotFound => Err(DaemonError::NotRunning),
348 Ok(Err(e)) => Err(DaemonError::Io(e)),
349 Err(_) => Err(DaemonError::ConnectionTimeout {
350 timeout_secs: CONNECTION_TIMEOUT_SECS,
351 }),
352 }
353 }
354
355 #[cfg(windows)]
356 async fn connect_tcp(project: &Path) -> DaemonResult<Self> {
357 let port = compute_tcp_port(project);
358 let addr = format!("127.0.0.1:{}", port);
359
360 let connect_future = tokio::net::TcpStream::connect(&addr);
362 let timeout = tokio::time::Duration::from_secs(CONNECTION_TIMEOUT_SECS);
363
364 match tokio::time::timeout(timeout, connect_future).await {
365 Ok(Ok(stream)) => Ok(Self {
366 inner: IpcStreamInner::Tcp(stream),
367 }),
368 Ok(Err(e)) if e.kind() == io::ErrorKind::ConnectionRefused => {
369 Err(DaemonError::ConnectionRefused)
370 }
371 Ok(Err(e)) => Err(DaemonError::Io(e)),
372 Err(_) => Err(DaemonError::ConnectionTimeout {
373 timeout_secs: CONNECTION_TIMEOUT_SECS,
374 }),
375 }
376 }
377
378 pub async fn send_command(&mut self, cmd: &DaemonCommand) -> DaemonResult<()> {
382 let json = serde_json::to_string(cmd)?;
383 self.send_raw(&json).await
384 }
385
386 pub async fn send_raw(&mut self, json: &str) -> DaemonResult<()> {
390 if json.len() > MAX_MESSAGE_SIZE {
392 return Err(DaemonError::InvalidMessage(format!(
393 "message too large: {} bytes (max {})",
394 json.len(),
395 MAX_MESSAGE_SIZE
396 )));
397 }
398
399 let mut message = json.to_string();
400 message.push('\n');
401
402 match &mut self.inner {
403 #[cfg(unix)]
404 IpcStreamInner::Unix(stream) => {
405 stream.write_all(message.as_bytes()).await?;
406 stream.flush().await?;
407 }
408 #[cfg(windows)]
409 IpcStreamInner::Tcp(stream) => {
410 stream.write_all(message.as_bytes()).await?;
411 stream.flush().await?;
412 }
413 #[cfg(all(not(unix), not(windows)))]
414 IpcStreamInner::Dummy => {}
415 }
416
417 Ok(())
418 }
419
420 pub async fn recv_response(&mut self) -> DaemonResult<DaemonResponse> {
424 let json = self.recv_raw().await?;
425 let response: DaemonResponse = serde_json::from_str(&json)?;
426 Ok(response)
427 }
428
429 pub async fn recv_raw(&mut self) -> DaemonResult<String> {
433 let timeout = tokio::time::Duration::from_secs(READ_TIMEOUT_SECS);
434
435 match &mut self.inner {
436 #[cfg(unix)]
437 IpcStreamInner::Unix(stream) => {
438 let mut reader = BufReader::new(stream);
439 let mut line = String::new();
440
441 let read_future = reader.read_line(&mut line);
442
443 match tokio::time::timeout(timeout, read_future).await {
444 Ok(Ok(0)) => Err(DaemonError::ConnectionRefused), Ok(Ok(n)) if n > MAX_MESSAGE_SIZE => Err(DaemonError::InvalidMessage(format!(
446 "response too large: {} bytes (max {})",
447 n, MAX_MESSAGE_SIZE
448 ))),
449 Ok(Ok(_)) => Ok(line.trim_end().to_string()),
450 Ok(Err(e)) => Err(DaemonError::Io(e)),
451 Err(_) => Err(DaemonError::ConnectionTimeout {
452 timeout_secs: READ_TIMEOUT_SECS,
453 }),
454 }
455 }
456 #[cfg(windows)]
457 IpcStreamInner::Tcp(stream) => {
458 let mut reader = BufReader::new(stream);
459 let mut line = String::new();
460
461 let read_future = reader.read_line(&mut line);
462
463 match tokio::time::timeout(timeout, read_future).await {
464 Ok(Ok(0)) => Err(DaemonError::ConnectionRefused), Ok(Ok(n)) if n > MAX_MESSAGE_SIZE => Err(DaemonError::InvalidMessage(format!(
466 "response too large: {} bytes (max {})",
467 n, MAX_MESSAGE_SIZE
468 ))),
469 Ok(Ok(_)) => Ok(line.trim_end().to_string()),
470 Ok(Err(e)) => Err(DaemonError::Io(e)),
471 Err(_) => Err(DaemonError::ConnectionTimeout {
472 timeout_secs: READ_TIMEOUT_SECS,
473 }),
474 }
475 }
476 #[cfg(all(not(unix), not(windows)))]
477 IpcStreamInner::Dummy => Err(DaemonError::NotRunning),
478 }
479 }
480}
481
482pub async fn read_command(stream: &mut IpcStream) -> DaemonResult<DaemonCommand> {
490 let json = stream.recv_raw().await?;
491
492 if json.len() > MAX_MESSAGE_SIZE {
494 return Err(DaemonError::InvalidMessage(format!(
495 "command too large: {} bytes (max {})",
496 json.len(),
497 MAX_MESSAGE_SIZE
498 )));
499 }
500
501 let cmd: DaemonCommand = serde_json::from_str(&json)?;
502 Ok(cmd)
503}
504
505pub async fn send_response(stream: &mut IpcStream, response: &DaemonResponse) -> DaemonResult<()> {
509 let json = serde_json::to_string(response)?;
510 stream.send_raw(&json).await
511}
512
513pub fn cleanup_socket(project: &Path) -> DaemonResult<()> {
521 let socket_path = compute_socket_path(project);
522
523 if socket_path.exists() {
524 check_not_symlink(&socket_path)?;
526 std::fs::remove_file(&socket_path)?;
527 }
528
529 Ok(())
530}
531
532pub async fn check_socket_alive(project: &Path) -> bool {
536 (IpcStream::connect(project).await).is_ok()
537}
538
539pub async fn send_command(project: &Path, cmd: &DaemonCommand) -> DaemonResult<DaemonResponse> {
547 let mut stream = IpcStream::connect(project).await?;
548 stream.send_command(cmd).await?;
549 stream.recv_response().await
550}
551
552pub async fn send_raw_command(project: &Path, json: &str) -> DaemonResult<String> {
556 let mut stream = IpcStream::connect(project).await?;
557 stream.send_raw(json).await?;
558 stream.recv_raw().await
559}
560
561#[cfg(test)]
566mod tests {
567 use super::*;
568 use std::path::PathBuf;
569 use tempfile::TempDir;
570
571 #[test]
572 fn test_compute_socket_path_format() {
573 let project = PathBuf::from("/test/project");
574 let socket_path = compute_socket_path(&project);
575
576 let filename = socket_path.file_name().unwrap().to_str().unwrap();
577 assert!(filename.starts_with("tldr-"));
578 assert!(filename.ends_with(".sock"));
579 }
580
581 #[test]
582 fn test_compute_socket_path_deterministic() {
583 let project = PathBuf::from("/test/project");
584 let path1 = compute_socket_path(&project);
585 let path2 = compute_socket_path(&project);
586 assert_eq!(path1, path2);
587 }
588
589 #[test]
590 fn test_compute_socket_path_different_projects() {
591 let project1 = PathBuf::from("/test/project1");
592 let project2 = PathBuf::from("/test/project2");
593 let path1 = compute_socket_path(&project1);
594 let path2 = compute_socket_path(&project2);
595 assert_ne!(path1, path2);
596 }
597
598 #[test]
599 fn test_compute_tcp_port_range() {
600 let project = PathBuf::from("/test/project");
601 let port = compute_tcp_port(&project);
602 assert!(port >= 49152);
603 assert!(port < 59152);
604 }
605
606 #[test]
607 fn test_compute_tcp_port_deterministic() {
608 let project = PathBuf::from("/test/project");
609 let port1 = compute_tcp_port(&project);
610 let port2 = compute_tcp_port(&project);
611 assert_eq!(port1, port2);
612 }
613
614 #[test]
615 fn test_validate_socket_path_valid() {
616 let tmp_dir = std::env::temp_dir();
617 let socket_path = tmp_dir.join("tldr-test.sock");
618 assert!(validate_socket_path(&socket_path).is_ok());
619 }
620
621 #[test]
622 fn test_validate_socket_path_traversal() {
623 let tmp_dir = std::env::temp_dir();
624 let socket_path = tmp_dir.join("../etc/passwd");
625 let result = validate_socket_path(&socket_path);
628 assert!(result.is_err() || !socket_path.starts_with(&tmp_dir));
631 }
632
633 #[test]
634 fn test_validate_socket_path_bad_filename() {
635 let tmp_dir = std::env::temp_dir();
636 let socket_path = tmp_dir.join("test..sock");
638 assert!(validate_socket_path(&socket_path).is_err());
639 }
640
641 #[test]
642 fn test_max_message_size_constant() {
643 assert_eq!(MAX_MESSAGE_SIZE, 10 * 1024 * 1024);
645 }
646
647 #[test]
648 fn test_cleanup_socket_nonexistent() {
649 let temp = TempDir::new().unwrap();
650 let project = temp.path().join("nonexistent");
651
652 let result = cleanup_socket(&project);
654 assert!(result.is_ok());
655 }
656
657 #[cfg(unix)]
658 #[test]
659 fn test_check_not_symlink_regular_file() {
660 let temp = TempDir::new().unwrap();
661 let file_path = temp.path().join("regular.txt");
662 std::fs::write(&file_path, "test").unwrap();
663
664 assert!(check_not_symlink(&file_path).is_ok());
665 }
666
667 #[cfg(unix)]
668 #[test]
669 fn test_check_not_symlink_symlink() {
670 let temp = TempDir::new().unwrap();
671 let file_path = temp.path().join("regular.txt");
672 let link_path = temp.path().join("symlink.txt");
673
674 std::fs::write(&file_path, "test").unwrap();
675 std::os::unix::fs::symlink(&file_path, &link_path).unwrap();
676
677 assert!(check_not_symlink(&link_path).is_err());
678 }
679
680 #[cfg(unix)]
681 #[test]
682 fn test_check_not_symlink_nonexistent() {
683 let temp = TempDir::new().unwrap();
684 let path = temp.path().join("nonexistent");
685
686 assert!(check_not_symlink(&path).is_ok());
688 }
689
690 #[tokio::test]
691 async fn test_connect_nonexistent_daemon() {
692 let temp = TempDir::new().unwrap();
693 let project = temp.path();
694
695 let result = IpcStream::connect(project).await;
696 assert!(matches!(result, Err(DaemonError::NotRunning)));
697 }
698
699 }