1#![warn(missing_docs)]
60#![cfg(unix)]
64
65use std::fs;
66use std::io;
67use std::os::unix::net::UnixDatagram;
68use std::path::{Path, PathBuf};
69use std::time::Duration;
70
71use zerodds_rtps::wire_types::{Locator, LocatorKind};
72use zerodds_transport::{ReceivedDatagram, RecvError, SendError, Transport};
73
74pub const DEFAULT_BASE_DIR: &str = "/tmp/zerodds/uds";
76
77pub const DEFAULT_MAX_DATAGRAM: usize = 65_536;
81
82#[derive(Debug, Clone)]
84pub struct UdsConfig {
85 pub base_dir: PathBuf,
87 pub max_datagram: usize,
89 pub recv_timeout: Option<Duration>,
91}
92
93impl Default for UdsConfig {
94 fn default() -> Self {
95 Self {
96 base_dir: PathBuf::from(DEFAULT_BASE_DIR),
97 max_datagram: DEFAULT_MAX_DATAGRAM,
98 recv_timeout: None,
99 }
100 }
101}
102
103#[must_use]
106pub fn socket_path(base_dir: &Path, id: [u8; 16]) -> PathBuf {
107 let mut hex = String::with_capacity(32);
108 for byte in id {
109 use std::fmt::Write;
110 let _ = write!(hex, "{byte:02x}");
111 }
112 let mut p = base_dir.to_path_buf();
113 p.push(format!("{hex}.sock"));
114 p
115}
116
117pub struct UdsTransport {
119 socket: UnixDatagram,
120 local_id: [u8; 16],
121 config: UdsConfig,
122}
123
124impl UdsTransport {
125 pub fn bind(local_id: [u8; 16], config: UdsConfig) -> io::Result<Self> {
132 ensure_base_dir(&config.base_dir)?;
133 let path = socket_path(&config.base_dir, local_id);
134 if path.exists() {
137 fs::remove_file(&path)?;
138 }
139 let socket = UnixDatagram::bind(&path)?;
140 if let Some(t) = config.recv_timeout {
141 socket.set_read_timeout(Some(t))?;
142 }
143 Ok(Self {
144 socket,
145 local_id,
146 config,
147 })
148 }
149}
150
151impl Drop for UdsTransport {
152 fn drop(&mut self) {
153 let path = socket_path(&self.config.base_dir, self.local_id);
158 let _ = fs::remove_file(path);
159 }
160}
161
162fn ensure_base_dir(path: &Path) -> io::Result<()> {
163 match fs::symlink_metadata(path) {
175 Ok(meta) => {
176 if meta.file_type().is_symlink() {
177 return Err(io::Error::new(
178 io::ErrorKind::InvalidInput,
179 "uds base_dir is a symlink — refusing (TOCTOU-hardening)",
180 ));
181 }
182 if !meta.is_dir() {
183 return Err(io::Error::new(
184 io::ErrorKind::InvalidInput,
185 "uds base_dir exists but is not a directory",
186 ));
187 }
188 return Ok(());
191 }
192 Err(e) if e.kind() == io::ErrorKind::NotFound => {
193 }
195 Err(e) => return Err(e),
196 }
197 fs::create_dir_all(path)?;
198 #[cfg(unix)]
199 {
200 use std::os::unix::fs::PermissionsExt;
201 let perms = fs::Permissions::from_mode(0o700);
202 let meta_after = fs::symlink_metadata(path)?;
206 if meta_after.file_type().is_symlink() {
207 return Err(io::Error::new(
208 io::ErrorKind::InvalidInput,
209 "uds base_dir became a symlink between create and chmod",
210 ));
211 }
212 fs::set_permissions(path, perms)?;
213 }
214 Ok(())
215}
216
217impl Transport for UdsTransport {
218 fn send(&self, dest: &Locator, data: &[u8]) -> Result<(), SendError> {
219 if dest.kind != LocatorKind::Uds {
220 return Err(SendError::UnsupportedLocator);
221 }
222 if data.len() > self.config.max_datagram {
223 return Err(SendError::PayloadTooLarge {
224 size: data.len(),
225 limit: self.config.max_datagram,
226 });
227 }
228 let path = socket_path(&self.config.base_dir, dest.address);
229 match self.socket.send_to(data, &path) {
230 Ok(_) => Ok(()),
231 Err(e) => Err(classify_send_error(&e)),
232 }
233 }
234
235 fn recv(&self) -> Result<ReceivedDatagram, RecvError> {
236 let mut buf = vec![0u8; self.config.max_datagram];
237 match self.socket.recv_from(&mut buf) {
238 Ok((len, addr)) => {
239 buf.truncate(len);
240 let source = source_locator(&addr, &self.config.base_dir);
241 Ok(ReceivedDatagram { source, data: buf })
242 }
243 Err(e)
244 if matches!(
245 e.kind(),
246 io::ErrorKind::WouldBlock | io::ErrorKind::TimedOut
247 ) =>
248 {
249 Err(RecvError::Timeout)
250 }
251 Err(_) => Err(RecvError::Io {
252 message: "uds recv failed",
253 }),
254 }
255 }
256
257 fn local_locator(&self) -> Locator {
258 Locator::uds(self.local_id)
259 }
260}
261
262fn classify_send_error(e: &io::Error) -> SendError {
268 match e.kind() {
269 io::ErrorKind::NotFound => SendError::Io {
270 message: "uds: destination socket missing",
271 },
272 io::ErrorKind::PermissionDenied => SendError::Io {
273 message: "uds: permission denied on destination",
274 },
275 io::ErrorKind::WouldBlock => SendError::Io {
276 message: "uds: kernel send buffer full (EAGAIN)",
277 },
278 io::ErrorKind::ConnectionRefused => SendError::Io {
279 message: "uds: peer socket exists but refused (ECONNREFUSED)",
280 },
281 _ => SendError::Io {
282 message: "uds: send failed",
283 },
284 }
285}
286
287fn source_locator(addr: &std::os::unix::net::SocketAddr, base_dir: &Path) -> Locator {
292 let Some(path) = addr.as_pathname() else {
293 return Locator::INVALID;
294 };
295 let Ok(rel) = path.strip_prefix(base_dir) else {
296 return Locator::INVALID;
297 };
298 let Some(stem) = rel.file_stem() else {
299 return Locator::INVALID;
300 };
301 let Some(stem_str) = stem.to_str() else {
302 return Locator::INVALID;
303 };
304 let Ok(id) = parse_hex_id(stem_str) else {
305 return Locator::INVALID;
306 };
307 Locator::uds(id)
308}
309
310fn parse_hex_id(s: &str) -> Result<[u8; 16], ()> {
311 if s.len() != 32 {
312 return Err(());
313 }
314 let mut out = [0u8; 16];
315 for (i, chunk) in s.as_bytes().chunks_exact(2).enumerate() {
316 let hi = hex_nibble(chunk[0])?;
317 let lo = hex_nibble(chunk[1])?;
318 out[i] = (hi << 4) | lo;
319 }
320 Ok(out)
321}
322
323fn hex_nibble(b: u8) -> Result<u8, ()> {
324 match b {
325 b'0'..=b'9' => Ok(b - b'0'),
326 b'a'..=b'f' => Ok(b - b'a' + 10),
327 b'A'..=b'F' => Ok(b - b'A' + 10),
328 _ => Err(()),
329 }
330}
331
332#[cfg(test)]
333#[allow(clippy::expect_used, clippy::unwrap_used, clippy::panic)]
334mod tests {
335 use super::*;
336
337 fn id_for(n: u8) -> [u8; 16] {
338 let mut a = [0u8; 16];
339 a[15] = n;
340 a
341 }
342
343 fn cfg_with_tmp(dir: &Path) -> UdsConfig {
344 UdsConfig {
345 base_dir: dir.to_path_buf(),
346 max_datagram: 8192,
347 recv_timeout: Some(Duration::from_millis(500)),
348 }
349 }
350
351 #[test]
352 fn socket_path_is_hex_under_base_dir() {
353 let mut id = [0u8; 16];
354 id[0] = 0xDE;
355 id[1] = 0xAD;
356 id[15] = 0xFF;
357 let p = socket_path(Path::new("/tmp/xyz"), id);
358 let s = p.to_string_lossy();
359 assert!(s.starts_with("/tmp/xyz/dead"));
360 assert!(s.ends_with("ff.sock"));
361 }
362
363 #[test]
364 fn parse_hex_id_roundtrip() {
365 let id = [0x11u8; 16];
366 let mut hex = String::new();
367 for b in id {
368 use std::fmt::Write;
369 let _ = write!(hex, "{b:02x}");
370 }
371 assert_eq!(parse_hex_id(&hex), Ok(id));
372 }
373
374 #[test]
375 fn parse_hex_id_rejects_wrong_length() {
376 assert_eq!(parse_hex_id("abc"), Err(()));
377 }
378
379 #[test]
380 fn parse_hex_id_rejects_non_hex() {
381 assert_eq!(parse_hex_id(&"zz".repeat(16)), Err(()));
382 }
383
384 #[test]
385 fn bind_creates_socket_file() {
386 let tmp = tempfile::tempdir().unwrap();
387 let t = UdsTransport::bind(id_for(1), cfg_with_tmp(tmp.path())).unwrap();
388 let expected = socket_path(tmp.path(), id_for(1));
389 assert!(expected.exists(), "socket file should exist: {expected:?}");
390 drop(t);
391 assert!(
392 !expected.exists(),
393 "Drop should clean up socket file at {expected:?}"
394 );
395 }
396
397 #[test]
398 fn bind_reuses_path_after_stale_leftover() {
399 let tmp = tempfile::tempdir().unwrap();
400 let path = socket_path(tmp.path(), id_for(2));
401 fs::create_dir_all(tmp.path()).unwrap();
402 fs::write(&path, b"stale").unwrap();
404 let _t = UdsTransport::bind(id_for(2), cfg_with_tmp(tmp.path()))
405 .expect("bind must remove stale file and succeed");
406 assert!(path.exists());
407 }
408
409 #[test]
410 fn send_and_recv_roundtrip_same_process() {
411 let tmp = tempfile::tempdir().unwrap();
412 let rx = UdsTransport::bind(id_for(10), cfg_with_tmp(tmp.path())).unwrap();
413 let tx = UdsTransport::bind(id_for(11), cfg_with_tmp(tmp.path())).unwrap();
414
415 tx.send(&Locator::uds(id_for(10)), b"hello zerodds")
416 .unwrap();
417
418 let got = rx.recv().expect("recv");
419 assert_eq!(got.data, b"hello zerodds");
420 assert_eq!(got.source, Locator::uds(id_for(11)));
421 }
422
423 #[test]
424 fn send_rejects_non_uds_locator() {
425 let tmp = tempfile::tempdir().unwrap();
426 let tx = UdsTransport::bind(id_for(20), cfg_with_tmp(tmp.path())).unwrap();
427 let res = tx.send(&Locator::udp_v4([127, 0, 0, 1], 7400), b"x");
428 assert_eq!(res, Err(SendError::UnsupportedLocator));
429 }
430
431 #[test]
432 fn send_rejects_oversize_payload() {
433 let tmp = tempfile::tempdir().unwrap();
434 let tx = UdsTransport::bind(id_for(30), cfg_with_tmp(tmp.path())).unwrap();
435 let big = vec![0u8; 10_000]; let res = tx.send(&Locator::uds(id_for(31)), &big);
437 assert!(matches!(res, Err(SendError::PayloadTooLarge { .. })));
438 }
439
440 #[test]
441 fn send_to_missing_peer_is_io_error() {
442 let tmp = tempfile::tempdir().unwrap();
443 let tx = UdsTransport::bind(id_for(40), cfg_with_tmp(tmp.path())).unwrap();
444 let res = tx.send(&Locator::uds(id_for(99)), b"nobody home");
445 assert!(matches!(res, Err(SendError::Io { .. })));
446 }
447
448 #[test]
449 fn recv_times_out_when_idle() {
450 let tmp = tempfile::tempdir().unwrap();
451 let rx = UdsTransport::bind(id_for(50), cfg_with_tmp(tmp.path())).unwrap();
452 let res = rx.recv();
453 assert_eq!(res, Err(RecvError::Timeout));
454 }
455
456 #[test]
457 fn local_locator_reflects_bind_id() {
458 let tmp = tempfile::tempdir().unwrap();
459 let t = UdsTransport::bind(id_for(60), cfg_with_tmp(tmp.path())).unwrap();
460 assert_eq!(t.local_locator(), Locator::uds(id_for(60)));
461 }
462
463 #[test]
466 fn classify_send_error_maps_kinds() {
467 use std::io;
468 let cases = [
469 (io::ErrorKind::NotFound, "destination socket missing"),
470 (io::ErrorKind::PermissionDenied, "permission denied"),
471 (io::ErrorKind::WouldBlock, "kernel send buffer full"),
472 (
473 io::ErrorKind::ConnectionRefused,
474 "peer socket exists but refused",
475 ),
476 (io::ErrorKind::Other, "send failed"),
477 ];
478 for (kind, expected_substr) in cases {
479 let e = io::Error::new(kind, "synthetic");
480 match classify_send_error(&e) {
481 SendError::Io { message } => {
482 assert!(
483 message.contains(expected_substr),
484 "kind {kind:?}: got {message:?}, want substring {expected_substr:?}",
485 );
486 }
487 other => panic!("expected Io, got {other:?}"),
488 }
489 }
490 }
491
492 #[cfg(unix)]
493 #[test]
494 fn ensure_base_dir_rejects_symlink() {
495 let tmp = tempfile::tempdir().unwrap();
496 let real = tmp.path().join("real");
497 fs::create_dir(&real).unwrap();
498 let link = tmp.path().join("link");
499 std::os::unix::fs::symlink(&real, &link).unwrap();
500
501 let err = ensure_base_dir(&link).expect_err("symlink must be rejected");
502 assert_eq!(err.kind(), std::io::ErrorKind::InvalidInput);
503 }
504
505 #[test]
506 fn ensure_base_dir_rejects_file_path() {
507 let tmp = tempfile::tempdir().unwrap();
508 let file = tmp.path().join("not-a-dir");
509 fs::write(&file, b"").unwrap();
510 let err = ensure_base_dir(&file).expect_err("regular file must be rejected");
511 assert_eq!(err.kind(), std::io::ErrorKind::InvalidInput);
512 }
513
514 #[test]
515 fn ensure_base_dir_creates_missing_dir() {
516 let tmp = tempfile::tempdir().unwrap();
517 let missing = tmp.path().join("created");
518 ensure_base_dir(&missing).unwrap();
519 assert!(missing.is_dir());
520 }
521}
522
523#[cfg(target_os = "linux")]
525pub mod abstract_dgram;