tldr_cli/commands/daemon/
ipc.rs1use std::io;
21use std::path::{Path, PathBuf};
22
23use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
24
25use crate::commands::daemon::error::{DaemonError, DaemonResult};
26use crate::commands::daemon::pid::compute_hash;
27use crate::commands::daemon::types::{DaemonCommand, DaemonResponse};
28
29pub const MAX_MESSAGE_SIZE: usize = 10 * 1024 * 1024;
37
38pub const CONNECTION_TIMEOUT_SECS: u64 = 5;
40
41pub const READ_TIMEOUT_SECS: u64 = 30;
43
44#[cfg(unix)]
58pub fn compute_socket_path(project: &Path) -> PathBuf {
59 let hash = compute_hash(project);
60 let tmp_dir = std::env::temp_dir();
61 tmp_dir.join(format!("tldr-{}.sock", hash))
62}
63
64#[cfg(windows)]
69pub fn compute_tcp_port(project: &Path) -> u16 {
70 let hash = compute_hash(project);
71 let hash_int = u64::from_str_radix(&hash, 16).unwrap_or(0);
72 49152 + (hash_int % 10000) as u16
73}
74
75#[cfg(not(unix))]
77pub fn compute_socket_path(project: &Path) -> PathBuf {
78 let hash = compute_hash(project);
80 let tmp_dir = std::env::temp_dir();
81 tmp_dir.join(format!("tldr-{}.sock", hash))
82}
83
84#[cfg(not(windows))]
85pub fn compute_tcp_port(project: &Path) -> u16 {
86 let hash = compute_hash(project);
88 let hash_int = u64::from_str_radix(&hash, 16).unwrap_or(0);
89 49152 + (hash_int % 10000) as u16
90}
91
92pub fn validate_socket_path(socket_path: &Path) -> DaemonResult<()> {
104 let tmp_dir = std::env::temp_dir();
105
106 let canonical_tmp = tmp_dir.canonicalize().unwrap_or(tmp_dir);
108
109 let socket_parent = socket_path.parent().unwrap_or(socket_path);
112
113 let canonical_parent = socket_parent
115 .canonicalize()
116 .unwrap_or_else(|_| socket_parent.to_path_buf());
117
118 if !canonical_parent.starts_with(&canonical_tmp) {
119 return Err(DaemonError::PermissionDenied {
120 path: socket_path.to_path_buf(),
121 });
122 }
123
124 if let Some(filename) = socket_path.file_name() {
126 let filename_str = filename.to_string_lossy();
127 if filename_str.contains("..") || filename_str.contains('/') || filename_str.contains('\\')
128 {
129 return Err(DaemonError::PermissionDenied {
130 path: socket_path.to_path_buf(),
131 });
132 }
133 }
134
135 Ok(())
136}
137
138#[cfg(unix)]
144pub fn check_not_symlink(path: &Path) -> DaemonResult<()> {
145 if let Ok(metadata) = std::fs::symlink_metadata(path) {
146 if metadata.file_type().is_symlink() {
147 return Err(DaemonError::PermissionDenied {
148 path: path.to_path_buf(),
149 });
150 }
151 }
152 Ok(())
153}
154
155#[cfg(not(unix))]
156pub fn check_not_symlink(path: &Path) -> DaemonResult<()> {
157 if let Ok(metadata) = std::fs::symlink_metadata(path) {
159 if metadata.file_type().is_symlink() {
160 return Err(DaemonError::PermissionDenied {
161 path: path.to_path_buf(),
162 });
163 }
164 }
165 Ok(())
166}
167
168pub struct IpcListener {
174 #[cfg(unix)]
175 inner: tokio::net::UnixListener,
176 #[cfg(windows)]
177 inner: tokio::net::TcpListener,
178 #[allow(dead_code)]
180 socket_path: PathBuf,
181}
182
183impl IpcListener {
184 pub async fn bind(project: &Path) -> DaemonResult<Self> {
197 #[cfg(unix)]
198 {
199 Self::bind_unix(project).await
200 }
201 #[cfg(windows)]
202 {
203 Self::bind_tcp(project).await
204 }
205 }
206
207 #[cfg(unix)]
208 async fn bind_unix(project: &Path) -> DaemonResult<Self> {
209 use std::os::unix::fs::PermissionsExt;
210
211 let socket_path = compute_socket_path(project);
212
213 validate_socket_path(&socket_path)?;
215
216 check_not_symlink(&socket_path)?;
218
219 let listener = tokio::net::UnixListener::bind(&socket_path).map_err(|e| {
227 if e.kind() == io::ErrorKind::AddrInUse {
228 DaemonError::AddressInUse {
229 addr: socket_path.display().to_string(),
230 }
231 } else {
232 DaemonError::SocketBindFailed(e)
233 }
234 })?;
235
236 let permissions = std::fs::Permissions::from_mode(0o600);
238 std::fs::set_permissions(&socket_path, permissions)
239 .map_err(DaemonError::SocketBindFailed)?;
240
241 Ok(Self {
242 inner: listener,
243 socket_path,
244 })
245 }
246
247 #[cfg(windows)]
248 async fn bind_tcp(project: &Path) -> DaemonResult<Self> {
249 let socket_path = compute_socket_path(project); let port = compute_tcp_port(project);
251 let addr = format!("127.0.0.1:{}", port);
252
253 let listener = tokio::net::TcpListener::bind(&addr).await.map_err(|e| {
254 if e.kind() == io::ErrorKind::AddrInUse {
255 DaemonError::AddressInUse { addr }
256 } else {
257 DaemonError::SocketBindFailed(e)
258 }
259 })?;
260
261 Ok(Self {
262 inner: listener,
263 socket_path,
264 })
265 }
266
267 pub async fn accept(&self) -> DaemonResult<IpcStream> {
271 #[cfg(unix)]
272 {
273 let (stream, _addr) = self.inner.accept().await.map_err(DaemonError::Io)?;
274 Ok(IpcStream {
275 inner: IpcStreamInner::Unix(stream),
276 })
277 }
278 #[cfg(windows)]
279 {
280 let (stream, _addr) = self.inner.accept().await.map_err(DaemonError::Io)?;
281 Ok(IpcStream {
282 inner: IpcStreamInner::Tcp(stream),
283 })
284 }
285 }
286}
287
288enum IpcStreamInner {
294 #[cfg(unix)]
295 Unix(tokio::net::UnixStream),
296 #[cfg(windows)]
297 Tcp(tokio::net::TcpStream),
298 #[cfg(all(not(unix), not(windows)))]
300 Dummy,
301}
302
303pub struct IpcStream {
305 inner: IpcStreamInner,
306}
307
308impl IpcStream {
309 pub async fn connect(project: &Path) -> DaemonResult<Self> {
317 #[cfg(unix)]
318 {
319 Self::connect_unix(project).await
320 }
321 #[cfg(windows)]
322 {
323 Self::connect_tcp(project).await
324 }
325 }
326
327 #[cfg(unix)]
328 async fn connect_unix(project: &Path) -> DaemonResult<Self> {
329 let socket_path = compute_socket_path(project);
330
331 validate_socket_path(&socket_path)?;
333
334 if !socket_path.exists() {
336 return Err(DaemonError::NotRunning);
337 }
338
339 check_not_symlink(&socket_path)?;
341
342 let connect_future = tokio::net::UnixStream::connect(&socket_path);
344 let timeout = tokio::time::Duration::from_secs(CONNECTION_TIMEOUT_SECS);
345
346 match tokio::time::timeout(timeout, connect_future).await {
347 Ok(Ok(stream)) => Ok(Self {
348 inner: IpcStreamInner::Unix(stream),
349 }),
350 Ok(Err(e)) if e.kind() == io::ErrorKind::ConnectionRefused => {
351 Err(DaemonError::ConnectionRefused)
352 }
353 Ok(Err(e)) if e.kind() == io::ErrorKind::NotFound => Err(DaemonError::NotRunning),
354 Ok(Err(e)) => Err(DaemonError::Io(e)),
355 Err(_) => Err(DaemonError::ConnectionTimeout {
356 timeout_secs: CONNECTION_TIMEOUT_SECS,
357 }),
358 }
359 }
360
361 #[cfg(windows)]
362 async fn connect_tcp(project: &Path) -> DaemonResult<Self> {
363 let port = compute_tcp_port(project);
364 let addr = format!("127.0.0.1:{}", port);
365
366 let connect_future = tokio::net::TcpStream::connect(&addr);
368 let timeout = tokio::time::Duration::from_secs(CONNECTION_TIMEOUT_SECS);
369
370 match tokio::time::timeout(timeout, connect_future).await {
371 Ok(Ok(stream)) => Ok(Self {
372 inner: IpcStreamInner::Tcp(stream),
373 }),
374 Ok(Err(e)) if e.kind() == io::ErrorKind::ConnectionRefused => {
375 Err(DaemonError::ConnectionRefused)
376 }
377 Ok(Err(e)) => Err(DaemonError::Io(e)),
378 Err(_) => Err(DaemonError::ConnectionTimeout {
379 timeout_secs: CONNECTION_TIMEOUT_SECS,
380 }),
381 }
382 }
383
384 pub async fn send_command(&mut self, cmd: &DaemonCommand) -> DaemonResult<()> {
388 let json = serde_json::to_string(cmd)?;
389 self.send_raw(&json).await
390 }
391
392 pub async fn send_raw(&mut self, json: &str) -> DaemonResult<()> {
396 if json.len() > MAX_MESSAGE_SIZE {
398 return Err(DaemonError::InvalidMessage(format!(
399 "message too large: {} bytes (max {})",
400 json.len(),
401 MAX_MESSAGE_SIZE
402 )));
403 }
404
405 let mut message = json.to_string();
406 message.push('\n');
407
408 match &mut self.inner {
409 #[cfg(unix)]
410 IpcStreamInner::Unix(stream) => {
411 stream.write_all(message.as_bytes()).await?;
412 stream.flush().await?;
413 }
414 #[cfg(windows)]
415 IpcStreamInner::Tcp(stream) => {
416 stream.write_all(message.as_bytes()).await?;
417 stream.flush().await?;
418 }
419 #[cfg(all(not(unix), not(windows)))]
420 IpcStreamInner::Dummy => {}
421 }
422
423 Ok(())
424 }
425
426 pub async fn recv_response(&mut self) -> DaemonResult<DaemonResponse> {
430 let json = self.recv_raw().await?;
431 let response: DaemonResponse = serde_json::from_str(&json)?;
432 Ok(response)
433 }
434
435 pub async fn recv_raw(&mut self) -> DaemonResult<String> {
439 let timeout = tokio::time::Duration::from_secs(READ_TIMEOUT_SECS);
440
441 match &mut self.inner {
442 #[cfg(unix)]
443 IpcStreamInner::Unix(stream) => {
444 let mut reader = BufReader::new(stream);
445 let mut line = String::new();
446
447 let read_future = reader.read_line(&mut line);
448
449 match tokio::time::timeout(timeout, read_future).await {
450 Ok(Ok(0)) => Err(DaemonError::ConnectionRefused), Ok(Ok(n)) if n > MAX_MESSAGE_SIZE => Err(DaemonError::InvalidMessage(format!(
452 "response too large: {} bytes (max {})",
453 n, MAX_MESSAGE_SIZE
454 ))),
455 Ok(Ok(_)) => Ok(line.trim_end().to_string()),
456 Ok(Err(e)) => Err(DaemonError::Io(e)),
457 Err(_) => Err(DaemonError::ConnectionTimeout {
458 timeout_secs: READ_TIMEOUT_SECS,
459 }),
460 }
461 }
462 #[cfg(windows)]
463 IpcStreamInner::Tcp(stream) => {
464 let mut reader = BufReader::new(stream);
465 let mut line = String::new();
466
467 let read_future = reader.read_line(&mut line);
468
469 match tokio::time::timeout(timeout, read_future).await {
470 Ok(Ok(0)) => Err(DaemonError::ConnectionRefused), Ok(Ok(n)) if n > MAX_MESSAGE_SIZE => Err(DaemonError::InvalidMessage(format!(
472 "response too large: {} bytes (max {})",
473 n, MAX_MESSAGE_SIZE
474 ))),
475 Ok(Ok(_)) => Ok(line.trim_end().to_string()),
476 Ok(Err(e)) => Err(DaemonError::Io(e)),
477 Err(_) => Err(DaemonError::ConnectionTimeout {
478 timeout_secs: READ_TIMEOUT_SECS,
479 }),
480 }
481 }
482 #[cfg(all(not(unix), not(windows)))]
483 IpcStreamInner::Dummy => Err(DaemonError::NotRunning),
484 }
485 }
486}
487
488pub async fn read_command(stream: &mut IpcStream) -> DaemonResult<DaemonCommand> {
496 let json = stream.recv_raw().await?;
497
498 if json.len() > MAX_MESSAGE_SIZE {
500 return Err(DaemonError::InvalidMessage(format!(
501 "command too large: {} bytes (max {})",
502 json.len(),
503 MAX_MESSAGE_SIZE
504 )));
505 }
506
507 let cmd: DaemonCommand = serde_json::from_str(&json)?;
508 Ok(cmd)
509}
510
511pub async fn send_response(stream: &mut IpcStream, response: &DaemonResponse) -> DaemonResult<()> {
515 let json = serde_json::to_string(response)?;
516 stream.send_raw(&json).await
517}
518
519pub fn cleanup_socket(project: &Path) -> DaemonResult<()> {
527 let socket_path = compute_socket_path(project);
528
529 if socket_path.exists() {
530 check_not_symlink(&socket_path)?;
532 std::fs::remove_file(&socket_path)?;
533 }
534
535 Ok(())
536}
537
538pub async fn check_socket_alive(project: &Path) -> bool {
542 (IpcStream::connect(project).await).is_ok()
543}
544
545pub async fn send_command(project: &Path, cmd: &DaemonCommand) -> DaemonResult<DaemonResponse> {
553 let mut stream = IpcStream::connect(project).await?;
554 stream.send_command(cmd).await?;
555 stream.recv_response().await
556}
557
558pub async fn send_raw_command(project: &Path, json: &str) -> DaemonResult<String> {
562 let mut stream = IpcStream::connect(project).await?;
563 stream.send_raw(json).await?;
564 stream.recv_raw().await
565}
566
567#[cfg(test)]
572mod tests {
573 use super::*;
574 use std::path::PathBuf;
575 use tempfile::TempDir;
576
577 #[test]
578 fn test_compute_socket_path_format() {
579 let project = PathBuf::from("/test/project");
580 let socket_path = compute_socket_path(&project);
581
582 let filename = socket_path.file_name().unwrap().to_str().unwrap();
583 assert!(filename.starts_with("tldr-"));
584 assert!(filename.ends_with(".sock"));
585 }
586
587 #[test]
588 fn test_compute_socket_path_deterministic() {
589 let project = PathBuf::from("/test/project");
590 let path1 = compute_socket_path(&project);
591 let path2 = compute_socket_path(&project);
592 assert_eq!(path1, path2);
593 }
594
595 #[test]
596 fn test_compute_socket_path_different_projects() {
597 let project1 = PathBuf::from("/test/project1");
598 let project2 = PathBuf::from("/test/project2");
599 let path1 = compute_socket_path(&project1);
600 let path2 = compute_socket_path(&project2);
601 assert_ne!(path1, path2);
602 }
603
604 #[test]
605 fn test_compute_tcp_port_range() {
606 let project = PathBuf::from("/test/project");
607 let port = compute_tcp_port(&project);
608 assert!(port >= 49152);
609 assert!(port < 59152);
610 }
611
612 #[test]
613 fn test_compute_tcp_port_deterministic() {
614 let project = PathBuf::from("/test/project");
615 let port1 = compute_tcp_port(&project);
616 let port2 = compute_tcp_port(&project);
617 assert_eq!(port1, port2);
618 }
619
620 #[test]
621 fn test_validate_socket_path_valid() {
622 let tmp_dir = std::env::temp_dir();
623 let socket_path = tmp_dir.join("tldr-test.sock");
624 assert!(validate_socket_path(&socket_path).is_ok());
625 }
626
627 #[test]
628 fn test_validate_socket_path_traversal() {
629 let tmp_dir = std::env::temp_dir();
630 let socket_path = tmp_dir.join("../etc/passwd");
631 let result = validate_socket_path(&socket_path);
634 assert!(result.is_err() || !socket_path.starts_with(&tmp_dir));
637 }
638
639 #[test]
640 fn test_validate_socket_path_bad_filename() {
641 let tmp_dir = std::env::temp_dir();
642 let socket_path = tmp_dir.join("test..sock");
644 assert!(validate_socket_path(&socket_path).is_err());
645 }
646
647 #[test]
648 fn test_max_message_size_constant() {
649 assert_eq!(MAX_MESSAGE_SIZE, 10 * 1024 * 1024);
651 }
652
653 #[test]
654 fn test_cleanup_socket_nonexistent() {
655 let temp = TempDir::new().unwrap();
656 let project = temp.path().join("nonexistent");
657
658 let result = cleanup_socket(&project);
660 assert!(result.is_ok());
661 }
662
663 #[cfg(unix)]
664 #[test]
665 fn test_check_not_symlink_regular_file() {
666 let temp = TempDir::new().unwrap();
667 let file_path = temp.path().join("regular.txt");
668 std::fs::write(&file_path, "test").unwrap();
669
670 assert!(check_not_symlink(&file_path).is_ok());
671 }
672
673 #[cfg(unix)]
674 #[test]
675 fn test_check_not_symlink_symlink() {
676 let temp = TempDir::new().unwrap();
677 let file_path = temp.path().join("regular.txt");
678 let link_path = temp.path().join("symlink.txt");
679
680 std::fs::write(&file_path, "test").unwrap();
681 std::os::unix::fs::symlink(&file_path, &link_path).unwrap();
682
683 assert!(check_not_symlink(&link_path).is_err());
684 }
685
686 #[cfg(unix)]
687 #[test]
688 fn test_check_not_symlink_nonexistent() {
689 let temp = TempDir::new().unwrap();
690 let path = temp.path().join("nonexistent");
691
692 assert!(check_not_symlink(&path).is_ok());
694 }
695
696 #[tokio::test]
697 async fn test_connect_nonexistent_daemon() {
698 let temp = TempDir::new().unwrap();
699 let project = temp.path();
700
701 let result = IpcStream::connect(project).await;
702 assert!(matches!(result, Err(DaemonError::NotRunning)));
703 }
704
705 }