zenoh_link_unixsock_stream/
unicast.rs1use std::{
15 cell::UnsafeCell, collections::HashMap, fmt, fs::remove_file, os::unix::io::RawFd,
16 path::PathBuf, sync::Arc, time::Duration,
17};
18
19use async_trait::async_trait;
20use tokio::{
21 io::{AsyncReadExt, AsyncWriteExt},
22 net::{UnixListener, UnixStream},
23 sync::RwLock as AsyncRwLock,
24 task::JoinHandle,
25};
26use tokio_util::sync::CancellationToken;
27use uuid::Uuid;
28use zenoh_core::{zasyncread, zasyncwrite};
29use zenoh_link_commons::{
30 LinkAuthId, LinkManagerUnicastTrait, LinkUnicast, LinkUnicastTrait, NewLinkChannelSender,
31};
32use zenoh_protocol::{
33 core::{EndPoint, Locator},
34 transport::BatchSize,
35};
36use zenoh_result::{zerror, ZResult};
37
38use super::{
39 get_unix_path_as_string, UNIXSOCKSTREAM_ACCEPT_THROTTLE_TIME, UNIXSOCKSTREAM_DEFAULT_MTU,
40 UNIXSOCKSTREAM_LOCATOR_PREFIX,
41};
42
43pub struct LinkUnicastUnixSocketStream {
44 socket: UnsafeCell<UnixStream>,
46 src_locator: Locator,
48 dst_locator: Locator,
50}
51
52unsafe impl Sync for LinkUnicastUnixSocketStream {}
53
54impl LinkUnicastUnixSocketStream {
55 fn new(socket: UnixStream, src_path: &str, dst_path: &str) -> LinkUnicastUnixSocketStream {
56 LinkUnicastUnixSocketStream {
57 socket: UnsafeCell::new(socket),
58 src_locator: Locator::new(UNIXSOCKSTREAM_LOCATOR_PREFIX, src_path, "").unwrap(),
59 dst_locator: Locator::new(UNIXSOCKSTREAM_LOCATOR_PREFIX, dst_path, "").unwrap(),
60 }
61 }
62
63 #[allow(clippy::mut_from_ref)]
64 fn get_mut_socket(&self) -> &mut UnixStream {
65 unsafe { &mut *self.socket.get() }
66 }
67}
68
69#[async_trait]
70impl LinkUnicastTrait for LinkUnicastUnixSocketStream {
71 async fn close(&self) -> ZResult<()> {
72 tracing::trace!("Closing UnixSocketStream link: {}", self);
73 let res = self.get_mut_socket().shutdown().await;
75 tracing::trace!("UnixSocketStream link shutdown {}: {:?}", self, res);
76 res.map_err(|e| zerror!(e).into())
77 }
78
79 async fn write(&self, buffer: &[u8]) -> ZResult<usize> {
80 self.get_mut_socket().write(buffer).await.map_err(|e| {
81 let e = zerror!("Write error on UnixSocketStream link {}: {}", self, e);
82 tracing::trace!("{}", e);
83 e.into()
84 })
85 }
86
87 async fn write_all(&self, buffer: &[u8]) -> ZResult<()> {
88 self.get_mut_socket().write_all(buffer).await.map_err(|e| {
89 let e = zerror!("Write error on UnixSocketStream link {}: {}", self, e);
90 tracing::trace!("{}", e);
91 e.into()
92 })
93 }
94
95 async fn read(&self, buffer: &mut [u8]) -> ZResult<usize> {
96 self.get_mut_socket().read(buffer).await.map_err(|e| {
97 let e = zerror!("Read error on UnixSocketStream link {}: {}", self, e);
98 tracing::trace!("{}", e);
99 e.into()
100 })
101 }
102
103 async fn read_exact(&self, buffer: &mut [u8]) -> ZResult<()> {
104 self.get_mut_socket()
105 .read_exact(buffer)
106 .await
107 .map(|_len| ())
108 .map_err(|e| {
109 let e = zerror!("Read error on UnixSocketStream link {}: {}", self, e);
110 tracing::trace!("{}", e);
111 e.into()
112 })
113 }
114
115 #[inline(always)]
116 fn get_src(&self) -> &Locator {
117 &self.src_locator
118 }
119
120 #[inline(always)]
121 fn get_dst(&self) -> &Locator {
122 &self.dst_locator
123 }
124
125 #[inline(always)]
126 fn get_mtu(&self) -> BatchSize {
127 *UNIXSOCKSTREAM_DEFAULT_MTU
128 }
129
130 #[inline(always)]
131 fn get_interface_names(&self) -> Vec<String> {
132 tracing::debug!("The get_interface_names for LinkUnicastUnixSocketStream is not supported");
134 vec![]
135 }
136
137 #[inline(always)]
138 fn is_reliable(&self) -> bool {
139 super::IS_RELIABLE
140 }
141
142 #[inline(always)]
143 fn is_streamed(&self) -> bool {
144 true
145 }
146
147 #[inline(always)]
148 fn get_auth_id(&self) -> &LinkAuthId {
149 &LinkAuthId::UnixsockStream
150 }
151}
152
153impl Drop for LinkUnicastUnixSocketStream {
154 fn drop(&mut self) {
155 let _ = zenoh_runtime::ZRuntime::Acceptor
157 .block_in_place(async move { self.get_mut_socket().shutdown().await });
158 }
159}
160
161impl fmt::Display for LinkUnicastUnixSocketStream {
162 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
163 write!(f, "{} => {}", &self.src_locator, &self.dst_locator)?;
164 Ok(())
165 }
166}
167
168impl fmt::Debug for LinkUnicastUnixSocketStream {
169 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
170 f.debug_struct("UnixSocketStream")
171 .field("src", &self.src_locator)
172 .field("dst", &self.dst_locator)
173 .finish()
174 }
175}
176
177struct ListenerUnixSocketStream {
181 endpoint: EndPoint,
182 token: CancellationToken,
183 handle: JoinHandle<ZResult<()>>,
184 lock_fd: RawFd,
185}
186
187impl ListenerUnixSocketStream {
188 fn new(
189 endpoint: EndPoint,
190 token: CancellationToken,
191 handle: JoinHandle<ZResult<()>>,
192 lock_fd: RawFd,
193 ) -> ListenerUnixSocketStream {
194 ListenerUnixSocketStream {
195 endpoint,
196 token,
197 handle,
198 lock_fd,
199 }
200 }
201
202 async fn stop(&self) {
203 self.token.cancel();
204 }
205}
206
207pub struct LinkManagerUnicastUnixSocketStream {
208 manager: NewLinkChannelSender,
209 listeners: Arc<AsyncRwLock<HashMap<String, ListenerUnixSocketStream>>>,
210}
211
212impl LinkManagerUnicastUnixSocketStream {
213 pub fn new(manager: NewLinkChannelSender) -> Self {
214 Self {
215 manager,
216 listeners: Arc::new(AsyncRwLock::new(HashMap::new())),
217 }
218 }
219}
220
221#[async_trait]
222impl LinkManagerUnicastTrait for LinkManagerUnicastUnixSocketStream {
223 async fn new_link(&self, endpoint: EndPoint) -> ZResult<LinkUnicast> {
224 let path = get_unix_path_as_string(endpoint.address());
225
226 let stream = UnixStream::connect(&path).await.map_err(|e| {
228 let e = zerror!(
229 "Can not create a new UnixSocketStream link bound to {:?}: {}",
230 path,
231 e
232 );
233 tracing::warn!("{}", e);
234 e
235 })?;
236
237 let src_addr = stream.local_addr().map_err(|e| {
238 let e = zerror!(
239 "Can not create a new UnixSocketStream link bound to {:?}: {}",
240 path,
241 e
242 );
243 tracing::warn!("{}", e);
244 e
245 })?;
246
247 let _dst_addr = stream.peer_addr().map_err(|e| {
249 let e = zerror!(
250 "Can not create a new UnixSocketStream link bound to {:?}: {}",
251 path,
252 e
253 );
254 tracing::warn!("{}", e);
255 e
256 })?;
257
258 let local_path = match src_addr.as_pathname() {
259 Some(path) => PathBuf::from(path),
260 None => {
261 let e = format!("Can not create a new UnixSocketStream link bound to {path:?}");
262 tracing::warn!("{}", e);
263 PathBuf::from(format!("{}", Uuid::new_v4()))
264 }
265 };
266
267 let local_path_str = local_path.to_str().ok_or_else(|| {
268 let e = zerror!(
269 "Can not create a new UnixSocketStream link bound to {:?}",
270 path
271 );
272 tracing::warn!("{}", e);
273 e
274 })?;
275
276 let remote_path_str = path.as_str();
277
278 let link = Arc::new(LinkUnicastUnixSocketStream::new(
279 stream,
280 local_path_str,
281 remote_path_str,
282 ));
283
284 Ok(LinkUnicast(link))
285 }
286
287 async fn new_listener(&self, mut endpoint: EndPoint) -> ZResult<Locator> {
288 let path = get_unix_path_as_string(endpoint.address());
289
290 let lock_file_path = format!("{path}.lock");
305
306 let mut open_flags = nix::fcntl::OFlag::empty();
309
310 open_flags.insert(nix::fcntl::OFlag::O_CREAT);
311 open_flags.insert(nix::fcntl::OFlag::O_RDONLY);
312
313 let mut open_mode = nix::sys::stat::Mode::empty();
314 open_mode.insert(nix::sys::stat::Mode::S_IRUSR);
315 open_mode.insert(nix::sys::stat::Mode::S_IWUSR);
316
317 let lock_fd = nix::fcntl::open(
318 std::path::Path::new(&lock_file_path),
319 open_flags,
320 open_mode,
321 ).map_err(|e| {
322 let e = zerror!(
323 "Can not create a new UnixSocketStream listener on {} - Unable to open lock file: {}",
324 path, e
325 );
326 tracing::warn!("{}", e);
327 e
328 })?;
329
330 #[allow(deprecated)]
333 nix::fcntl::flock(lock_fd, nix::fcntl::FlockArg::LockExclusiveNonblock).map_err(|e| {
334 let _ = nix::unistd::close(lock_fd);
335 let e = zerror!(
336 "Can not create a new UnixSocketStream listener on {} - Unable to acquire lock: {}",
337 path,
338 e
339 );
340 tracing::warn!("{}", e);
341 e
342 })?;
343
344 let _ = remove_file(path.clone());
348
349 let socket = UnixListener::bind(&path).map_err(|e| {
351 let e = zerror!(
352 "Can not create a new UnixSocketStream listener on {}: {}",
353 path,
354 e
355 );
356 tracing::warn!("{}", e);
357 e
358 })?;
359
360 let local_addr = socket.local_addr().map_err(|e| {
361 let e = zerror!(
362 "Can not create a new UnixSocketStream listener on {}: {}",
363 path,
364 e
365 );
366 tracing::warn!("{}", e);
367 e
368 })?;
369
370 let local_path = PathBuf::from(local_addr.as_pathname().ok_or_else(|| {
371 let e = zerror!("Can not create a new UnixSocketStream listener on {}", path);
372 tracing::warn!("{}", e);
373 e
374 })?);
375
376 let local_path_str = local_path.to_str().ok_or_else(|| {
377 let e = zerror!("Can not create a new UnixSocketStream listener on {}", path);
378 tracing::warn!("{}", e);
379 e
380 })?;
381
382 endpoint = EndPoint::new(
384 endpoint.protocol(),
385 local_path_str,
386 endpoint.metadata(),
387 endpoint.config(),
388 )?;
389
390 let token = CancellationToken::new();
392 let c_token = token.clone();
393 let mut listeners = zasyncwrite!(self.listeners);
394
395 let task = {
396 let manager = self.manager.clone();
397 let listeners = self.listeners.clone();
398 let path = local_path_str.to_owned();
399
400 async move {
401 let res = accept_task(socket, c_token, manager).await;
403 zasyncwrite!(listeners).remove(&path);
404 res
405 }
406 };
407 let handle = zenoh_runtime::ZRuntime::Acceptor.spawn(task);
408
409 let locator = endpoint.to_locator();
410 let listener = ListenerUnixSocketStream::new(endpoint, token, handle, lock_fd);
411 listeners.insert(local_path_str.to_owned(), listener);
412
413 Ok(locator)
414 }
415
416 async fn del_listener(&self, endpoint: &EndPoint) -> ZResult<()> {
417 let path = get_unix_path_as_string(endpoint.address());
418
419 let listener = zasyncwrite!(self.listeners).remove(&path).ok_or_else(|| {
421 let e = zerror!(
422 "Can not delete the UnixSocketStream listener because it has not been found: {}",
423 path
424 );
425 tracing::trace!("{}", e);
426 e
427 })?;
428
429 listener.stop().await;
431 listener.handle.await??;
432
433 #[allow(deprecated)]
436 let _ = nix::fcntl::flock(listener.lock_fd, nix::fcntl::FlockArg::UnlockNonblock);
437 let _ = nix::unistd::close(listener.lock_fd);
438 let _ = remove_file(path.clone());
439
440 let lock_file_path = format!("{path}.lock");
442 let tmp = remove_file(lock_file_path);
443 tracing::trace!("UnixSocketStream Domain Socket removal result: {:?}", tmp);
444
445 Ok(())
446 }
447
448 async fn get_listeners(&self) -> Vec<EndPoint> {
449 zasyncread!(self.listeners)
450 .values()
451 .map(|x| x.endpoint.clone())
452 .collect()
453 }
454
455 async fn get_locators(&self) -> Vec<Locator> {
456 zasyncread!(self.listeners)
457 .values()
458 .map(|x| x.endpoint.to_locator())
459 .collect()
460 }
461}
462
463async fn accept_task(
464 socket: UnixListener,
465 token: CancellationToken,
466 manager: NewLinkChannelSender,
467) -> ZResult<()> {
468 async fn accept(socket: &UnixListener) -> ZResult<UnixStream> {
469 let (stream, _) = socket.accept().await.map_err(|e| zerror!(e))?;
470 Ok(stream)
471 }
472
473 let src_addr = socket.local_addr().map_err(|e| {
474 zerror!("Can not accept UnixSocketStream connections: {}", e);
475 tracing::warn!("{}", e);
476 e
477 })?;
478
479 let local_path = PathBuf::from(src_addr.as_pathname().ok_or_else(|| {
480 let e = zerror!(
481 "Can not create a new UnixSocketStream link bound to {:?}",
482 src_addr
483 );
484 tracing::warn!("{}", e);
485 e
486 })?);
487
488 let src_path = local_path.to_str().ok_or_else(|| {
489 let e = zerror!(
490 "Can not create a new UnixSocketStream link bound to {:?}",
491 src_addr
492 );
493 tracing::warn!("{}", e);
494 e
495 })?;
496
497 tracing::trace!(
499 "Ready to accept UnixSocketStream connections on: {}",
500 src_path
501 );
502
503 loop {
504 tokio::select! {
505 _ = token.cancelled() => break,
506
507 res = accept(&socket) => {
508 match res {
509 Ok(stream) => {
510 let dst_path = format!("{}", Uuid::new_v4());
511
512 tracing::debug!("Accepted UnixSocketStream connection on: {:?}", src_addr,);
513
514 let link = Arc::new(LinkUnicastUnixSocketStream::new(
516 stream, src_path, &dst_path,
517 ));
518
519 if let Err(e) = manager.send_async(LinkUnicast(link)).await {
521 tracing::error!("{}-{}: {}", file!(), line!(), e)
522 }
523
524 }
525 Err(e) => {
526 tracing::warn!("{}. Hint: increase the system open file limit.", e);
527 tokio::time::sleep(Duration::from_micros(*UNIXSOCKSTREAM_ACCEPT_THROTTLE_TIME)).await;
534 }
535 }
536 }
537 }
538 }
539
540 Ok(())
541}