tldr_cli/commands/daemon/
ipc.rs1use std::io;
21use std::path::{Path, PathBuf};
22
23use tokio::io::{AsyncBufReadExt, AsyncReadExt, AsyncWriteExt, BufReader};
24
25use crate::commands::daemon::error::{DaemonError, DaemonResult};
26use crate::commands::daemon::pid::compute_hash;
27use crate::commands::daemon::types::{DaemonCommand, DaemonResponse};
28
29pub const MAX_MESSAGE_SIZE: usize = 10 * 1024 * 1024;
37
38pub const CONNECTION_TIMEOUT_SECS: u64 = 5;
40
41pub const READ_TIMEOUT_SECS: u64 = 30;
43
44#[cfg(unix)]
58pub fn compute_socket_path(project: &Path) -> PathBuf {
59 let hash = compute_hash(project);
60 let tmp_dir = std::env::temp_dir();
61 tmp_dir.join(format!("tldr-{}.sock", hash))
62}
63
64#[cfg(windows)]
69pub fn compute_tcp_port(project: &Path) -> u16 {
70 let hash = compute_hash(project);
71 let hash_int = u64::from_str_radix(&hash, 16).unwrap_or(0);
72 49152 + (hash_int % 10000) as u16
73}
74
75#[cfg(not(unix))]
77pub fn compute_socket_path(project: &Path) -> PathBuf {
78 let hash = compute_hash(project);
80 let tmp_dir = std::env::temp_dir();
81 tmp_dir.join(format!("tldr-{}.sock", hash))
82}
83
84#[cfg(not(windows))]
85pub fn compute_tcp_port(project: &Path) -> u16 {
86 let hash = compute_hash(project);
88 let hash_int = u64::from_str_radix(&hash, 16).unwrap_or(0);
89 49152 + (hash_int % 10000) as u16
90}
91
92pub fn validate_socket_path(socket_path: &Path) -> DaemonResult<()> {
104 let tmp_dir = std::env::temp_dir();
105
106 let canonical_tmp = tmp_dir.canonicalize().unwrap_or(tmp_dir);
108
109 let socket_parent = socket_path.parent().unwrap_or(socket_path);
112
113 let canonical_parent = socket_parent
115 .canonicalize()
116 .unwrap_or_else(|_| socket_parent.to_path_buf());
117
118 if !canonical_parent.starts_with(&canonical_tmp) {
119 return Err(DaemonError::PermissionDenied {
120 path: socket_path.to_path_buf(),
121 });
122 }
123
124 if let Some(filename) = socket_path.file_name() {
126 let filename_str = filename.to_string_lossy();
127 if filename_str.contains("..") || filename_str.contains('/') || filename_str.contains('\\')
128 {
129 return Err(DaemonError::PermissionDenied {
130 path: socket_path.to_path_buf(),
131 });
132 }
133 }
134
135 Ok(())
136}
137
138#[cfg(unix)]
144pub fn check_not_symlink(path: &Path) -> DaemonResult<()> {
145 if let Ok(metadata) = std::fs::symlink_metadata(path) {
146 if metadata.file_type().is_symlink() {
147 return Err(DaemonError::PermissionDenied {
148 path: path.to_path_buf(),
149 });
150 }
151 }
152 Ok(())
153}
154
155#[cfg(not(unix))]
156pub fn check_not_symlink(path: &Path) -> DaemonResult<()> {
157 if let Ok(metadata) = std::fs::symlink_metadata(path) {
159 if metadata.file_type().is_symlink() {
160 return Err(DaemonError::PermissionDenied {
161 path: path.to_path_buf(),
162 });
163 }
164 }
165 Ok(())
166}
167
168pub struct IpcListener {
174 #[cfg(unix)]
175 inner: tokio::net::UnixListener,
176 #[cfg(windows)]
177 inner: tokio::net::TcpListener,
178 #[allow(dead_code)]
180 socket_path: PathBuf,
181}
182
183impl IpcListener {
184 pub async fn bind(project: &Path) -> DaemonResult<Self> {
197 #[cfg(unix)]
198 {
199 Self::bind_unix(project).await
200 }
201 #[cfg(windows)]
202 {
203 Self::bind_tcp(project).await
204 }
205 }
206
207 #[cfg(unix)]
208 async fn bind_unix(project: &Path) -> DaemonResult<Self> {
209 use std::os::unix::fs::PermissionsExt;
210
211 let socket_path = compute_socket_path(project);
212
213 validate_socket_path(&socket_path)?;
215
216 check_not_symlink(&socket_path)?;
218
219 let listener = tokio::net::UnixListener::bind(&socket_path).map_err(|e| {
227 if e.kind() == io::ErrorKind::AddrInUse {
228 DaemonError::AddressInUse {
229 addr: socket_path.display().to_string(),
230 }
231 } else {
232 DaemonError::SocketBindFailed(e)
233 }
234 })?;
235
236 let permissions = std::fs::Permissions::from_mode(0o600);
238 std::fs::set_permissions(&socket_path, permissions)
239 .map_err(DaemonError::SocketBindFailed)?;
240
241 Ok(Self {
242 inner: listener,
243 socket_path,
244 })
245 }
246
247 #[cfg(windows)]
248 async fn bind_tcp(project: &Path) -> DaemonResult<Self> {
249 let socket_path = compute_socket_path(project); let port = compute_tcp_port(project);
251 let addr = format!("127.0.0.1:{}", port);
252
253 let listener = tokio::net::TcpListener::bind(&addr).await.map_err(|e| {
254 if e.kind() == io::ErrorKind::AddrInUse {
255 DaemonError::AddressInUse { addr }
256 } else {
257 DaemonError::SocketBindFailed(e)
258 }
259 })?;
260
261 Ok(Self {
262 inner: listener,
263 socket_path,
264 })
265 }
266
267 pub async fn accept(&self) -> DaemonResult<IpcStream> {
271 #[cfg(unix)]
272 {
273 let (stream, _addr) = self.inner.accept().await.map_err(DaemonError::Io)?;
274 Ok(IpcStream {
275 inner: IpcStreamInner::Unix(stream),
276 })
277 }
278 #[cfg(windows)]
279 {
280 let (stream, _addr) = self.inner.accept().await.map_err(DaemonError::Io)?;
281 Ok(IpcStream {
282 inner: IpcStreamInner::Tcp(stream),
283 })
284 }
285 }
286}
287
288enum IpcStreamInner {
294 #[cfg(unix)]
295 Unix(tokio::net::UnixStream),
296 #[cfg(windows)]
297 Tcp(tokio::net::TcpStream),
298 #[cfg(all(not(unix), not(windows)))]
300 Dummy,
301}
302
303pub struct IpcStream {
305 inner: IpcStreamInner,
306}
307
308impl IpcStream {
309 pub async fn connect(project: &Path) -> DaemonResult<Self> {
317 #[cfg(unix)]
318 {
319 Self::connect_unix(project).await
320 }
321 #[cfg(windows)]
322 {
323 Self::connect_tcp(project).await
324 }
325 }
326
327 #[cfg(unix)]
328 async fn connect_unix(project: &Path) -> DaemonResult<Self> {
329 let socket_path = compute_socket_path(project);
330
331 validate_socket_path(&socket_path)?;
333
334 if !socket_path.exists() {
336 return Err(DaemonError::NotRunning);
337 }
338
339 check_not_symlink(&socket_path)?;
341
342 let connect_future = tokio::net::UnixStream::connect(&socket_path);
344 let timeout = tokio::time::Duration::from_secs(CONNECTION_TIMEOUT_SECS);
345
346 match tokio::time::timeout(timeout, connect_future).await {
347 Ok(Ok(stream)) => Ok(Self {
348 inner: IpcStreamInner::Unix(stream),
349 }),
350 Ok(Err(e)) if e.kind() == io::ErrorKind::ConnectionRefused => {
351 Err(DaemonError::ConnectionRefused)
352 }
353 Ok(Err(e)) if e.kind() == io::ErrorKind::NotFound => Err(DaemonError::NotRunning),
354 Ok(Err(e)) => Err(DaemonError::Io(e)),
355 Err(_) => Err(DaemonError::ConnectionTimeout {
356 timeout_secs: CONNECTION_TIMEOUT_SECS,
357 }),
358 }
359 }
360
361 #[cfg(windows)]
362 async fn connect_tcp(project: &Path) -> DaemonResult<Self> {
363 let port = compute_tcp_port(project);
364 let addr = format!("127.0.0.1:{}", port);
365
366 let connect_future = tokio::net::TcpStream::connect(&addr);
368 let timeout = tokio::time::Duration::from_secs(CONNECTION_TIMEOUT_SECS);
369
370 match tokio::time::timeout(timeout, connect_future).await {
371 Ok(Ok(stream)) => Ok(Self {
372 inner: IpcStreamInner::Tcp(stream),
373 }),
374 Ok(Err(e)) if e.kind() == io::ErrorKind::ConnectionRefused => {
375 Err(DaemonError::ConnectionRefused)
376 }
377 Ok(Err(e)) => Err(DaemonError::Io(e)),
378 Err(_) => Err(DaemonError::ConnectionTimeout {
379 timeout_secs: CONNECTION_TIMEOUT_SECS,
380 }),
381 }
382 }
383
384 pub async fn send_command(&mut self, cmd: &DaemonCommand) -> DaemonResult<()> {
388 let json = serde_json::to_string(cmd)?;
389 self.send_raw(&json).await
390 }
391
392 pub async fn send_raw(&mut self, json: &str) -> DaemonResult<()> {
396 if json.len() > MAX_MESSAGE_SIZE {
398 return Err(DaemonError::InvalidMessage(format!(
399 "message too large: {} bytes (max {})",
400 json.len(),
401 MAX_MESSAGE_SIZE
402 )));
403 }
404
405 let mut message = json.to_string();
406 message.push('\n');
407
408 match &mut self.inner {
409 #[cfg(unix)]
410 IpcStreamInner::Unix(stream) => {
411 stream.write_all(message.as_bytes()).await?;
412 stream.flush().await?;
413 }
414 #[cfg(windows)]
415 IpcStreamInner::Tcp(stream) => {
416 stream.write_all(message.as_bytes()).await?;
417 stream.flush().await?;
418 }
419 #[cfg(all(not(unix), not(windows)))]
420 IpcStreamInner::Dummy => {}
421 }
422
423 Ok(())
424 }
425
426 pub async fn recv_response(&mut self) -> DaemonResult<DaemonResponse> {
430 let json = self.recv_raw().await?;
431 let response: DaemonResponse = serde_json::from_str(&json)?;
432 Ok(response)
433 }
434
435 pub async fn recv_raw(&mut self) -> DaemonResult<String> {
444 let timeout = tokio::time::Duration::from_secs(READ_TIMEOUT_SECS);
445 let limit = (MAX_MESSAGE_SIZE + 1) as u64;
448
449 match &mut self.inner {
450 #[cfg(unix)]
451 IpcStreamInner::Unix(stream) => recv_raw_from(stream, limit, timeout).await,
452 #[cfg(windows)]
453 IpcStreamInner::Tcp(stream) => recv_raw_from(stream, limit, timeout).await,
454 #[cfg(all(not(unix), not(windows)))]
455 IpcStreamInner::Dummy => Err(DaemonError::NotRunning),
456 }
457 }
458}
459
460async fn recv_raw_from<R>(
471 stream: &mut R,
472 limit: u64,
473 timeout: tokio::time::Duration,
474) -> DaemonResult<String>
475where
476 R: tokio::io::AsyncRead + Unpin,
477{
478 let limited = AsyncReadExt::take(stream, limit);
479 let mut reader = BufReader::new(limited);
480 let mut line = String::new();
481
482 let read_future = reader.read_line(&mut line);
483
484 match tokio::time::timeout(timeout, read_future).await {
485 Ok(Ok(0)) if line.is_empty() => Err(DaemonError::ConnectionRefused),
487 Ok(Ok(_)) if !line.ends_with('\n') => Err(DaemonError::InvalidMessage(format!(
492 "message exceeds size limit of {} bytes",
493 MAX_MESSAGE_SIZE
494 ))),
495 Ok(Ok(_)) => Ok(line.trim_end().to_string()),
496 Ok(Err(e)) => Err(DaemonError::Io(e)),
497 Err(_) => Err(DaemonError::ConnectionTimeout {
498 timeout_secs: READ_TIMEOUT_SECS,
499 }),
500 }
501}
502
503pub async fn read_command(stream: &mut IpcStream) -> DaemonResult<DaemonCommand> {
516 let json = stream.recv_raw().await?;
517 let cmd: DaemonCommand = serde_json::from_str(&json)?;
518 Ok(cmd)
519}
520
521pub async fn send_response(stream: &mut IpcStream, response: &DaemonResponse) -> DaemonResult<()> {
525 let json = serde_json::to_string(response)?;
526 stream.send_raw(&json).await
527}
528
529pub fn cleanup_socket(project: &Path) -> DaemonResult<()> {
537 let socket_path = compute_socket_path(project);
538
539 if socket_path.exists() {
540 check_not_symlink(&socket_path)?;
542 std::fs::remove_file(&socket_path)?;
543 }
544
545 Ok(())
546}
547
548pub async fn check_socket_alive(project: &Path) -> bool {
552 (IpcStream::connect(project).await).is_ok()
553}
554
555pub async fn send_command(project: &Path, cmd: &DaemonCommand) -> DaemonResult<DaemonResponse> {
563 let mut stream = IpcStream::connect(project).await?;
564 stream.send_command(cmd).await?;
565 stream.recv_response().await
566}
567
568pub async fn send_raw_command(project: &Path, json: &str) -> DaemonResult<String> {
572 let mut stream = IpcStream::connect(project).await?;
573 stream.send_raw(json).await?;
574 stream.recv_raw().await
575}
576
577#[cfg(test)]
582mod tests {
583 use super::*;
584 use std::path::PathBuf;
585 use tempfile::TempDir;
586
587 #[test]
588 fn test_compute_socket_path_format() {
589 let project = PathBuf::from("/test/project");
590 let socket_path = compute_socket_path(&project);
591
592 let filename = socket_path.file_name().unwrap().to_str().unwrap();
593 assert!(filename.starts_with("tldr-"));
594 assert!(filename.ends_with(".sock"));
595 }
596
597 #[test]
598 fn test_compute_socket_path_deterministic() {
599 let project = PathBuf::from("/test/project");
600 let path1 = compute_socket_path(&project);
601 let path2 = compute_socket_path(&project);
602 assert_eq!(path1, path2);
603 }
604
605 #[test]
606 fn test_compute_socket_path_different_projects() {
607 let project1 = PathBuf::from("/test/project1");
608 let project2 = PathBuf::from("/test/project2");
609 let path1 = compute_socket_path(&project1);
610 let path2 = compute_socket_path(&project2);
611 assert_ne!(path1, path2);
612 }
613
614 #[test]
615 fn test_compute_tcp_port_range() {
616 let project = PathBuf::from("/test/project");
617 let port = compute_tcp_port(&project);
618 assert!(port >= 49152);
619 assert!(port < 59152);
620 }
621
622 #[test]
623 fn test_compute_tcp_port_deterministic() {
624 let project = PathBuf::from("/test/project");
625 let port1 = compute_tcp_port(&project);
626 let port2 = compute_tcp_port(&project);
627 assert_eq!(port1, port2);
628 }
629
630 #[test]
631 fn test_validate_socket_path_valid() {
632 let tmp_dir = std::env::temp_dir();
633 let socket_path = tmp_dir.join("tldr-test.sock");
634 assert!(validate_socket_path(&socket_path).is_ok());
635 }
636
637 #[test]
638 fn test_validate_socket_path_traversal() {
639 let tmp_dir = std::env::temp_dir();
640 let socket_path = tmp_dir.join("../etc/passwd");
641 let result = validate_socket_path(&socket_path);
644 assert!(result.is_err() || !socket_path.starts_with(&tmp_dir));
647 }
648
649 #[test]
650 fn test_validate_socket_path_bad_filename() {
651 let tmp_dir = std::env::temp_dir();
652 let socket_path = tmp_dir.join("test..sock");
654 assert!(validate_socket_path(&socket_path).is_err());
655 }
656
657 #[test]
658 fn test_max_message_size_constant() {
659 assert_eq!(MAX_MESSAGE_SIZE, 10 * 1024 * 1024);
661 }
662
663 #[test]
664 fn test_cleanup_socket_nonexistent() {
665 let temp = TempDir::new().unwrap();
666 let project = temp.path().join("nonexistent");
667
668 let result = cleanup_socket(&project);
670 assert!(result.is_ok());
671 }
672
673 #[cfg(unix)]
674 #[test]
675 fn test_check_not_symlink_regular_file() {
676 let temp = TempDir::new().unwrap();
677 let file_path = temp.path().join("regular.txt");
678 std::fs::write(&file_path, "test").unwrap();
679
680 assert!(check_not_symlink(&file_path).is_ok());
681 }
682
683 #[cfg(unix)]
684 #[test]
685 fn test_check_not_symlink_symlink() {
686 let temp = TempDir::new().unwrap();
687 let file_path = temp.path().join("regular.txt");
688 let link_path = temp.path().join("symlink.txt");
689
690 std::fs::write(&file_path, "test").unwrap();
691 std::os::unix::fs::symlink(&file_path, &link_path).unwrap();
692
693 assert!(check_not_symlink(&link_path).is_err());
694 }
695
696 #[cfg(unix)]
697 #[test]
698 fn test_check_not_symlink_nonexistent() {
699 let temp = TempDir::new().unwrap();
700 let path = temp.path().join("nonexistent");
701
702 assert!(check_not_symlink(&path).is_ok());
704 }
705
706 #[tokio::test]
707 async fn test_connect_nonexistent_daemon() {
708 let temp = TempDir::new().unwrap();
709 let project = temp.path();
710
711 let result = IpcStream::connect(project).await;
712 assert!(matches!(result, Err(DaemonError::NotRunning)));
713 }
714
715 }