1use crate::poll::{PollEvents, POLLERR, POLLHUP, POLLIN, POLLOUT};
2use crate::vfs::normalize_path;
3use std::collections::{BTreeMap, BTreeSet, VecDeque};
4use std::error::Error;
5use std::fmt;
6use std::net::{Ipv4Addr, Ipv6Addr};
7use std::sync::{Arc, Mutex, MutexGuard};
8
9pub type SocketId = u64;
10pub type SocketResult<T> = Result<T, SocketTableError>;
11
12#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord)]
13pub struct InetSocketAddress {
14 host: String,
15 port: u16,
16}
17
18impl InetSocketAddress {
19 pub fn new(host: impl Into<String>, port: u16) -> Self {
20 Self {
21 host: host.into(),
22 port,
23 }
24 }
25
26 pub fn host(&self) -> &str {
27 &self.host
28 }
29
30 pub const fn port(&self) -> u16 {
31 self.port
32 }
33}
34
35#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
36pub enum SocketDomain {
37 Inet,
38 Inet6,
39 Unix,
40}
41
42#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
43pub enum SocketType {
44 Stream,
45 Datagram,
46}
47
48#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
49pub enum SocketState {
50 Created,
51 Bound,
52 Listening,
53 Connected,
54}
55
56impl SocketState {
57 pub const fn counts_as_listener(self) -> bool {
58 matches!(self, Self::Listening)
59 }
60
61 pub const fn counts_as_connection(self) -> bool {
62 matches!(self, Self::Connected)
63 }
64}
65
66#[derive(Debug, Clone, Copy, PartialEq, Eq)]
67pub enum SocketShutdown {
68 Read,
69 Write,
70 Both,
71}
72
73#[derive(Debug, Clone, Copy, PartialEq, Eq)]
74pub enum DatagramSocketOption {
75 ReuseAddr,
76 ReusePort,
77 Broadcast,
78}
79
80#[derive(Debug, Clone, Copy, PartialEq, Eq)]
81pub struct SocketSpec {
82 pub domain: SocketDomain,
83 pub socket_type: SocketType,
84}
85
86impl SocketSpec {
87 pub const fn new(domain: SocketDomain, socket_type: SocketType) -> Self {
88 Self {
89 domain,
90 socket_type,
91 }
92 }
93
94 pub const fn tcp() -> Self {
95 Self::new(SocketDomain::Inet, SocketType::Stream)
96 }
97
98 pub const fn udp() -> Self {
99 Self::new(SocketDomain::Inet, SocketType::Datagram)
100 }
101
102 pub const fn unix_stream() -> Self {
103 Self::new(SocketDomain::Unix, SocketType::Stream)
104 }
105
106 pub const fn unix_datagram() -> Self {
107 Self::new(SocketDomain::Unix, SocketType::Datagram)
108 }
109}
110
111#[derive(Debug, Clone, PartialEq, Eq)]
112pub struct SocketRecord {
113 id: SocketId,
114 owner_pid: u32,
115 spec: SocketSpec,
116 state: SocketState,
117 local_address: Option<InetSocketAddress>,
118 peer_address: Option<InetSocketAddress>,
119 local_unix_path: Option<String>,
120 peer_unix_path: Option<String>,
121 listener_state: Option<ListenerState>,
122 connection_state: Option<ConnectionState>,
123 datagram_state: Option<DatagramState>,
124}
125
126impl SocketRecord {
127 pub const fn id(&self) -> SocketId {
128 self.id
129 }
130
131 pub const fn owner_pid(&self) -> u32 {
132 self.owner_pid
133 }
134
135 pub const fn spec(&self) -> SocketSpec {
136 self.spec
137 }
138
139 pub const fn state(&self) -> SocketState {
140 self.state
141 }
142
143 pub fn local_address(&self) -> Option<&InetSocketAddress> {
144 self.local_address.as_ref()
145 }
146
147 pub fn peer_address(&self) -> Option<&InetSocketAddress> {
148 self.peer_address.as_ref()
149 }
150
151 pub fn local_unix_path(&self) -> Option<&str> {
152 self.local_unix_path.as_deref()
153 }
154
155 pub fn peer_unix_path(&self) -> Option<&str> {
156 self.peer_unix_path.as_deref()
157 }
158
159 pub fn listen_backlog(&self) -> Option<usize> {
160 self.listener_state.as_ref().map(|state| state.backlog)
161 }
162
163 pub fn pending_accept_count(&self) -> usize {
164 self.listener_state
165 .as_ref()
166 .map(|state| state.pending_accepts.len())
167 .unwrap_or(0)
168 }
169
170 pub fn peer_socket_id(&self) -> Option<SocketId> {
171 self.connection_state
172 .as_ref()
173 .and_then(|state| state.peer_socket_id)
174 }
175
176 pub fn buffered_read_bytes(&self) -> usize {
177 self.connection_state
178 .as_ref()
179 .map(|state| state.recv_buffer.len())
180 .unwrap_or(0)
181 }
182
183 pub fn read_shutdown(&self) -> bool {
184 self.connection_state
185 .as_ref()
186 .map(|state| state.read_shutdown)
187 .unwrap_or(false)
188 }
189
190 pub fn write_shutdown(&self) -> bool {
191 self.connection_state
192 .as_ref()
193 .map(|state| state.write_shutdown)
194 .unwrap_or(false)
195 }
196
197 pub fn peer_write_shutdown(&self) -> bool {
198 self.connection_state
199 .as_ref()
200 .map(|state| state.peer_write_shutdown)
201 .unwrap_or(false)
202 }
203
204 pub fn queued_datagrams(&self) -> usize {
205 self.datagram_state
206 .as_ref()
207 .map(|state| state.recv_queue.len())
208 .unwrap_or(0)
209 }
210
211 pub fn queued_datagram_bytes(&self) -> usize {
212 self.datagram_state
213 .as_ref()
214 .map(|state| datagram_queue_bytes(&state.recv_queue))
215 .unwrap_or(0)
216 }
217
218 pub fn reuse_address(&self) -> bool {
219 self.datagram_state
220 .as_ref()
221 .map(|state| state.reuse_addr)
222 .unwrap_or(false)
223 }
224
225 pub fn reuse_port(&self) -> bool {
226 self.datagram_state
227 .as_ref()
228 .map(|state| state.reuse_port)
229 .unwrap_or(false)
230 }
231
232 pub fn broadcast_enabled(&self) -> bool {
233 self.datagram_state
234 .as_ref()
235 .map(|state| state.broadcast)
236 .unwrap_or(false)
237 }
238
239 pub fn multicast_membership_count(&self) -> usize {
240 self.datagram_state
241 .as_ref()
242 .map(|state| state.multicast_memberships.len())
243 .unwrap_or(0)
244 }
245
246 pub fn has_multicast_membership(&self, membership: &SocketMulticastMembership) -> bool {
247 self.datagram_state
248 .as_ref()
249 .map(|state| state.multicast_memberships.contains(membership))
250 .unwrap_or(false)
251 }
252}
253
254#[derive(Debug, Clone, PartialEq, Eq)]
255pub struct ReceivedDatagram {
256 source_address: Option<InetSocketAddress>,
257 payload: Vec<u8>,
258}
259
260impl ReceivedDatagram {
261 pub fn source_address(&self) -> Option<&InetSocketAddress> {
262 self.source_address.as_ref()
263 }
264
265 pub fn payload(&self) -> &[u8] {
266 &self.payload
267 }
268
269 pub fn into_parts(self) -> (Option<InetSocketAddress>, Vec<u8>) {
270 (self.source_address, self.payload)
271 }
272}
273
274#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
275pub struct SocketTableSnapshot {
276 pub sockets: usize,
277 pub listeners: usize,
278 pub connections: usize,
279 pub buffered_bytes: usize,
280 pub datagram_queue_len: usize,
281}
282
283#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord)]
284pub struct SocketMulticastMembership {
285 group_address: String,
286 interface_address: Option<String>,
287}
288
289impl SocketMulticastMembership {
290 pub fn new(group_address: impl Into<String>, interface_address: Option<String>) -> Self {
291 Self {
292 group_address: group_address.into(),
293 interface_address,
294 }
295 }
296
297 pub fn group_address(&self) -> &str {
298 &self.group_address
299 }
300
301 pub fn interface_address(&self) -> Option<&str> {
302 self.interface_address.as_deref()
303 }
304}
305
306#[derive(Debug, Clone, PartialEq, Eq)]
307pub struct SocketTableError {
308 code: &'static str,
309 message: String,
310}
311
312impl SocketTableError {
313 pub fn code(&self) -> &'static str {
314 self.code
315 }
316
317 fn not_found(socket_id: SocketId) -> Self {
318 Self {
319 code: "ENOENT",
320 message: format!("no such socket {socket_id}"),
321 }
322 }
323
324 fn invalid_argument(message: impl Into<String>) -> Self {
325 Self {
326 code: "EINVAL",
327 message: message.into(),
328 }
329 }
330
331 fn address_in_use(message: impl Into<String>) -> Self {
332 Self {
333 code: "EADDRINUSE",
334 message: message.into(),
335 }
336 }
337
338 fn address_not_available(message: impl Into<String>) -> Self {
339 Self {
340 code: "EADDRNOTAVAIL",
341 message: message.into(),
342 }
343 }
344
345 fn not_found_address(message: impl Into<String>) -> Self {
346 Self {
347 code: "ECONNREFUSED",
348 message: message.into(),
349 }
350 }
351
352 fn would_block(message: impl Into<String>) -> Self {
353 Self {
354 code: "EAGAIN",
355 message: message.into(),
356 }
357 }
358
359 fn not_connected(message: impl Into<String>) -> Self {
360 Self {
361 code: "ENOTCONN",
362 message: message.into(),
363 }
364 }
365
366 fn broken_pipe(message: impl Into<String>) -> Self {
367 Self {
368 code: "EPIPE",
369 message: message.into(),
370 }
371 }
372}
373
374impl fmt::Display for SocketTableError {
375 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
376 write!(f, "{}: {}", self.code, self.message)
377 }
378}
379
380impl Error for SocketTableError {}
381
382#[derive(Debug, Default)]
383struct SocketTableState {
384 sockets: BTreeMap<SocketId, SocketRecord>,
385 by_owner: BTreeMap<u32, BTreeSet<SocketId>>,
386 bound_inet_streams: BTreeMap<InetSocketAddress, SocketId>,
387 bound_inet_datagrams: BTreeMap<InetSocketAddress, BTreeSet<SocketId>>,
388 bound_unix_streams: BTreeMap<String, SocketId>,
389 multicast_groups: BTreeMap<SocketMulticastMembership, BTreeSet<SocketId>>,
390 next_socket_id: SocketId,
391}
392
393#[derive(Debug, Clone, PartialEq, Eq)]
394struct ListenerState {
395 backlog: usize,
396 pending_accepts: VecDeque<PendingConnection>,
397}
398
399#[derive(Debug, Clone, PartialEq, Eq, Default)]
400struct ConnectionState {
401 peer_socket_id: Option<SocketId>,
402 recv_buffer: VecDeque<u8>,
403 read_shutdown: bool,
404 write_shutdown: bool,
405 peer_write_shutdown: bool,
406}
407
408#[derive(Debug, Clone, PartialEq, Eq)]
409struct PendingConnection {
410 peer_address: Option<InetSocketAddress>,
411 peer_unix_path: Option<String>,
412 accepted_socket_id: Option<SocketId>,
413}
414
415#[derive(Debug, Clone, PartialEq, Eq, Default)]
416struct DatagramState {
417 recv_queue: VecDeque<QueuedDatagram>,
418 reuse_addr: bool,
419 reuse_port: bool,
420 broadcast: bool,
421 multicast_memberships: BTreeSet<SocketMulticastMembership>,
422}
423
424#[derive(Debug, Clone, PartialEq, Eq)]
425struct QueuedDatagram {
426 source_address: Option<InetSocketAddress>,
427 payload: Vec<u8>,
428}
429
430#[derive(Debug, Default)]
431struct SocketTableInner {
432 state: Mutex<SocketTableState>,
433}
434
435#[derive(Debug, Clone, Default)]
436pub struct SocketTable {
437 inner: Arc<SocketTableInner>,
438}
439
440impl SocketTable {
441 pub fn new() -> Self {
442 Self::default()
443 }
444
445 pub fn allocate(&self, owner_pid: u32, spec: SocketSpec) -> SocketRecord {
446 self.allocate_with_state(owner_pid, spec, SocketState::Created)
447 }
448
449 pub fn allocate_with_state(
450 &self,
451 owner_pid: u32,
452 spec: SocketSpec,
453 state: SocketState,
454 ) -> SocketRecord {
455 let mut table = lock_or_recover(&self.inner.state);
456 let socket_id = next_socket_id(&mut table);
457 let record = SocketRecord {
458 id: socket_id,
459 owner_pid,
460 spec,
461 state,
462 local_address: None,
463 peer_address: None,
464 local_unix_path: None,
465 peer_unix_path: None,
466 listener_state: None,
467 connection_state: default_connection_state(spec, state),
468 datagram_state: default_datagram_state(spec),
469 };
470 table.sockets.insert(socket_id, record.clone());
471 table
472 .by_owner
473 .entry(owner_pid)
474 .or_default()
475 .insert(socket_id);
476 record
477 }
478
479 pub fn get(&self, socket_id: SocketId) -> Option<SocketRecord> {
480 lock_or_recover(&self.inner.state)
481 .sockets
482 .get(&socket_id)
483 .cloned()
484 }
485
486 pub fn update_state(
487 &self,
488 socket_id: SocketId,
489 new_state: SocketState,
490 ) -> SocketResult<SocketRecord> {
491 let mut table = lock_or_recover(&self.inner.state);
492 let record = table
493 .sockets
494 .get_mut(&socket_id)
495 .ok_or_else(|| SocketTableError::not_found(socket_id))?;
496 validate_state_transition(record.state, new_state)?;
497 record.state = new_state;
498 if new_state != SocketState::Listening {
499 record.listener_state = None;
500 }
501 if new_state == SocketState::Connected && supports_connection_lifecycle(record.spec) {
502 record
503 .connection_state
504 .get_or_insert_with(ConnectionState::default);
505 } else if new_state != SocketState::Connected {
506 record.connection_state = None;
507 }
508 Ok(record.clone())
509 }
510
511 pub fn bind_inet(
512 &self,
513 socket_id: SocketId,
514 address: InetSocketAddress,
515 ) -> SocketResult<SocketRecord> {
516 let address = normalize_inet_address(address);
517 let mut table = lock_or_recover(&self.inner.state);
518 let existing = table
519 .sockets
520 .get(&socket_id)
521 .cloned()
522 .ok_or_else(|| SocketTableError::not_found(socket_id))?;
523 if !supports_inet_bind(existing.spec) {
524 return Err(SocketTableError::invalid_argument(format!(
525 "socket {socket_id} is not an INET socket"
526 )));
527 }
528 let conflicting_ids =
529 lookup_conflicting_bound_inet_socket_ids(&table, existing.spec, &address);
530 if has_incompatible_inet_bind_conflict(&table, &existing, &conflicting_ids) {
531 return Err(SocketTableError::address_in_use(format!(
532 "address {}:{} is already bound",
533 address.host(),
534 address.port()
535 )));
536 }
537 let cloned = {
538 let record = table
539 .sockets
540 .get_mut(&socket_id)
541 .ok_or_else(|| SocketTableError::not_found(socket_id))?;
542
543 match record.state {
544 SocketState::Created => {}
545 SocketState::Bound if record.local_address.as_ref() == Some(&address) => {
546 return Ok(record.clone());
547 }
548 SocketState::Bound | SocketState::Listening | SocketState::Connected => {
549 return Err(SocketTableError::invalid_argument(format!(
550 "socket {socket_id} cannot bind in state {:?}",
551 record.state
552 )));
553 }
554 }
555
556 record.local_address = Some(address.clone());
557 record.peer_address = None;
558 record.local_unix_path = None;
559 record.peer_unix_path = None;
560 record.listener_state = None;
561 record.connection_state = None;
562 record.state = SocketState::Bound;
563 record.clone()
564 };
565 register_bound_inet_socket(&mut table, cloned.spec, address, socket_id);
566 Ok(cloned)
567 }
568
569 pub fn set_datagram_socket_option(
570 &self,
571 socket_id: SocketId,
572 option: DatagramSocketOption,
573 enabled: bool,
574 ) -> SocketResult<SocketRecord> {
575 let mut table = lock_or_recover(&self.inner.state);
576 let record = table
577 .sockets
578 .get_mut(&socket_id)
579 .ok_or_else(|| SocketTableError::not_found(socket_id))?;
580 let datagram_state = datagram_state_mut(record)?;
581
582 match option {
583 DatagramSocketOption::ReuseAddr => datagram_state.reuse_addr = enabled,
584 DatagramSocketOption::ReusePort => datagram_state.reuse_port = enabled,
585 DatagramSocketOption::Broadcast => datagram_state.broadcast = enabled,
586 }
587
588 Ok(record.clone())
589 }
590
591 pub fn add_multicast_membership(
592 &self,
593 socket_id: SocketId,
594 membership: SocketMulticastMembership,
595 ) -> SocketResult<SocketRecord> {
596 let mut table = lock_or_recover(&self.inner.state);
597 let normalized_membership = {
598 let record = table
599 .sockets
600 .get(&socket_id)
601 .ok_or_else(|| SocketTableError::not_found(socket_id))?;
602 validate_multicast_socket(record)?;
603 normalize_multicast_membership(record.spec, membership)?
604 };
605
606 let cloned = {
607 let record = table
608 .sockets
609 .get_mut(&socket_id)
610 .ok_or_else(|| SocketTableError::not_found(socket_id))?;
611 let datagram_state = datagram_state_mut(record)?;
612 datagram_state
613 .multicast_memberships
614 .insert(normalized_membership.clone());
615 record.clone()
616 };
617
618 table
619 .multicast_groups
620 .entry(normalized_membership)
621 .or_default()
622 .insert(socket_id);
623 Ok(cloned)
624 }
625
626 pub fn drop_multicast_membership(
627 &self,
628 socket_id: SocketId,
629 membership: SocketMulticastMembership,
630 ) -> SocketResult<SocketRecord> {
631 let mut table = lock_or_recover(&self.inner.state);
632 let normalized_membership = {
633 let record = table
634 .sockets
635 .get(&socket_id)
636 .ok_or_else(|| SocketTableError::not_found(socket_id))?;
637 validate_multicast_socket(record)?;
638 normalize_multicast_membership(record.spec, membership)?
639 };
640
641 let cloned = {
642 let record = table
643 .sockets
644 .get_mut(&socket_id)
645 .ok_or_else(|| SocketTableError::not_found(socket_id))?;
646 let datagram_state = datagram_state_mut(record)?;
647 if !datagram_state
648 .multicast_memberships
649 .remove(&normalized_membership)
650 {
651 return Err(SocketTableError::address_not_available(format!(
652 "socket {socket_id} has not joined multicast group {}",
653 normalized_membership.group_address()
654 )));
655 }
656 record.clone()
657 };
658
659 if let Some(members) = table.multicast_groups.get_mut(&normalized_membership) {
660 members.remove(&socket_id);
661 if members.is_empty() {
662 table.multicast_groups.remove(&normalized_membership);
663 }
664 }
665
666 Ok(cloned)
667 }
668
669 pub fn bind_unix(
670 &self,
671 socket_id: SocketId,
672 path: impl Into<String>,
673 ) -> SocketResult<SocketRecord> {
674 let path = normalize_unix_socket_path(path.into())?;
675 let mut table = lock_or_recover(&self.inner.state);
676 let existing = table
677 .sockets
678 .get(&socket_id)
679 .cloned()
680 .ok_or_else(|| SocketTableError::not_found(socket_id))?;
681 if !supports_unix_stream_lifecycle(existing.spec) {
682 return Err(SocketTableError::invalid_argument(format!(
683 "socket {socket_id} is not a Unix stream socket"
684 )));
685 }
686 let existing_id = table.bound_unix_streams.get(&path).copied();
687 let cloned = {
688 let record = table
689 .sockets
690 .get_mut(&socket_id)
691 .ok_or_else(|| SocketTableError::not_found(socket_id))?;
692
693 if let Some(bound_socket_id) = existing_id {
694 if bound_socket_id != socket_id {
695 return Err(SocketTableError::address_in_use(format!(
696 "path {path} is already bound"
697 )));
698 }
699 }
700
701 match record.state {
702 SocketState::Created => {}
703 SocketState::Bound if record.local_unix_path.as_deref() == Some(path.as_str()) => {
704 return Ok(record.clone());
705 }
706 SocketState::Bound | SocketState::Listening | SocketState::Connected => {
707 return Err(SocketTableError::invalid_argument(format!(
708 "socket {socket_id} cannot bind in state {:?}",
709 record.state
710 )));
711 }
712 }
713
714 record.local_address = None;
715 record.peer_address = None;
716 record.local_unix_path = Some(path.clone());
717 record.peer_unix_path = None;
718 record.listener_state = None;
719 record.connection_state = None;
720 record.state = SocketState::Bound;
721 record.clone()
722 };
723 table.bound_unix_streams.insert(path, socket_id);
724 Ok(cloned)
725 }
726
727 pub fn listen(&self, socket_id: SocketId, backlog: usize) -> SocketResult<SocketRecord> {
728 if backlog == 0 {
729 return Err(SocketTableError::invalid_argument(
730 "listener backlog must be greater than zero",
731 ));
732 }
733
734 let mut table = lock_or_recover(&self.inner.state);
735 let record = table
736 .sockets
737 .get_mut(&socket_id)
738 .ok_or_else(|| SocketTableError::not_found(socket_id))?;
739
740 if !supports_listener_lifecycle(record.spec) {
741 return Err(SocketTableError::invalid_argument(format!(
742 "socket {socket_id} is not a stream socket"
743 )));
744 }
745 if record.state != SocketState::Bound || !has_bound_endpoint(record) {
746 return Err(SocketTableError::invalid_argument(format!(
747 "socket {socket_id} must be bound before listen"
748 )));
749 }
750
751 record.state = SocketState::Listening;
752 record.listener_state = Some(ListenerState {
753 backlog,
754 pending_accepts: VecDeque::new(),
755 });
756 Ok(record.clone())
757 }
758
759 pub fn enqueue_incoming_tcp_connection(
760 &self,
761 listener_socket_id: SocketId,
762 peer_address: InetSocketAddress,
763 ) -> SocketResult<()> {
764 let mut table = lock_or_recover(&self.inner.state);
765 let record = table
766 .sockets
767 .get_mut(&listener_socket_id)
768 .ok_or_else(|| SocketTableError::not_found(listener_socket_id))?;
769
770 if record.state != SocketState::Listening {
771 return Err(SocketTableError::invalid_argument(format!(
772 "socket {listener_socket_id} is not listening"
773 )));
774 }
775
776 let listener_state = record.listener_state.as_mut().ok_or_else(|| {
777 SocketTableError::invalid_argument(format!(
778 "socket {listener_socket_id} has no listener state"
779 ))
780 })?;
781
782 if listener_state.pending_accepts.len() >= listener_state.backlog {
783 return Err(SocketTableError::would_block(format!(
784 "listener {listener_socket_id} backlog is full"
785 )));
786 }
787
788 listener_state.pending_accepts.push_back(PendingConnection {
789 peer_address: Some(peer_address),
790 peer_unix_path: None,
791 accepted_socket_id: None,
792 });
793 Ok(())
794 }
795
796 pub fn accept(&self, listener_socket_id: SocketId) -> SocketResult<SocketRecord> {
797 let mut table = lock_or_recover(&self.inner.state);
798 let (owner_pid, spec, local_address, local_unix_path, pending) = {
799 let record = table
800 .sockets
801 .get_mut(&listener_socket_id)
802 .ok_or_else(|| SocketTableError::not_found(listener_socket_id))?;
803
804 if record.state != SocketState::Listening {
805 return Err(SocketTableError::invalid_argument(format!(
806 "socket {listener_socket_id} is not listening"
807 )));
808 }
809
810 let listener_state = record.listener_state.as_mut().ok_or_else(|| {
811 SocketTableError::invalid_argument(format!(
812 "socket {listener_socket_id} has no listener state"
813 ))
814 })?;
815 let pending = listener_state.pending_accepts.pop_front().ok_or_else(|| {
816 SocketTableError::would_block(format!(
817 "listener {listener_socket_id} has no pending connections"
818 ))
819 })?;
820
821 (
822 record.owner_pid,
823 record.spec,
824 record.local_address.clone(),
825 record.local_unix_path.clone(),
826 pending,
827 )
828 };
829
830 if let Some(accepted_socket_id) = pending.accepted_socket_id {
831 return table
832 .sockets
833 .get(&accepted_socket_id)
834 .cloned()
835 .ok_or_else(|| SocketTableError::not_found(accepted_socket_id));
836 }
837
838 let socket_id = next_socket_id(&mut table);
839 let record = SocketRecord {
840 id: socket_id,
841 owner_pid,
842 spec,
843 state: SocketState::Connected,
844 local_address,
845 peer_address: pending.peer_address,
846 local_unix_path,
847 peer_unix_path: pending.peer_unix_path,
848 listener_state: None,
849 connection_state: default_connection_state(spec, SocketState::Connected),
850 datagram_state: default_datagram_state(spec),
851 };
852 table.sockets.insert(socket_id, record.clone());
853 table
854 .by_owner
855 .entry(owner_pid)
856 .or_default()
857 .insert(socket_id);
858 Ok(record)
859 }
860
861 pub fn connect_pair(
862 &self,
863 socket_id: SocketId,
864 peer_socket_id: SocketId,
865 ) -> SocketResult<(SocketRecord, SocketRecord)> {
866 if socket_id == peer_socket_id {
867 return Err(SocketTableError::invalid_argument(
868 "socket cannot connect to itself",
869 ));
870 }
871
872 let mut table = lock_or_recover(&self.inner.state);
873 let mut socket = table
874 .sockets
875 .remove(&socket_id)
876 .ok_or_else(|| SocketTableError::not_found(socket_id))?;
877 let Some(mut peer) = table.sockets.remove(&peer_socket_id) else {
878 table.sockets.insert(socket_id, socket);
879 return Err(SocketTableError::not_found(peer_socket_id));
880 };
881
882 if let Err(error) = validate_connect_pair(&socket, &peer) {
883 table.sockets.insert(socket_id, socket);
884 table.sockets.insert(peer_socket_id, peer);
885 return Err(error);
886 }
887
888 socket.state = SocketState::Connected;
889 socket.peer_address = peer.local_address.clone();
890 socket.peer_unix_path = peer.local_unix_path.clone();
891 socket.listener_state = None;
892 socket.connection_state = Some(ConnectionState {
893 peer_socket_id: Some(peer_socket_id),
894 ..ConnectionState::default()
895 });
896
897 peer.state = SocketState::Connected;
898 peer.peer_address = socket.local_address.clone();
899 peer.peer_unix_path = socket.local_unix_path.clone();
900 peer.listener_state = None;
901 peer.connection_state = Some(ConnectionState {
902 peer_socket_id: Some(socket_id),
903 ..ConnectionState::default()
904 });
905
906 let socket_clone = socket.clone();
907 let peer_clone = peer.clone();
908 table.sockets.insert(socket_id, socket);
909 table.sockets.insert(peer_socket_id, peer);
910 Ok((socket_clone, peer_clone))
911 }
912
913 pub fn find_bound_inet_socket(
914 &self,
915 spec: SocketSpec,
916 address: &InetSocketAddress,
917 ) -> Option<SocketRecord> {
918 let address = normalize_inet_address(address.clone());
919 let table = lock_or_recover(&self.inner.state);
920 let socket_id = lookup_bound_inet_socket(&table, spec, &address)?;
921 table.sockets.get(&socket_id).cloned()
922 }
923
924 pub fn connect_to_bound_inet_stream(
925 &self,
926 socket_id: SocketId,
927 target_address: InetSocketAddress,
928 ) -> SocketResult<()> {
929 let target_address = normalize_inet_address(target_address);
930 let mut table = lock_or_recover(&self.inner.state);
931 let listener_socket_id =
932 lookup_bound_inet_socket_in_table(&table.bound_inet_streams, &target_address)
933 .ok_or_else(|| {
934 SocketTableError::not_found_address(format!(
935 "no listening socket bound at {}:{}",
936 target_address.host(),
937 target_address.port()
938 ))
939 })?;
940
941 if socket_id == listener_socket_id {
942 return Err(SocketTableError::invalid_argument(
943 "socket cannot connect to its own listening endpoint",
944 ));
945 }
946
947 let mut client = table
948 .sockets
949 .remove(&socket_id)
950 .ok_or_else(|| SocketTableError::not_found(socket_id))?;
951 let result = (|| {
952 {
957 let listener = table
958 .sockets
959 .get(&listener_socket_id)
960 .ok_or_else(|| SocketTableError::not_found(listener_socket_id))?;
961 validate_connect_to_listener(&client, listener)?;
962
963 let listener_state = listener.listener_state.as_ref().ok_or_else(|| {
964 SocketTableError::invalid_argument(format!(
965 "socket {listener_socket_id} has no listener state"
966 ))
967 })?;
968 if listener_state.pending_accepts.len() >= listener_state.backlog {
969 return Err(SocketTableError::would_block(format!(
970 "listener {listener_socket_id} backlog is full"
971 )));
972 }
973 }
974
975 let accepted_socket_id = next_socket_id(&mut table);
977 let listener = table
978 .sockets
979 .get_mut(&listener_socket_id)
980 .ok_or_else(|| SocketTableError::not_found(listener_socket_id))?;
981 let listener_state = listener.listener_state.as_mut().ok_or_else(|| {
982 SocketTableError::invalid_argument(format!(
983 "socket {listener_socket_id} has no listener state"
984 ))
985 })?;
986
987 let accepted = SocketRecord {
988 id: accepted_socket_id,
989 owner_pid: listener.owner_pid,
990 spec: listener.spec,
991 state: SocketState::Connected,
992 local_address: listener.local_address.clone(),
993 peer_address: client.local_address.clone(),
994 local_unix_path: None,
995 peer_unix_path: None,
996 listener_state: None,
997 connection_state: Some(ConnectionState {
998 peer_socket_id: Some(socket_id),
999 ..ConnectionState::default()
1000 }),
1001 datagram_state: default_datagram_state(listener.spec),
1002 };
1003
1004 listener_state.pending_accepts.push_back(PendingConnection {
1005 peer_address: client.local_address.clone(),
1006 peer_unix_path: None,
1007 accepted_socket_id: Some(accepted_socket_id),
1008 });
1009
1010 client.state = SocketState::Connected;
1011 client.peer_address = listener.local_address.clone();
1012 client.peer_unix_path = None;
1013 client.listener_state = None;
1014 client.connection_state = Some(ConnectionState {
1015 peer_socket_id: Some(accepted_socket_id),
1016 ..ConnectionState::default()
1017 });
1018
1019 Ok(accepted)
1020 })();
1021
1022 match result {
1023 Ok(accepted) => {
1024 let accepted_socket_id = accepted.id;
1025 table.sockets.insert(socket_id, client);
1026 table.sockets.insert(accepted_socket_id, accepted.clone());
1027 table
1028 .by_owner
1029 .entry(accepted.owner_pid)
1030 .or_default()
1031 .insert(accepted_socket_id);
1032 Ok(())
1033 }
1034 Err(error) => {
1035 table.sockets.insert(socket_id, client);
1036 Err(error)
1037 }
1038 }
1039 }
1040
1041 pub fn find_bound_unix_socket(&self, path: &str) -> Option<SocketRecord> {
1042 let path = normalize_unix_socket_path(path).ok()?;
1043 let table = lock_or_recover(&self.inner.state);
1044 let socket_id = table.bound_unix_streams.get(&path).copied()?;
1045 table.sockets.get(&socket_id).cloned()
1046 }
1047
1048 pub fn connect_to_bound_unix_stream(
1049 &self,
1050 socket_id: SocketId,
1051 target_path: impl Into<String>,
1052 ) -> SocketResult<()> {
1053 let target_path = normalize_unix_socket_path(target_path.into())?;
1054 let mut table = lock_or_recover(&self.inner.state);
1055 let listener_socket_id = table
1056 .bound_unix_streams
1057 .get(&target_path)
1058 .copied()
1059 .ok_or_else(|| {
1060 SocketTableError::not_found_address(format!(
1061 "no listening socket bound at path {target_path}"
1062 ))
1063 })?;
1064
1065 if socket_id == listener_socket_id {
1066 return Err(SocketTableError::invalid_argument(
1067 "socket cannot connect to its own listening endpoint",
1068 ));
1069 }
1070
1071 let mut client = table
1072 .sockets
1073 .remove(&socket_id)
1074 .ok_or_else(|| SocketTableError::not_found(socket_id))?;
1075 let result = (|| {
1076 {
1081 let listener = table
1082 .sockets
1083 .get(&listener_socket_id)
1084 .ok_or_else(|| SocketTableError::not_found(listener_socket_id))?;
1085 validate_connect_to_listener(&client, listener)?;
1086
1087 let listener_state = listener.listener_state.as_ref().ok_or_else(|| {
1088 SocketTableError::invalid_argument(format!(
1089 "socket {listener_socket_id} has no listener state"
1090 ))
1091 })?;
1092 if listener_state.pending_accepts.len() >= listener_state.backlog {
1093 return Err(SocketTableError::would_block(format!(
1094 "listener {listener_socket_id} backlog is full"
1095 )));
1096 }
1097 }
1098
1099 let accepted_socket_id = next_socket_id(&mut table);
1101 let listener = table
1102 .sockets
1103 .get_mut(&listener_socket_id)
1104 .ok_or_else(|| SocketTableError::not_found(listener_socket_id))?;
1105 let listener_state = listener.listener_state.as_mut().ok_or_else(|| {
1106 SocketTableError::invalid_argument(format!(
1107 "socket {listener_socket_id} has no listener state"
1108 ))
1109 })?;
1110
1111 let accepted = SocketRecord {
1112 id: accepted_socket_id,
1113 owner_pid: listener.owner_pid,
1114 spec: listener.spec,
1115 state: SocketState::Connected,
1116 local_address: None,
1117 peer_address: None,
1118 local_unix_path: listener.local_unix_path.clone(),
1119 peer_unix_path: client.local_unix_path.clone(),
1120 listener_state: None,
1121 connection_state: Some(ConnectionState {
1122 peer_socket_id: Some(socket_id),
1123 ..ConnectionState::default()
1124 }),
1125 datagram_state: default_datagram_state(listener.spec),
1126 };
1127
1128 listener_state.pending_accepts.push_back(PendingConnection {
1129 peer_address: None,
1130 peer_unix_path: client.local_unix_path.clone(),
1131 accepted_socket_id: Some(accepted_socket_id),
1132 });
1133
1134 client.state = SocketState::Connected;
1135 client.peer_address = None;
1136 client.peer_unix_path = listener.local_unix_path.clone();
1137 client.listener_state = None;
1138 client.connection_state = Some(ConnectionState {
1139 peer_socket_id: Some(accepted_socket_id),
1140 ..ConnectionState::default()
1141 });
1142
1143 Ok(accepted)
1144 })();
1145
1146 match result {
1147 Ok(accepted) => {
1148 let accepted_socket_id = accepted.id;
1149 table.sockets.insert(socket_id, client);
1150 table.sockets.insert(accepted_socket_id, accepted.clone());
1151 table
1152 .by_owner
1153 .entry(accepted.owner_pid)
1154 .or_default()
1155 .insert(accepted_socket_id);
1156 Ok(())
1157 }
1158 Err(error) => {
1159 table.sockets.insert(socket_id, client);
1160 Err(error)
1161 }
1162 }
1163 }
1164
1165 pub fn send_to_bound_udp_socket(
1166 &self,
1167 socket_id: SocketId,
1168 target_address: InetSocketAddress,
1169 data: &[u8],
1170 ) -> SocketResult<usize> {
1171 let target_address = normalize_inet_address(target_address);
1172 let mut table = lock_or_recover(&self.inner.state);
1173 let sender = table
1174 .sockets
1175 .get(&socket_id)
1176 .cloned()
1177 .ok_or_else(|| SocketTableError::not_found(socket_id))?;
1178 validate_bound_udp_sender(&sender)?;
1179
1180 let receiver_socket_id = lookup_bound_inet_datagram_socket_in_table(
1181 &table.bound_inet_datagrams,
1182 &target_address,
1183 )
1184 .ok_or_else(|| {
1185 SocketTableError::not_found_address(format!(
1186 "no UDP socket bound at {}:{}",
1187 target_address.host(),
1188 target_address.port()
1189 ))
1190 })?;
1191 let receiver = table
1192 .sockets
1193 .get_mut(&receiver_socket_id)
1194 .ok_or_else(|| SocketTableError::not_found(receiver_socket_id))?;
1195 validate_bound_udp_receiver(receiver)?;
1196
1197 let datagram_state = receiver.datagram_state.as_mut().ok_or_else(|| {
1198 SocketTableError::invalid_argument(format!(
1199 "socket {receiver_socket_id} does not support datagrams"
1200 ))
1201 })?;
1202 datagram_state.recv_queue.push_back(QueuedDatagram {
1203 source_address: sender.local_address.clone(),
1204 payload: data.to_vec(),
1205 });
1206 Ok(data.len())
1207 }
1208
1209 pub fn check_send_to_bound_udp_socket(
1210 &self,
1211 socket_id: SocketId,
1212 target_address: InetSocketAddress,
1213 ) -> SocketResult<()> {
1214 let target_address = normalize_inet_address(target_address);
1215 let table = lock_or_recover(&self.inner.state);
1216 let sender = table
1217 .sockets
1218 .get(&socket_id)
1219 .ok_or_else(|| SocketTableError::not_found(socket_id))?;
1220 validate_bound_udp_sender(sender)?;
1221
1222 let receiver_socket_id = lookup_bound_inet_datagram_socket_in_table(
1223 &table.bound_inet_datagrams,
1224 &target_address,
1225 )
1226 .ok_or_else(|| {
1227 SocketTableError::not_found_address(format!(
1228 "no UDP socket bound at {}:{}",
1229 target_address.host(),
1230 target_address.port()
1231 ))
1232 })?;
1233 let receiver = table
1234 .sockets
1235 .get(&receiver_socket_id)
1236 .ok_or_else(|| SocketTableError::not_found(receiver_socket_id))?;
1237 validate_bound_udp_receiver(receiver)?;
1238 Ok(())
1239 }
1240
1241 pub fn recv_datagram(
1242 &self,
1243 socket_id: SocketId,
1244 max_bytes: usize,
1245 ) -> SocketResult<Option<ReceivedDatagram>> {
1246 let mut table = lock_or_recover(&self.inner.state);
1247 let record = table
1248 .sockets
1249 .get_mut(&socket_id)
1250 .ok_or_else(|| SocketTableError::not_found(socket_id))?;
1251 validate_bound_udp_receiver(record)?;
1252
1253 let datagram_state = record.datagram_state.as_mut().ok_or_else(|| {
1254 SocketTableError::invalid_argument(format!(
1255 "socket {socket_id} does not support datagrams"
1256 ))
1257 })?;
1258 let Some(datagram) = datagram_state.recv_queue.pop_front() else {
1259 return Err(SocketTableError::would_block(format!(
1260 "socket {socket_id} has no queued datagrams"
1261 )));
1262 };
1263
1264 let payload = if datagram.payload.len() > max_bytes {
1265 datagram.payload[..max_bytes].to_vec()
1266 } else {
1267 datagram.payload
1268 };
1269 Ok(Some(ReceivedDatagram {
1270 source_address: datagram.source_address,
1271 payload,
1272 }))
1273 }
1274
1275 pub fn poll(&self, socket_id: SocketId, requested: PollEvents) -> SocketResult<PollEvents> {
1276 let table = lock_or_recover(&self.inner.state);
1277 let record = table
1278 .sockets
1279 .get(&socket_id)
1280 .ok_or_else(|| SocketTableError::not_found(socket_id))?;
1281
1282 let mut events = PollEvents::empty();
1283 match record.state {
1284 SocketState::Listening => {
1285 if requested.intersects(POLLIN) && record.pending_accept_count() > 0 {
1286 events |= POLLIN;
1287 }
1288 }
1289 SocketState::Connected => {
1290 let connection = record.connection_state.as_ref().ok_or_else(|| {
1291 SocketTableError::not_connected(format!("socket {socket_id} is not connected"))
1292 })?;
1293 let peer = connection
1294 .peer_socket_id
1295 .and_then(|peer_socket_id| table.sockets.get(&peer_socket_id));
1296
1297 if requested.intersects(POLLIN) && !connection.recv_buffer.is_empty() {
1298 events |= POLLIN;
1299 }
1300 if connection.peer_write_shutdown || peer.is_none() {
1301 events |= POLLHUP;
1302 }
1303
1304 if requested.intersects(POLLOUT) && !connection.write_shutdown {
1305 if peer
1306 .and_then(|peer| peer.connection_state.as_ref())
1307 .map(|peer_connection| peer_connection.read_shutdown)
1308 .unwrap_or(true)
1309 {
1310 events |= POLLERR;
1311 } else {
1312 events |= POLLOUT;
1313 }
1314 }
1315 }
1316 SocketState::Bound if supports_inet_datagram_lifecycle(record.spec) => {
1317 let datagram_state = record.datagram_state.as_ref().ok_or_else(|| {
1318 SocketTableError::invalid_argument(format!(
1319 "socket {socket_id} does not support datagrams"
1320 ))
1321 })?;
1322 if requested.intersects(POLLIN) && !datagram_state.recv_queue.is_empty() {
1323 events |= POLLIN;
1324 }
1325 if requested.intersects(POLLOUT) {
1326 events |= POLLOUT;
1327 }
1328 }
1329 SocketState::Created | SocketState::Bound => {}
1330 }
1331
1332 Ok(events)
1333 }
1334
1335 pub fn write(&self, socket_id: SocketId, data: &[u8]) -> SocketResult<usize> {
1336 let mut table = lock_or_recover(&self.inner.state);
1337 let record = table
1338 .sockets
1339 .get(&socket_id)
1340 .cloned()
1341 .ok_or_else(|| SocketTableError::not_found(socket_id))?;
1342 let connection = record.connection_state.as_ref().ok_or_else(|| {
1343 SocketTableError::not_connected(format!("socket {socket_id} is not connected"))
1344 })?;
1345 if record.state != SocketState::Connected {
1346 return Err(SocketTableError::not_connected(format!(
1347 "socket {socket_id} is not connected"
1348 )));
1349 }
1350 if connection.write_shutdown {
1351 return Err(SocketTableError::broken_pipe(format!(
1352 "socket {socket_id} write side is shut down"
1353 )));
1354 }
1355
1356 let peer_socket_id = connection.peer_socket_id.ok_or_else(|| {
1357 SocketTableError::broken_pipe(format!("socket {socket_id} peer is closed"))
1358 })?;
1359 let peer = table.sockets.get_mut(&peer_socket_id).ok_or_else(|| {
1360 SocketTableError::broken_pipe(format!("socket {socket_id} peer is closed"))
1361 })?;
1362 let peer_connection = peer.connection_state.as_mut().ok_or_else(|| {
1363 SocketTableError::broken_pipe(format!("socket {socket_id} peer is closed"))
1364 })?;
1365 if peer_connection.read_shutdown {
1366 return Err(SocketTableError::broken_pipe(format!(
1367 "socket {peer_socket_id} read side is shut down"
1368 )));
1369 }
1370
1371 peer_connection.recv_buffer.extend(data.iter().copied());
1372 Ok(data.len())
1373 }
1374
1375 pub fn check_write(&self, socket_id: SocketId) -> SocketResult<()> {
1376 let table = lock_or_recover(&self.inner.state);
1377 let record = table
1378 .sockets
1379 .get(&socket_id)
1380 .ok_or_else(|| SocketTableError::not_found(socket_id))?;
1381 let connection = record.connection_state.as_ref().ok_or_else(|| {
1382 SocketTableError::not_connected(format!("socket {socket_id} is not connected"))
1383 })?;
1384 if record.state != SocketState::Connected {
1385 return Err(SocketTableError::not_connected(format!(
1386 "socket {socket_id} is not connected"
1387 )));
1388 }
1389 if connection.write_shutdown {
1390 return Err(SocketTableError::broken_pipe(format!(
1391 "socket {socket_id} write side is shut down"
1392 )));
1393 }
1394
1395 let peer_socket_id = connection.peer_socket_id.ok_or_else(|| {
1396 SocketTableError::broken_pipe(format!("socket {socket_id} peer is closed"))
1397 })?;
1398 let peer = table.sockets.get(&peer_socket_id).ok_or_else(|| {
1399 SocketTableError::broken_pipe(format!("socket {socket_id} peer is closed"))
1400 })?;
1401 let peer_connection = peer.connection_state.as_ref().ok_or_else(|| {
1402 SocketTableError::broken_pipe(format!("socket {socket_id} peer is closed"))
1403 })?;
1404 if peer_connection.read_shutdown {
1405 return Err(SocketTableError::broken_pipe(format!(
1406 "socket {peer_socket_id} read side is shut down"
1407 )));
1408 }
1409
1410 Ok(())
1411 }
1412
1413 pub fn read(&self, socket_id: SocketId, max_bytes: usize) -> SocketResult<Option<Vec<u8>>> {
1414 if max_bytes == 0 {
1415 return Ok(Some(Vec::new()));
1416 }
1417
1418 let mut table = lock_or_recover(&self.inner.state);
1419 let record = table
1420 .sockets
1421 .get(&socket_id)
1422 .cloned()
1423 .ok_or_else(|| SocketTableError::not_found(socket_id))?;
1424 if record.state != SocketState::Connected {
1425 return Err(SocketTableError::not_connected(format!(
1426 "socket {socket_id} is not connected"
1427 )));
1428 }
1429
1430 let connection = record.connection_state.as_ref().ok_or_else(|| {
1431 SocketTableError::not_connected(format!("socket {socket_id} is not connected"))
1432 })?;
1433 if connection.read_shutdown {
1434 return Ok(None);
1435 }
1436 if !connection.recv_buffer.is_empty() {
1437 let record = table
1438 .sockets
1439 .get_mut(&socket_id)
1440 .ok_or_else(|| SocketTableError::not_found(socket_id))?;
1441 let connection = record.connection_state.as_mut().ok_or_else(|| {
1442 SocketTableError::not_connected(format!("socket {socket_id} is not connected"))
1443 })?;
1444 let read_len = connection.recv_buffer.len().min(max_bytes);
1445 let bytes = connection.recv_buffer.drain(..read_len).collect::<Vec<_>>();
1446 return Ok(Some(bytes));
1447 }
1448
1449 let peer_open = connection
1450 .peer_socket_id
1451 .map(|peer_socket_id| table.sockets.contains_key(&peer_socket_id))
1452 .unwrap_or(false);
1453 if connection.peer_write_shutdown || !peer_open {
1454 return Ok(None);
1455 }
1456
1457 Err(SocketTableError::would_block(format!(
1458 "socket {socket_id} has no readable data"
1459 )))
1460 }
1461
1462 pub fn shutdown(&self, socket_id: SocketId, how: SocketShutdown) -> SocketResult<SocketRecord> {
1463 let mut table = lock_or_recover(&self.inner.state);
1464 let record = table
1465 .sockets
1466 .remove(&socket_id)
1467 .ok_or_else(|| SocketTableError::not_found(socket_id))?;
1468
1469 if record.state != SocketState::Connected {
1470 table.sockets.insert(socket_id, record);
1471 return Err(SocketTableError::not_connected(format!(
1472 "socket {socket_id} is not connected"
1473 )));
1474 }
1475
1476 let Some(mut connection) = record.connection_state.clone() else {
1477 table.sockets.insert(socket_id, record);
1478 return Err(SocketTableError::not_connected(format!(
1479 "socket {socket_id} is not connected"
1480 )));
1481 };
1482
1483 if matches!(how, SocketShutdown::Read | SocketShutdown::Both) {
1484 connection.recv_buffer.clear();
1485 connection.read_shutdown = true;
1486 }
1487 if matches!(how, SocketShutdown::Write | SocketShutdown::Both) {
1488 connection.write_shutdown = true;
1489 if let Some(peer_socket_id) = connection.peer_socket_id {
1490 if let Some(peer) = table.sockets.get_mut(&peer_socket_id) {
1491 if let Some(peer_connection) = peer.connection_state.as_mut() {
1492 peer_connection.peer_write_shutdown = true;
1493 }
1494 }
1495 }
1496 }
1497
1498 let mut record = record;
1499 record.connection_state = Some(connection);
1500 let cloned = record.clone();
1501 table.sockets.insert(socket_id, record);
1502 Ok(cloned)
1503 }
1504
1505 pub fn remove(&self, socket_id: SocketId) -> SocketResult<SocketRecord> {
1506 let mut table = lock_or_recover(&self.inner.state);
1507 remove_socket(&mut table, socket_id).ok_or_else(|| SocketTableError::not_found(socket_id))
1508 }
1509
1510 pub fn remove_all_for_pid(&self, owner_pid: u32) -> Vec<SocketRecord> {
1511 let mut table = lock_or_recover(&self.inner.state);
1512 let Some(socket_ids) = table.by_owner.remove(&owner_pid) else {
1513 return Vec::new();
1514 };
1515
1516 socket_ids
1517 .into_iter()
1518 .filter_map(|socket_id| remove_socket(&mut table, socket_id))
1519 .collect()
1520 }
1521
1522 pub fn snapshot(&self) -> SocketTableSnapshot {
1523 let table = lock_or_recover(&self.inner.state);
1524 let mut snapshot = SocketTableSnapshot {
1525 sockets: table.sockets.len(),
1526 ..SocketTableSnapshot::default()
1527 };
1528 for record in table.sockets.values() {
1529 if record.state.counts_as_listener() {
1530 snapshot.listeners += 1;
1531 }
1532 if record.state.counts_as_connection() {
1533 snapshot.connections += 1;
1534 }
1535 if let Some(connection) = &record.connection_state {
1536 snapshot.buffered_bytes = snapshot
1537 .buffered_bytes
1538 .saturating_add(connection.recv_buffer.len());
1539 }
1540 if let Some(datagram_state) = &record.datagram_state {
1541 snapshot.datagram_queue_len = snapshot
1542 .datagram_queue_len
1543 .saturating_add(datagram_state.recv_queue.len());
1544 snapshot.buffered_bytes = snapshot
1545 .buffered_bytes
1546 .saturating_add(datagram_queue_bytes(&datagram_state.recv_queue));
1547 }
1548 }
1549 snapshot
1550 }
1551}
1552
1553fn datagram_queue_bytes(queue: &VecDeque<QueuedDatagram>) -> usize {
1554 queue
1555 .iter()
1556 .map(|datagram| datagram.payload.len())
1557 .sum::<usize>()
1558}
1559
1560fn next_socket_id(table: &mut SocketTableState) -> SocketId {
1561 if table.next_socket_id == 0 {
1562 table.next_socket_id = 1;
1563 }
1564 let socket_id = table.next_socket_id;
1565 table.next_socket_id = table.next_socket_id.saturating_add(1);
1566 socket_id
1567}
1568
1569fn validate_state_transition(current: SocketState, next: SocketState) -> SocketResult<()> {
1570 if current == SocketState::Connected && next != SocketState::Connected {
1571 return Err(SocketTableError::invalid_argument(format!(
1572 "invalid socket state transition from {current:?} to {next:?}"
1573 )));
1574 }
1575 Ok(())
1576}
1577
1578fn validate_connect_pair(socket: &SocketRecord, peer: &SocketRecord) -> SocketResult<()> {
1579 if !supports_connection_lifecycle(socket.spec) {
1580 return Err(SocketTableError::invalid_argument(format!(
1581 "socket {} does not support stream connections",
1582 socket.id
1583 )));
1584 }
1585 if !supports_connection_lifecycle(peer.spec) {
1586 return Err(SocketTableError::invalid_argument(format!(
1587 "socket {} does not support stream connections",
1588 peer.id
1589 )));
1590 }
1591 if !matches!(socket.state, SocketState::Created | SocketState::Bound) {
1592 return Err(SocketTableError::invalid_argument(format!(
1593 "socket {} cannot connect in state {:?}",
1594 socket.id, socket.state
1595 )));
1596 }
1597 if !matches!(peer.state, SocketState::Created | SocketState::Bound) {
1598 return Err(SocketTableError::invalid_argument(format!(
1599 "socket {} cannot connect in state {:?}",
1600 peer.id, peer.state
1601 )));
1602 }
1603 Ok(())
1604}
1605
1606fn default_connection_state(spec: SocketSpec, state: SocketState) -> Option<ConnectionState> {
1607 if state == SocketState::Connected && supports_connection_lifecycle(spec) {
1608 Some(ConnectionState::default())
1609 } else {
1610 None
1611 }
1612}
1613
1614fn default_datagram_state(spec: SocketSpec) -> Option<DatagramState> {
1615 if supports_inet_datagram_lifecycle(spec) {
1616 Some(DatagramState::default())
1617 } else {
1618 None
1619 }
1620}
1621
1622fn supports_connection_lifecycle(spec: SocketSpec) -> bool {
1623 matches!(spec.socket_type, SocketType::Stream)
1624}
1625
1626fn supports_listener_lifecycle(spec: SocketSpec) -> bool {
1627 matches!(spec.socket_type, SocketType::Stream)
1628 && matches!(
1629 spec.domain,
1630 SocketDomain::Inet | SocketDomain::Inet6 | SocketDomain::Unix
1631 )
1632}
1633
1634fn supports_inet_bind(spec: SocketSpec) -> bool {
1635 matches!(spec.domain, SocketDomain::Inet | SocketDomain::Inet6)
1636 && matches!(spec.socket_type, SocketType::Stream | SocketType::Datagram)
1637}
1638
1639fn supports_unix_stream_lifecycle(spec: SocketSpec) -> bool {
1640 matches!(spec.socket_type, SocketType::Stream) && matches!(spec.domain, SocketDomain::Unix)
1641}
1642
1643fn supports_inet_stream_lifecycle(spec: SocketSpec) -> bool {
1644 matches!(spec.socket_type, SocketType::Stream)
1645 && matches!(spec.domain, SocketDomain::Inet | SocketDomain::Inet6)
1646}
1647
1648fn supports_inet_datagram_lifecycle(spec: SocketSpec) -> bool {
1649 matches!(spec.socket_type, SocketType::Datagram)
1650 && matches!(spec.domain, SocketDomain::Inet | SocketDomain::Inet6)
1651}
1652
1653fn lookup_conflicting_bound_inet_socket_ids(
1654 table: &SocketTableState,
1655 spec: SocketSpec,
1656 address: &InetSocketAddress,
1657) -> Vec<SocketId> {
1658 if supports_inet_stream_lifecycle(spec) {
1659 table
1660 .bound_inet_streams
1661 .iter()
1662 .find_map(|(bound_address, socket_id)| {
1663 inet_stream_bind_addresses_overlap(bound_address, address).then_some(*socket_id)
1664 })
1665 .into_iter()
1666 .collect()
1667 } else if supports_inet_datagram_lifecycle(spec) {
1668 table
1669 .bound_inet_datagrams
1670 .iter()
1671 .filter(|(bound_address, _)| inet_stream_bind_addresses_overlap(bound_address, address))
1672 .flat_map(|(_, socket_ids)| socket_ids.iter().copied())
1673 .collect()
1674 } else {
1675 Vec::new()
1676 }
1677}
1678
1679fn lookup_bound_inet_socket(
1680 table: &SocketTableState,
1681 spec: SocketSpec,
1682 address: &InetSocketAddress,
1683) -> Option<SocketId> {
1684 if supports_inet_stream_lifecycle(spec) {
1685 lookup_bound_inet_socket_in_table(&table.bound_inet_streams, address)
1686 } else if supports_inet_datagram_lifecycle(spec) {
1687 lookup_bound_inet_datagram_socket_in_table(&table.bound_inet_datagrams, address)
1688 } else {
1689 None
1690 }
1691}
1692
1693fn inet_stream_bind_addresses_overlap(
1694 existing: &InetSocketAddress,
1695 requested: &InetSocketAddress,
1696) -> bool {
1697 if existing == requested {
1698 return true;
1699 }
1700
1701 wildcard_inet_address(existing).as_ref() == Some(requested)
1702 || wildcard_inet_address(requested).as_ref() == Some(existing)
1703}
1704
1705fn lookup_bound_inet_socket_in_table(
1706 sockets: &BTreeMap<InetSocketAddress, SocketId>,
1707 address: &InetSocketAddress,
1708) -> Option<SocketId> {
1709 sockets.get(address).copied().or_else(|| {
1710 wildcard_inet_address(address).and_then(|wildcard| sockets.get(&wildcard).copied())
1711 })
1712}
1713
1714fn lookup_bound_inet_datagram_socket_in_table(
1715 sockets: &BTreeMap<InetSocketAddress, BTreeSet<SocketId>>,
1716 address: &InetSocketAddress,
1717) -> Option<SocketId> {
1718 sockets
1719 .get(address)
1720 .and_then(|socket_ids| socket_ids.first().copied())
1721 .or_else(|| {
1722 wildcard_inet_address(address).and_then(|wildcard| {
1723 sockets
1724 .get(&wildcard)
1725 .and_then(|socket_ids| socket_ids.first().copied())
1726 })
1727 })
1728}
1729
1730fn register_bound_inet_socket(
1731 table: &mut SocketTableState,
1732 spec: SocketSpec,
1733 address: InetSocketAddress,
1734 socket_id: SocketId,
1735) {
1736 if supports_inet_stream_lifecycle(spec) {
1737 table.bound_inet_streams.insert(address, socket_id);
1738 } else if supports_inet_datagram_lifecycle(spec) {
1739 table
1740 .bound_inet_datagrams
1741 .entry(address)
1742 .or_default()
1743 .insert(socket_id);
1744 }
1745}
1746
1747fn validate_connect_to_listener(
1748 client: &SocketRecord,
1749 listener: &SocketRecord,
1750) -> SocketResult<()> {
1751 if !supports_connection_lifecycle(client.spec) {
1752 return Err(SocketTableError::invalid_argument(format!(
1753 "socket {} does not support stream connections",
1754 client.id
1755 )));
1756 }
1757 if !supports_listener_lifecycle(listener.spec) {
1758 return Err(SocketTableError::invalid_argument(format!(
1759 "socket {} is not a stream listener",
1760 listener.id
1761 )));
1762 }
1763 if !matches!(client.state, SocketState::Created | SocketState::Bound) {
1764 return Err(SocketTableError::invalid_argument(format!(
1765 "socket {} cannot connect in state {:?}",
1766 client.id, client.state
1767 )));
1768 }
1769 if listener.state != SocketState::Listening {
1770 return Err(SocketTableError::invalid_argument(format!(
1771 "socket {} is not listening",
1772 listener.id
1773 )));
1774 }
1775 Ok(())
1776}
1777
1778fn has_bound_endpoint(record: &SocketRecord) -> bool {
1779 record.local_address.is_some() || record.local_unix_path.is_some()
1780}
1781
1782fn validate_bound_udp_sender(sender: &SocketRecord) -> SocketResult<()> {
1783 if !supports_inet_datagram_lifecycle(sender.spec) {
1784 return Err(SocketTableError::invalid_argument(format!(
1785 "socket {} is not an INET datagram socket",
1786 sender.id
1787 )));
1788 }
1789 if sender.state != SocketState::Bound || sender.local_address.is_none() {
1790 return Err(SocketTableError::invalid_argument(format!(
1791 "socket {} must be bound before sending datagrams",
1792 sender.id
1793 )));
1794 }
1795 Ok(())
1796}
1797
1798fn validate_bound_udp_receiver(receiver: &SocketRecord) -> SocketResult<()> {
1799 if !supports_inet_datagram_lifecycle(receiver.spec) {
1800 return Err(SocketTableError::invalid_argument(format!(
1801 "socket {} is not an INET datagram socket",
1802 receiver.id
1803 )));
1804 }
1805 if receiver.state != SocketState::Bound || receiver.local_address.is_none() {
1806 return Err(SocketTableError::invalid_argument(format!(
1807 "socket {} must be bound to receive datagrams",
1808 receiver.id
1809 )));
1810 }
1811 Ok(())
1812}
1813
1814fn datagram_state_mut(record: &mut SocketRecord) -> SocketResult<&mut DatagramState> {
1815 if !supports_inet_datagram_lifecycle(record.spec) {
1816 return Err(SocketTableError::invalid_argument(format!(
1817 "socket {} is not an INET datagram socket",
1818 record.id
1819 )));
1820 }
1821 record.datagram_state.as_mut().ok_or_else(|| {
1822 SocketTableError::invalid_argument(format!(
1823 "socket {} does not support datagrams",
1824 record.id
1825 ))
1826 })
1827}
1828
1829fn validate_multicast_socket(record: &SocketRecord) -> SocketResult<()> {
1830 validate_bound_udp_receiver(record)?;
1831 if record.spec.domain != SocketDomain::Inet {
1832 return Err(SocketTableError::invalid_argument(format!(
1833 "socket {} multicast membership is only implemented for IPv4 datagrams",
1834 record.id
1835 )));
1836 }
1837 Ok(())
1838}
1839
1840fn normalize_multicast_membership(
1841 spec: SocketSpec,
1842 membership: SocketMulticastMembership,
1843) -> SocketResult<SocketMulticastMembership> {
1844 let group_address = membership.group_address.trim().to_ascii_lowercase();
1845 let interface_address = membership
1846 .interface_address
1847 .map(|value| value.trim().to_ascii_lowercase())
1848 .filter(|value| !value.is_empty());
1849
1850 match spec.domain {
1851 SocketDomain::Inet => {
1852 let parsed = group_address.parse::<Ipv4Addr>().map_err(|_| {
1853 SocketTableError::invalid_argument(format!(
1854 "invalid IPv4 multicast address {group_address}"
1855 ))
1856 })?;
1857 if !parsed.is_multicast() {
1858 return Err(SocketTableError::invalid_argument(format!(
1859 "address {group_address} is not an IPv4 multicast group"
1860 )));
1861 }
1862 }
1863 SocketDomain::Inet6 => {
1864 let parsed = group_address.parse::<Ipv6Addr>().map_err(|_| {
1865 SocketTableError::invalid_argument(format!(
1866 "invalid IPv6 multicast address {group_address}"
1867 ))
1868 })?;
1869 if !parsed.is_multicast() {
1870 return Err(SocketTableError::invalid_argument(format!(
1871 "address {group_address} is not an IPv6 multicast group"
1872 )));
1873 }
1874 }
1875 SocketDomain::Unix => {
1876 return Err(SocketTableError::invalid_argument(
1877 "unix sockets do not support multicast membership",
1878 ));
1879 }
1880 }
1881
1882 Ok(SocketMulticastMembership::new(
1883 group_address,
1884 interface_address,
1885 ))
1886}
1887
1888fn has_incompatible_inet_bind_conflict(
1889 table: &SocketTableState,
1890 record: &SocketRecord,
1891 conflicting_ids: &[SocketId],
1892) -> bool {
1893 conflicting_ids.iter().any(|conflicting_id| {
1894 if *conflicting_id == record.id {
1895 return false;
1896 }
1897
1898 let Some(existing) = table.sockets.get(conflicting_id) else {
1899 return false;
1900 };
1901
1902 if supports_inet_datagram_lifecycle(record.spec) {
1903 !inet_datagram_bind_shares_port(record, existing)
1904 } else {
1905 true
1906 }
1907 })
1908}
1909
1910fn inet_datagram_bind_shares_port(requested: &SocketRecord, existing: &SocketRecord) -> bool {
1911 (requested.reuse_port() && existing.reuse_port())
1912 || (requested.reuse_address() && existing.reuse_address())
1913}
1914
1915fn remove_socket(table: &mut SocketTableState, socket_id: SocketId) -> Option<SocketRecord> {
1916 let record = table.sockets.remove(&socket_id)?;
1917 unregister_bound_socket(table, &record);
1918 unregister_multicast_memberships(table, &record);
1919 if let Some(listener_state) = record.listener_state.as_ref() {
1920 let pending_socket_ids = listener_state
1921 .pending_accepts
1922 .iter()
1923 .filter_map(|pending| pending.accepted_socket_id)
1924 .collect::<Vec<_>>();
1925 for pending_socket_id in pending_socket_ids {
1926 let _ = remove_socket(table, pending_socket_id);
1927 }
1928 }
1929 if let Some(connection) = record.connection_state.as_ref() {
1930 if let Some(peer_socket_id) = connection.peer_socket_id {
1931 if let Some(peer) = table.sockets.get_mut(&peer_socket_id) {
1932 if let Some(peer_connection) = peer.connection_state.as_mut() {
1933 if peer_connection.peer_socket_id == Some(socket_id) {
1934 peer_connection.peer_socket_id = None;
1935 }
1936 peer_connection.peer_write_shutdown = true;
1937 }
1938 }
1939 }
1940 }
1941 if let Some(owner_sockets) = table.by_owner.get_mut(&record.owner_pid) {
1942 owner_sockets.remove(&socket_id);
1943 if owner_sockets.is_empty() {
1944 table.by_owner.remove(&record.owner_pid);
1945 }
1946 }
1947 Some(record)
1948}
1949
1950fn unregister_bound_socket(table: &mut SocketTableState, record: &SocketRecord) {
1951 let Some(address) = record.local_address.as_ref() else {
1952 if supports_unix_stream_lifecycle(record.spec) {
1953 if let Some(path) = record.local_unix_path.as_ref() {
1954 if table.bound_unix_streams.get(path).copied() == Some(record.id) {
1955 table.bound_unix_streams.remove(path);
1956 }
1957 }
1958 }
1959 return;
1960 };
1961 if supports_inet_stream_lifecycle(record.spec)
1962 && table.bound_inet_streams.get(address).copied() == Some(record.id)
1963 {
1964 table.bound_inet_streams.remove(address);
1965 }
1966 if supports_inet_datagram_lifecycle(record.spec) {
1967 if let Some(socket_ids) = table.bound_inet_datagrams.get_mut(address) {
1968 socket_ids.remove(&record.id);
1969 if socket_ids.is_empty() {
1970 table.bound_inet_datagrams.remove(address);
1971 }
1972 }
1973 }
1974}
1975
1976fn unregister_multicast_memberships(table: &mut SocketTableState, record: &SocketRecord) {
1977 let Some(datagram_state) = record.datagram_state.as_ref() else {
1978 return;
1979 };
1980
1981 for membership in &datagram_state.multicast_memberships {
1982 if let Some(socket_ids) = table.multicast_groups.get_mut(membership) {
1983 socket_ids.remove(&record.id);
1984 if socket_ids.is_empty() {
1985 table.multicast_groups.remove(membership);
1986 }
1987 }
1988 }
1989}
1990
1991fn normalize_inet_address(address: InetSocketAddress) -> InetSocketAddress {
1992 match address.host().to_ascii_lowercase().as_str() {
1993 "localhost" => InetSocketAddress::new("127.0.0.1", address.port()),
1994 _ => address,
1995 }
1996}
1997
1998fn wildcard_inet_address(address: &InetSocketAddress) -> Option<InetSocketAddress> {
1999 match address.host() {
2000 "127.0.0.1" => Some(InetSocketAddress::new("0.0.0.0", address.port())),
2001 "::1" => Some(InetSocketAddress::new("::", address.port())),
2002 _ => None,
2003 }
2004}
2005
2006fn normalize_unix_socket_path(path: impl AsRef<str>) -> SocketResult<String> {
2007 let normalized = normalize_path(path.as_ref());
2008 if normalized == "/" {
2009 return Err(SocketTableError::invalid_argument(
2010 "unix socket path must not be empty or root",
2011 ));
2012 }
2013 Ok(normalized)
2014}
2015
2016fn lock_or_recover<'a, T>(mutex: &'a Mutex<T>) -> MutexGuard<'a, T> {
2017 match mutex.lock() {
2018 Ok(guard) => guard,
2019 Err(poisoned) => poisoned.into_inner(),
2020 }
2021}
2022
2023#[cfg(test)]
2024mod tests {
2025 use super::*;
2026
2027 fn peek_next_socket_id(table: &SocketTable) -> SocketId {
2030 lock_or_recover(&table.inner.state).next_socket_id
2031 }
2032
2033 #[test]
2034 fn full_backlog_unix_connect_does_not_consume_socket_id() {
2035 let table = SocketTable::new();
2036 let path = "/tmp/leak-test/server.sock";
2037
2038 let listener = table.allocate(1, SocketSpec::unix_stream());
2039 table
2040 .bind_unix(listener.id, path)
2041 .expect("bind unix listener");
2042 table.listen(listener.id, 1).expect("listen with backlog 1");
2043
2044 let first = table.allocate(2, SocketSpec::unix_stream());
2046 table
2047 .connect_to_bound_unix_stream(first.id, path)
2048 .expect("first connect fills the backlog");
2049
2050 let second = table.allocate(2, SocketSpec::unix_stream());
2053 let before = peek_next_socket_id(&table);
2054 let error = table
2055 .connect_to_bound_unix_stream(second.id, path)
2056 .expect_err("full-backlog connect must fail");
2057 assert_eq!(error.code(), "EAGAIN");
2058 let after = peek_next_socket_id(&table);
2059
2060 assert_eq!(
2061 before, after,
2062 "full-backlog unix connect leaked a socket id (counter advanced from {before} to {after})"
2063 );
2064 }
2065
2066 #[test]
2067 fn full_backlog_inet_connect_does_not_consume_socket_id() {
2068 let table = SocketTable::new();
2069 let target = InetSocketAddress::new("127.0.0.1", 49222);
2070
2071 let listener = table.allocate(1, SocketSpec::tcp());
2072 table
2073 .bind_inet(listener.id, target.clone())
2074 .expect("bind inet listener");
2075 table.listen(listener.id, 1).expect("listen with backlog 1");
2076
2077 let first = table.allocate(2, SocketSpec::tcp());
2079 table
2080 .connect_to_bound_inet_stream(first.id, target.clone())
2081 .expect("first connect fills the backlog");
2082
2083 let second = table.allocate(2, SocketSpec::tcp());
2086 let before = peek_next_socket_id(&table);
2087 let error = table
2088 .connect_to_bound_inet_stream(second.id, target)
2089 .expect_err("full-backlog connect must fail");
2090 assert_eq!(error.code(), "EAGAIN");
2091 let after = peek_next_socket_id(&table);
2092
2093 assert_eq!(
2094 before, after,
2095 "full-backlog inet connect leaked a socket id (counter advanced from {before} to {after})"
2096 );
2097 }
2098}