Skip to main content

secure_exec_kernel/
socket_table.rs

1use crate::poll::{PollEvents, POLLERR, POLLHUP, POLLIN, POLLOUT};
2use crate::vfs::normalize_path;
3use std::collections::{BTreeMap, BTreeSet, VecDeque};
4use std::error::Error;
5use std::fmt;
6use std::net::{Ipv4Addr, Ipv6Addr};
7use std::sync::{Arc, Mutex, MutexGuard};
8
9pub type SocketId = u64;
10pub type SocketResult<T> = Result<T, SocketTableError>;
11
12#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord)]
13pub struct InetSocketAddress {
14    host: String,
15    port: u16,
16}
17
18impl InetSocketAddress {
19    pub fn new(host: impl Into<String>, port: u16) -> Self {
20        Self {
21            host: host.into(),
22            port,
23        }
24    }
25
26    pub fn host(&self) -> &str {
27        &self.host
28    }
29
30    pub const fn port(&self) -> u16 {
31        self.port
32    }
33}
34
35#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
36pub enum SocketDomain {
37    Inet,
38    Inet6,
39    Unix,
40}
41
42#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
43pub enum SocketType {
44    Stream,
45    Datagram,
46}
47
48#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
49pub enum SocketState {
50    Created,
51    Bound,
52    Listening,
53    Connected,
54}
55
56impl SocketState {
57    pub const fn counts_as_listener(self) -> bool {
58        matches!(self, Self::Listening)
59    }
60
61    pub const fn counts_as_connection(self) -> bool {
62        matches!(self, Self::Connected)
63    }
64}
65
66#[derive(Debug, Clone, Copy, PartialEq, Eq)]
67pub enum SocketShutdown {
68    Read,
69    Write,
70    Both,
71}
72
73#[derive(Debug, Clone, Copy, PartialEq, Eq)]
74pub enum DatagramSocketOption {
75    ReuseAddr,
76    ReusePort,
77    Broadcast,
78}
79
80#[derive(Debug, Clone, Copy, PartialEq, Eq)]
81pub struct SocketSpec {
82    pub domain: SocketDomain,
83    pub socket_type: SocketType,
84}
85
86impl SocketSpec {
87    pub const fn new(domain: SocketDomain, socket_type: SocketType) -> Self {
88        Self {
89            domain,
90            socket_type,
91        }
92    }
93
94    pub const fn tcp() -> Self {
95        Self::new(SocketDomain::Inet, SocketType::Stream)
96    }
97
98    pub const fn udp() -> Self {
99        Self::new(SocketDomain::Inet, SocketType::Datagram)
100    }
101
102    pub const fn unix_stream() -> Self {
103        Self::new(SocketDomain::Unix, SocketType::Stream)
104    }
105
106    pub const fn unix_datagram() -> Self {
107        Self::new(SocketDomain::Unix, SocketType::Datagram)
108    }
109}
110
111#[derive(Debug, Clone, PartialEq, Eq)]
112pub struct SocketRecord {
113    id: SocketId,
114    owner_pid: u32,
115    spec: SocketSpec,
116    state: SocketState,
117    local_address: Option<InetSocketAddress>,
118    peer_address: Option<InetSocketAddress>,
119    local_unix_path: Option<String>,
120    peer_unix_path: Option<String>,
121    listener_state: Option<ListenerState>,
122    connection_state: Option<ConnectionState>,
123    datagram_state: Option<DatagramState>,
124}
125
126impl SocketRecord {
127    pub const fn id(&self) -> SocketId {
128        self.id
129    }
130
131    pub const fn owner_pid(&self) -> u32 {
132        self.owner_pid
133    }
134
135    pub const fn spec(&self) -> SocketSpec {
136        self.spec
137    }
138
139    pub const fn state(&self) -> SocketState {
140        self.state
141    }
142
143    pub fn local_address(&self) -> Option<&InetSocketAddress> {
144        self.local_address.as_ref()
145    }
146
147    pub fn peer_address(&self) -> Option<&InetSocketAddress> {
148        self.peer_address.as_ref()
149    }
150
151    pub fn local_unix_path(&self) -> Option<&str> {
152        self.local_unix_path.as_deref()
153    }
154
155    pub fn peer_unix_path(&self) -> Option<&str> {
156        self.peer_unix_path.as_deref()
157    }
158
159    pub fn listen_backlog(&self) -> Option<usize> {
160        self.listener_state.as_ref().map(|state| state.backlog)
161    }
162
163    pub fn pending_accept_count(&self) -> usize {
164        self.listener_state
165            .as_ref()
166            .map(|state| state.pending_accepts.len())
167            .unwrap_or(0)
168    }
169
170    pub fn peer_socket_id(&self) -> Option<SocketId> {
171        self.connection_state
172            .as_ref()
173            .and_then(|state| state.peer_socket_id)
174    }
175
176    pub fn buffered_read_bytes(&self) -> usize {
177        self.connection_state
178            .as_ref()
179            .map(|state| state.recv_buffer.len())
180            .unwrap_or(0)
181    }
182
183    pub fn read_shutdown(&self) -> bool {
184        self.connection_state
185            .as_ref()
186            .map(|state| state.read_shutdown)
187            .unwrap_or(false)
188    }
189
190    pub fn write_shutdown(&self) -> bool {
191        self.connection_state
192            .as_ref()
193            .map(|state| state.write_shutdown)
194            .unwrap_or(false)
195    }
196
197    pub fn peer_write_shutdown(&self) -> bool {
198        self.connection_state
199            .as_ref()
200            .map(|state| state.peer_write_shutdown)
201            .unwrap_or(false)
202    }
203
204    pub fn queued_datagrams(&self) -> usize {
205        self.datagram_state
206            .as_ref()
207            .map(|state| state.recv_queue.len())
208            .unwrap_or(0)
209    }
210
211    pub fn queued_datagram_bytes(&self) -> usize {
212        self.datagram_state
213            .as_ref()
214            .map(|state| datagram_queue_bytes(&state.recv_queue))
215            .unwrap_or(0)
216    }
217
218    pub fn reuse_address(&self) -> bool {
219        self.datagram_state
220            .as_ref()
221            .map(|state| state.reuse_addr)
222            .unwrap_or(false)
223    }
224
225    pub fn reuse_port(&self) -> bool {
226        self.datagram_state
227            .as_ref()
228            .map(|state| state.reuse_port)
229            .unwrap_or(false)
230    }
231
232    pub fn broadcast_enabled(&self) -> bool {
233        self.datagram_state
234            .as_ref()
235            .map(|state| state.broadcast)
236            .unwrap_or(false)
237    }
238
239    pub fn multicast_membership_count(&self) -> usize {
240        self.datagram_state
241            .as_ref()
242            .map(|state| state.multicast_memberships.len())
243            .unwrap_or(0)
244    }
245
246    pub fn has_multicast_membership(&self, membership: &SocketMulticastMembership) -> bool {
247        self.datagram_state
248            .as_ref()
249            .map(|state| state.multicast_memberships.contains(membership))
250            .unwrap_or(false)
251    }
252}
253
254#[derive(Debug, Clone, PartialEq, Eq)]
255pub struct ReceivedDatagram {
256    source_address: Option<InetSocketAddress>,
257    payload: Vec<u8>,
258}
259
260impl ReceivedDatagram {
261    pub fn source_address(&self) -> Option<&InetSocketAddress> {
262        self.source_address.as_ref()
263    }
264
265    pub fn payload(&self) -> &[u8] {
266        &self.payload
267    }
268
269    pub fn into_parts(self) -> (Option<InetSocketAddress>, Vec<u8>) {
270        (self.source_address, self.payload)
271    }
272}
273
274#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
275pub struct SocketTableSnapshot {
276    pub sockets: usize,
277    pub listeners: usize,
278    pub connections: usize,
279    pub buffered_bytes: usize,
280    pub datagram_queue_len: usize,
281}
282
283#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord)]
284pub struct SocketMulticastMembership {
285    group_address: String,
286    interface_address: Option<String>,
287}
288
289impl SocketMulticastMembership {
290    pub fn new(group_address: impl Into<String>, interface_address: Option<String>) -> Self {
291        Self {
292            group_address: group_address.into(),
293            interface_address,
294        }
295    }
296
297    pub fn group_address(&self) -> &str {
298        &self.group_address
299    }
300
301    pub fn interface_address(&self) -> Option<&str> {
302        self.interface_address.as_deref()
303    }
304}
305
306#[derive(Debug, Clone, PartialEq, Eq)]
307pub struct SocketTableError {
308    code: &'static str,
309    message: String,
310}
311
312impl SocketTableError {
313    pub fn code(&self) -> &'static str {
314        self.code
315    }
316
317    fn not_found(socket_id: SocketId) -> Self {
318        Self {
319            code: "ENOENT",
320            message: format!("no such socket {socket_id}"),
321        }
322    }
323
324    fn invalid_argument(message: impl Into<String>) -> Self {
325        Self {
326            code: "EINVAL",
327            message: message.into(),
328        }
329    }
330
331    fn address_in_use(message: impl Into<String>) -> Self {
332        Self {
333            code: "EADDRINUSE",
334            message: message.into(),
335        }
336    }
337
338    fn address_not_available(message: impl Into<String>) -> Self {
339        Self {
340            code: "EADDRNOTAVAIL",
341            message: message.into(),
342        }
343    }
344
345    fn not_found_address(message: impl Into<String>) -> Self {
346        Self {
347            code: "ECONNREFUSED",
348            message: message.into(),
349        }
350    }
351
352    fn would_block(message: impl Into<String>) -> Self {
353        Self {
354            code: "EAGAIN",
355            message: message.into(),
356        }
357    }
358
359    fn not_connected(message: impl Into<String>) -> Self {
360        Self {
361            code: "ENOTCONN",
362            message: message.into(),
363        }
364    }
365
366    fn broken_pipe(message: impl Into<String>) -> Self {
367        Self {
368            code: "EPIPE",
369            message: message.into(),
370        }
371    }
372}
373
374impl fmt::Display for SocketTableError {
375    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
376        write!(f, "{}: {}", self.code, self.message)
377    }
378}
379
380impl Error for SocketTableError {}
381
382#[derive(Debug, Default)]
383struct SocketTableState {
384    sockets: BTreeMap<SocketId, SocketRecord>,
385    by_owner: BTreeMap<u32, BTreeSet<SocketId>>,
386    bound_inet_streams: BTreeMap<InetSocketAddress, SocketId>,
387    bound_inet_datagrams: BTreeMap<InetSocketAddress, BTreeSet<SocketId>>,
388    bound_unix_streams: BTreeMap<String, SocketId>,
389    multicast_groups: BTreeMap<SocketMulticastMembership, BTreeSet<SocketId>>,
390    next_socket_id: SocketId,
391}
392
393#[derive(Debug, Clone, PartialEq, Eq)]
394struct ListenerState {
395    backlog: usize,
396    pending_accepts: VecDeque<PendingConnection>,
397}
398
399#[derive(Debug, Clone, PartialEq, Eq, Default)]
400struct ConnectionState {
401    peer_socket_id: Option<SocketId>,
402    recv_buffer: VecDeque<u8>,
403    read_shutdown: bool,
404    write_shutdown: bool,
405    peer_write_shutdown: bool,
406}
407
408#[derive(Debug, Clone, PartialEq, Eq)]
409struct PendingConnection {
410    peer_address: Option<InetSocketAddress>,
411    peer_unix_path: Option<String>,
412    accepted_socket_id: Option<SocketId>,
413}
414
415#[derive(Debug, Clone, PartialEq, Eq, Default)]
416struct DatagramState {
417    recv_queue: VecDeque<QueuedDatagram>,
418    reuse_addr: bool,
419    reuse_port: bool,
420    broadcast: bool,
421    multicast_memberships: BTreeSet<SocketMulticastMembership>,
422}
423
424#[derive(Debug, Clone, PartialEq, Eq)]
425struct QueuedDatagram {
426    source_address: Option<InetSocketAddress>,
427    payload: Vec<u8>,
428}
429
430#[derive(Debug, Default)]
431struct SocketTableInner {
432    state: Mutex<SocketTableState>,
433}
434
435#[derive(Debug, Clone, Default)]
436pub struct SocketTable {
437    inner: Arc<SocketTableInner>,
438}
439
440impl SocketTable {
441    pub fn new() -> Self {
442        Self::default()
443    }
444
445    pub fn allocate(&self, owner_pid: u32, spec: SocketSpec) -> SocketRecord {
446        self.allocate_with_state(owner_pid, spec, SocketState::Created)
447    }
448
449    pub fn allocate_with_state(
450        &self,
451        owner_pid: u32,
452        spec: SocketSpec,
453        state: SocketState,
454    ) -> SocketRecord {
455        let mut table = lock_or_recover(&self.inner.state);
456        let socket_id = next_socket_id(&mut table);
457        let record = SocketRecord {
458            id: socket_id,
459            owner_pid,
460            spec,
461            state,
462            local_address: None,
463            peer_address: None,
464            local_unix_path: None,
465            peer_unix_path: None,
466            listener_state: None,
467            connection_state: default_connection_state(spec, state),
468            datagram_state: default_datagram_state(spec),
469        };
470        table.sockets.insert(socket_id, record.clone());
471        table
472            .by_owner
473            .entry(owner_pid)
474            .or_default()
475            .insert(socket_id);
476        record
477    }
478
479    pub fn get(&self, socket_id: SocketId) -> Option<SocketRecord> {
480        lock_or_recover(&self.inner.state)
481            .sockets
482            .get(&socket_id)
483            .cloned()
484    }
485
486    pub fn update_state(
487        &self,
488        socket_id: SocketId,
489        new_state: SocketState,
490    ) -> SocketResult<SocketRecord> {
491        let mut table = lock_or_recover(&self.inner.state);
492        let record = table
493            .sockets
494            .get_mut(&socket_id)
495            .ok_or_else(|| SocketTableError::not_found(socket_id))?;
496        validate_state_transition(record.state, new_state)?;
497        record.state = new_state;
498        if new_state != SocketState::Listening {
499            record.listener_state = None;
500        }
501        if new_state == SocketState::Connected && supports_connection_lifecycle(record.spec) {
502            record
503                .connection_state
504                .get_or_insert_with(ConnectionState::default);
505        } else if new_state != SocketState::Connected {
506            record.connection_state = None;
507        }
508        Ok(record.clone())
509    }
510
511    pub fn bind_inet(
512        &self,
513        socket_id: SocketId,
514        address: InetSocketAddress,
515    ) -> SocketResult<SocketRecord> {
516        let address = normalize_inet_address(address);
517        let mut table = lock_or_recover(&self.inner.state);
518        let existing = table
519            .sockets
520            .get(&socket_id)
521            .cloned()
522            .ok_or_else(|| SocketTableError::not_found(socket_id))?;
523        if !supports_inet_bind(existing.spec) {
524            return Err(SocketTableError::invalid_argument(format!(
525                "socket {socket_id} is not an INET socket"
526            )));
527        }
528        let conflicting_ids =
529            lookup_conflicting_bound_inet_socket_ids(&table, existing.spec, &address);
530        if has_incompatible_inet_bind_conflict(&table, &existing, &conflicting_ids) {
531            return Err(SocketTableError::address_in_use(format!(
532                "address {}:{} is already bound",
533                address.host(),
534                address.port()
535            )));
536        }
537        let cloned = {
538            let record = table
539                .sockets
540                .get_mut(&socket_id)
541                .ok_or_else(|| SocketTableError::not_found(socket_id))?;
542
543            match record.state {
544                SocketState::Created => {}
545                SocketState::Bound if record.local_address.as_ref() == Some(&address) => {
546                    return Ok(record.clone());
547                }
548                SocketState::Bound | SocketState::Listening | SocketState::Connected => {
549                    return Err(SocketTableError::invalid_argument(format!(
550                        "socket {socket_id} cannot bind in state {:?}",
551                        record.state
552                    )));
553                }
554            }
555
556            record.local_address = Some(address.clone());
557            record.peer_address = None;
558            record.local_unix_path = None;
559            record.peer_unix_path = None;
560            record.listener_state = None;
561            record.connection_state = None;
562            record.state = SocketState::Bound;
563            record.clone()
564        };
565        register_bound_inet_socket(&mut table, cloned.spec, address, socket_id);
566        Ok(cloned)
567    }
568
569    pub fn set_datagram_socket_option(
570        &self,
571        socket_id: SocketId,
572        option: DatagramSocketOption,
573        enabled: bool,
574    ) -> SocketResult<SocketRecord> {
575        let mut table = lock_or_recover(&self.inner.state);
576        let record = table
577            .sockets
578            .get_mut(&socket_id)
579            .ok_or_else(|| SocketTableError::not_found(socket_id))?;
580        let datagram_state = datagram_state_mut(record)?;
581
582        match option {
583            DatagramSocketOption::ReuseAddr => datagram_state.reuse_addr = enabled,
584            DatagramSocketOption::ReusePort => datagram_state.reuse_port = enabled,
585            DatagramSocketOption::Broadcast => datagram_state.broadcast = enabled,
586        }
587
588        Ok(record.clone())
589    }
590
591    pub fn add_multicast_membership(
592        &self,
593        socket_id: SocketId,
594        membership: SocketMulticastMembership,
595    ) -> SocketResult<SocketRecord> {
596        let mut table = lock_or_recover(&self.inner.state);
597        let normalized_membership = {
598            let record = table
599                .sockets
600                .get(&socket_id)
601                .ok_or_else(|| SocketTableError::not_found(socket_id))?;
602            validate_multicast_socket(record)?;
603            normalize_multicast_membership(record.spec, membership)?
604        };
605
606        let cloned = {
607            let record = table
608                .sockets
609                .get_mut(&socket_id)
610                .ok_or_else(|| SocketTableError::not_found(socket_id))?;
611            let datagram_state = datagram_state_mut(record)?;
612            datagram_state
613                .multicast_memberships
614                .insert(normalized_membership.clone());
615            record.clone()
616        };
617
618        table
619            .multicast_groups
620            .entry(normalized_membership)
621            .or_default()
622            .insert(socket_id);
623        Ok(cloned)
624    }
625
626    pub fn drop_multicast_membership(
627        &self,
628        socket_id: SocketId,
629        membership: SocketMulticastMembership,
630    ) -> SocketResult<SocketRecord> {
631        let mut table = lock_or_recover(&self.inner.state);
632        let normalized_membership = {
633            let record = table
634                .sockets
635                .get(&socket_id)
636                .ok_or_else(|| SocketTableError::not_found(socket_id))?;
637            validate_multicast_socket(record)?;
638            normalize_multicast_membership(record.spec, membership)?
639        };
640
641        let cloned = {
642            let record = table
643                .sockets
644                .get_mut(&socket_id)
645                .ok_or_else(|| SocketTableError::not_found(socket_id))?;
646            let datagram_state = datagram_state_mut(record)?;
647            if !datagram_state
648                .multicast_memberships
649                .remove(&normalized_membership)
650            {
651                return Err(SocketTableError::address_not_available(format!(
652                    "socket {socket_id} has not joined multicast group {}",
653                    normalized_membership.group_address()
654                )));
655            }
656            record.clone()
657        };
658
659        if let Some(members) = table.multicast_groups.get_mut(&normalized_membership) {
660            members.remove(&socket_id);
661            if members.is_empty() {
662                table.multicast_groups.remove(&normalized_membership);
663            }
664        }
665
666        Ok(cloned)
667    }
668
669    pub fn bind_unix(
670        &self,
671        socket_id: SocketId,
672        path: impl Into<String>,
673    ) -> SocketResult<SocketRecord> {
674        let path = normalize_unix_socket_path(path.into())?;
675        let mut table = lock_or_recover(&self.inner.state);
676        let existing = table
677            .sockets
678            .get(&socket_id)
679            .cloned()
680            .ok_or_else(|| SocketTableError::not_found(socket_id))?;
681        if !supports_unix_stream_lifecycle(existing.spec) {
682            return Err(SocketTableError::invalid_argument(format!(
683                "socket {socket_id} is not a Unix stream socket"
684            )));
685        }
686        let existing_id = table.bound_unix_streams.get(&path).copied();
687        let cloned = {
688            let record = table
689                .sockets
690                .get_mut(&socket_id)
691                .ok_or_else(|| SocketTableError::not_found(socket_id))?;
692
693            if let Some(bound_socket_id) = existing_id {
694                if bound_socket_id != socket_id {
695                    return Err(SocketTableError::address_in_use(format!(
696                        "path {path} is already bound"
697                    )));
698                }
699            }
700
701            match record.state {
702                SocketState::Created => {}
703                SocketState::Bound if record.local_unix_path.as_deref() == Some(path.as_str()) => {
704                    return Ok(record.clone());
705                }
706                SocketState::Bound | SocketState::Listening | SocketState::Connected => {
707                    return Err(SocketTableError::invalid_argument(format!(
708                        "socket {socket_id} cannot bind in state {:?}",
709                        record.state
710                    )));
711                }
712            }
713
714            record.local_address = None;
715            record.peer_address = None;
716            record.local_unix_path = Some(path.clone());
717            record.peer_unix_path = None;
718            record.listener_state = None;
719            record.connection_state = None;
720            record.state = SocketState::Bound;
721            record.clone()
722        };
723        table.bound_unix_streams.insert(path, socket_id);
724        Ok(cloned)
725    }
726
727    pub fn listen(&self, socket_id: SocketId, backlog: usize) -> SocketResult<SocketRecord> {
728        if backlog == 0 {
729            return Err(SocketTableError::invalid_argument(
730                "listener backlog must be greater than zero",
731            ));
732        }
733
734        let mut table = lock_or_recover(&self.inner.state);
735        let record = table
736            .sockets
737            .get_mut(&socket_id)
738            .ok_or_else(|| SocketTableError::not_found(socket_id))?;
739
740        if !supports_listener_lifecycle(record.spec) {
741            return Err(SocketTableError::invalid_argument(format!(
742                "socket {socket_id} is not a stream socket"
743            )));
744        }
745        if record.state != SocketState::Bound || !has_bound_endpoint(record) {
746            return Err(SocketTableError::invalid_argument(format!(
747                "socket {socket_id} must be bound before listen"
748            )));
749        }
750
751        record.state = SocketState::Listening;
752        record.listener_state = Some(ListenerState {
753            backlog,
754            pending_accepts: VecDeque::new(),
755        });
756        Ok(record.clone())
757    }
758
759    pub fn enqueue_incoming_tcp_connection(
760        &self,
761        listener_socket_id: SocketId,
762        peer_address: InetSocketAddress,
763    ) -> SocketResult<()> {
764        let mut table = lock_or_recover(&self.inner.state);
765        let record = table
766            .sockets
767            .get_mut(&listener_socket_id)
768            .ok_or_else(|| SocketTableError::not_found(listener_socket_id))?;
769
770        if record.state != SocketState::Listening {
771            return Err(SocketTableError::invalid_argument(format!(
772                "socket {listener_socket_id} is not listening"
773            )));
774        }
775
776        let listener_state = record.listener_state.as_mut().ok_or_else(|| {
777            SocketTableError::invalid_argument(format!(
778                "socket {listener_socket_id} has no listener state"
779            ))
780        })?;
781
782        if listener_state.pending_accepts.len() >= listener_state.backlog {
783            return Err(SocketTableError::would_block(format!(
784                "listener {listener_socket_id} backlog is full"
785            )));
786        }
787
788        listener_state.pending_accepts.push_back(PendingConnection {
789            peer_address: Some(peer_address),
790            peer_unix_path: None,
791            accepted_socket_id: None,
792        });
793        Ok(())
794    }
795
796    pub fn accept(&self, listener_socket_id: SocketId) -> SocketResult<SocketRecord> {
797        let mut table = lock_or_recover(&self.inner.state);
798        let (owner_pid, spec, local_address, local_unix_path, pending) = {
799            let record = table
800                .sockets
801                .get_mut(&listener_socket_id)
802                .ok_or_else(|| SocketTableError::not_found(listener_socket_id))?;
803
804            if record.state != SocketState::Listening {
805                return Err(SocketTableError::invalid_argument(format!(
806                    "socket {listener_socket_id} is not listening"
807                )));
808            }
809
810            let listener_state = record.listener_state.as_mut().ok_or_else(|| {
811                SocketTableError::invalid_argument(format!(
812                    "socket {listener_socket_id} has no listener state"
813                ))
814            })?;
815            let pending = listener_state.pending_accepts.pop_front().ok_or_else(|| {
816                SocketTableError::would_block(format!(
817                    "listener {listener_socket_id} has no pending connections"
818                ))
819            })?;
820
821            (
822                record.owner_pid,
823                record.spec,
824                record.local_address.clone(),
825                record.local_unix_path.clone(),
826                pending,
827            )
828        };
829
830        if let Some(accepted_socket_id) = pending.accepted_socket_id {
831            return table
832                .sockets
833                .get(&accepted_socket_id)
834                .cloned()
835                .ok_or_else(|| SocketTableError::not_found(accepted_socket_id));
836        }
837
838        let socket_id = next_socket_id(&mut table);
839        let record = SocketRecord {
840            id: socket_id,
841            owner_pid,
842            spec,
843            state: SocketState::Connected,
844            local_address,
845            peer_address: pending.peer_address,
846            local_unix_path,
847            peer_unix_path: pending.peer_unix_path,
848            listener_state: None,
849            connection_state: default_connection_state(spec, SocketState::Connected),
850            datagram_state: default_datagram_state(spec),
851        };
852        table.sockets.insert(socket_id, record.clone());
853        table
854            .by_owner
855            .entry(owner_pid)
856            .or_default()
857            .insert(socket_id);
858        Ok(record)
859    }
860
861    pub fn connect_pair(
862        &self,
863        socket_id: SocketId,
864        peer_socket_id: SocketId,
865    ) -> SocketResult<(SocketRecord, SocketRecord)> {
866        if socket_id == peer_socket_id {
867            return Err(SocketTableError::invalid_argument(
868                "socket cannot connect to itself",
869            ));
870        }
871
872        let mut table = lock_or_recover(&self.inner.state);
873        let mut socket = table
874            .sockets
875            .remove(&socket_id)
876            .ok_or_else(|| SocketTableError::not_found(socket_id))?;
877        let Some(mut peer) = table.sockets.remove(&peer_socket_id) else {
878            table.sockets.insert(socket_id, socket);
879            return Err(SocketTableError::not_found(peer_socket_id));
880        };
881
882        if let Err(error) = validate_connect_pair(&socket, &peer) {
883            table.sockets.insert(socket_id, socket);
884            table.sockets.insert(peer_socket_id, peer);
885            return Err(error);
886        }
887
888        socket.state = SocketState::Connected;
889        socket.peer_address = peer.local_address.clone();
890        socket.peer_unix_path = peer.local_unix_path.clone();
891        socket.listener_state = None;
892        socket.connection_state = Some(ConnectionState {
893            peer_socket_id: Some(peer_socket_id),
894            ..ConnectionState::default()
895        });
896
897        peer.state = SocketState::Connected;
898        peer.peer_address = socket.local_address.clone();
899        peer.peer_unix_path = socket.local_unix_path.clone();
900        peer.listener_state = None;
901        peer.connection_state = Some(ConnectionState {
902            peer_socket_id: Some(socket_id),
903            ..ConnectionState::default()
904        });
905
906        let socket_clone = socket.clone();
907        let peer_clone = peer.clone();
908        table.sockets.insert(socket_id, socket);
909        table.sockets.insert(peer_socket_id, peer);
910        Ok((socket_clone, peer_clone))
911    }
912
913    pub fn find_bound_inet_socket(
914        &self,
915        spec: SocketSpec,
916        address: &InetSocketAddress,
917    ) -> Option<SocketRecord> {
918        let address = normalize_inet_address(address.clone());
919        let table = lock_or_recover(&self.inner.state);
920        let socket_id = lookup_bound_inet_socket(&table, spec, &address)?;
921        table.sockets.get(&socket_id).cloned()
922    }
923
924    pub fn connect_to_bound_inet_stream(
925        &self,
926        socket_id: SocketId,
927        target_address: InetSocketAddress,
928    ) -> SocketResult<()> {
929        let target_address = normalize_inet_address(target_address);
930        let mut table = lock_or_recover(&self.inner.state);
931        let listener_socket_id =
932            lookup_bound_inet_socket_in_table(&table.bound_inet_streams, &target_address)
933                .ok_or_else(|| {
934                    SocketTableError::not_found_address(format!(
935                        "no listening socket bound at {}:{}",
936                        target_address.host(),
937                        target_address.port()
938                    ))
939                })?;
940
941        if socket_id == listener_socket_id {
942            return Err(SocketTableError::invalid_argument(
943                "socket cannot connect to its own listening endpoint",
944            ));
945        }
946
947        let mut client = table
948            .sockets
949            .remove(&socket_id)
950            .ok_or_else(|| SocketTableError::not_found(socket_id))?;
951        let result = (|| {
952            // Validate the listener and confirm backlog capacity BEFORE consuming a
953            // socket id. The id counter is monotonic (saturating_add) and never
954            // reclaims, so allocating an id before this check leaks one on every
955            // rejected connect (for example when the backlog is full).
956            {
957                let listener = table
958                    .sockets
959                    .get(&listener_socket_id)
960                    .ok_or_else(|| SocketTableError::not_found(listener_socket_id))?;
961                validate_connect_to_listener(&client, listener)?;
962
963                let listener_state = listener.listener_state.as_ref().ok_or_else(|| {
964                    SocketTableError::invalid_argument(format!(
965                        "socket {listener_socket_id} has no listener state"
966                    ))
967                })?;
968                if listener_state.pending_accepts.len() >= listener_state.backlog {
969                    return Err(SocketTableError::would_block(format!(
970                        "listener {listener_socket_id} backlog is full"
971                    )));
972                }
973            }
974
975            // Capacity confirmed: only now is it safe to consume a socket id.
976            let accepted_socket_id = next_socket_id(&mut table);
977            let listener = table
978                .sockets
979                .get_mut(&listener_socket_id)
980                .ok_or_else(|| SocketTableError::not_found(listener_socket_id))?;
981            let listener_state = listener.listener_state.as_mut().ok_or_else(|| {
982                SocketTableError::invalid_argument(format!(
983                    "socket {listener_socket_id} has no listener state"
984                ))
985            })?;
986
987            let accepted = SocketRecord {
988                id: accepted_socket_id,
989                owner_pid: listener.owner_pid,
990                spec: listener.spec,
991                state: SocketState::Connected,
992                local_address: listener.local_address.clone(),
993                peer_address: client.local_address.clone(),
994                local_unix_path: None,
995                peer_unix_path: None,
996                listener_state: None,
997                connection_state: Some(ConnectionState {
998                    peer_socket_id: Some(socket_id),
999                    ..ConnectionState::default()
1000                }),
1001                datagram_state: default_datagram_state(listener.spec),
1002            };
1003
1004            listener_state.pending_accepts.push_back(PendingConnection {
1005                peer_address: client.local_address.clone(),
1006                peer_unix_path: None,
1007                accepted_socket_id: Some(accepted_socket_id),
1008            });
1009
1010            client.state = SocketState::Connected;
1011            client.peer_address = listener.local_address.clone();
1012            client.peer_unix_path = None;
1013            client.listener_state = None;
1014            client.connection_state = Some(ConnectionState {
1015                peer_socket_id: Some(accepted_socket_id),
1016                ..ConnectionState::default()
1017            });
1018
1019            Ok(accepted)
1020        })();
1021
1022        match result {
1023            Ok(accepted) => {
1024                let accepted_socket_id = accepted.id;
1025                table.sockets.insert(socket_id, client);
1026                table.sockets.insert(accepted_socket_id, accepted.clone());
1027                table
1028                    .by_owner
1029                    .entry(accepted.owner_pid)
1030                    .or_default()
1031                    .insert(accepted_socket_id);
1032                Ok(())
1033            }
1034            Err(error) => {
1035                table.sockets.insert(socket_id, client);
1036                Err(error)
1037            }
1038        }
1039    }
1040
1041    pub fn find_bound_unix_socket(&self, path: &str) -> Option<SocketRecord> {
1042        let path = normalize_unix_socket_path(path).ok()?;
1043        let table = lock_or_recover(&self.inner.state);
1044        let socket_id = table.bound_unix_streams.get(&path).copied()?;
1045        table.sockets.get(&socket_id).cloned()
1046    }
1047
1048    pub fn connect_to_bound_unix_stream(
1049        &self,
1050        socket_id: SocketId,
1051        target_path: impl Into<String>,
1052    ) -> SocketResult<()> {
1053        let target_path = normalize_unix_socket_path(target_path.into())?;
1054        let mut table = lock_or_recover(&self.inner.state);
1055        let listener_socket_id = table
1056            .bound_unix_streams
1057            .get(&target_path)
1058            .copied()
1059            .ok_or_else(|| {
1060                SocketTableError::not_found_address(format!(
1061                    "no listening socket bound at path {target_path}"
1062                ))
1063            })?;
1064
1065        if socket_id == listener_socket_id {
1066            return Err(SocketTableError::invalid_argument(
1067                "socket cannot connect to its own listening endpoint",
1068            ));
1069        }
1070
1071        let mut client = table
1072            .sockets
1073            .remove(&socket_id)
1074            .ok_or_else(|| SocketTableError::not_found(socket_id))?;
1075        let result = (|| {
1076            // Validate the listener and confirm backlog capacity BEFORE consuming a
1077            // socket id. The id counter is monotonic (saturating_add) and never
1078            // reclaims, so allocating an id before this check leaks one on every
1079            // rejected connect (for example when the backlog is full).
1080            {
1081                let listener = table
1082                    .sockets
1083                    .get(&listener_socket_id)
1084                    .ok_or_else(|| SocketTableError::not_found(listener_socket_id))?;
1085                validate_connect_to_listener(&client, listener)?;
1086
1087                let listener_state = listener.listener_state.as_ref().ok_or_else(|| {
1088                    SocketTableError::invalid_argument(format!(
1089                        "socket {listener_socket_id} has no listener state"
1090                    ))
1091                })?;
1092                if listener_state.pending_accepts.len() >= listener_state.backlog {
1093                    return Err(SocketTableError::would_block(format!(
1094                        "listener {listener_socket_id} backlog is full"
1095                    )));
1096                }
1097            }
1098
1099            // Capacity confirmed: only now is it safe to consume a socket id.
1100            let accepted_socket_id = next_socket_id(&mut table);
1101            let listener = table
1102                .sockets
1103                .get_mut(&listener_socket_id)
1104                .ok_or_else(|| SocketTableError::not_found(listener_socket_id))?;
1105            let listener_state = listener.listener_state.as_mut().ok_or_else(|| {
1106                SocketTableError::invalid_argument(format!(
1107                    "socket {listener_socket_id} has no listener state"
1108                ))
1109            })?;
1110
1111            let accepted = SocketRecord {
1112                id: accepted_socket_id,
1113                owner_pid: listener.owner_pid,
1114                spec: listener.spec,
1115                state: SocketState::Connected,
1116                local_address: None,
1117                peer_address: None,
1118                local_unix_path: listener.local_unix_path.clone(),
1119                peer_unix_path: client.local_unix_path.clone(),
1120                listener_state: None,
1121                connection_state: Some(ConnectionState {
1122                    peer_socket_id: Some(socket_id),
1123                    ..ConnectionState::default()
1124                }),
1125                datagram_state: default_datagram_state(listener.spec),
1126            };
1127
1128            listener_state.pending_accepts.push_back(PendingConnection {
1129                peer_address: None,
1130                peer_unix_path: client.local_unix_path.clone(),
1131                accepted_socket_id: Some(accepted_socket_id),
1132            });
1133
1134            client.state = SocketState::Connected;
1135            client.peer_address = None;
1136            client.peer_unix_path = listener.local_unix_path.clone();
1137            client.listener_state = None;
1138            client.connection_state = Some(ConnectionState {
1139                peer_socket_id: Some(accepted_socket_id),
1140                ..ConnectionState::default()
1141            });
1142
1143            Ok(accepted)
1144        })();
1145
1146        match result {
1147            Ok(accepted) => {
1148                let accepted_socket_id = accepted.id;
1149                table.sockets.insert(socket_id, client);
1150                table.sockets.insert(accepted_socket_id, accepted.clone());
1151                table
1152                    .by_owner
1153                    .entry(accepted.owner_pid)
1154                    .or_default()
1155                    .insert(accepted_socket_id);
1156                Ok(())
1157            }
1158            Err(error) => {
1159                table.sockets.insert(socket_id, client);
1160                Err(error)
1161            }
1162        }
1163    }
1164
1165    pub fn send_to_bound_udp_socket(
1166        &self,
1167        socket_id: SocketId,
1168        target_address: InetSocketAddress,
1169        data: &[u8],
1170    ) -> SocketResult<usize> {
1171        let target_address = normalize_inet_address(target_address);
1172        let mut table = lock_or_recover(&self.inner.state);
1173        let sender = table
1174            .sockets
1175            .get(&socket_id)
1176            .cloned()
1177            .ok_or_else(|| SocketTableError::not_found(socket_id))?;
1178        validate_bound_udp_sender(&sender)?;
1179
1180        let receiver_socket_id = lookup_bound_inet_datagram_socket_in_table(
1181            &table.bound_inet_datagrams,
1182            &target_address,
1183        )
1184        .ok_or_else(|| {
1185            SocketTableError::not_found_address(format!(
1186                "no UDP socket bound at {}:{}",
1187                target_address.host(),
1188                target_address.port()
1189            ))
1190        })?;
1191        let receiver = table
1192            .sockets
1193            .get_mut(&receiver_socket_id)
1194            .ok_or_else(|| SocketTableError::not_found(receiver_socket_id))?;
1195        validate_bound_udp_receiver(receiver)?;
1196
1197        let datagram_state = receiver.datagram_state.as_mut().ok_or_else(|| {
1198            SocketTableError::invalid_argument(format!(
1199                "socket {receiver_socket_id} does not support datagrams"
1200            ))
1201        })?;
1202        datagram_state.recv_queue.push_back(QueuedDatagram {
1203            source_address: sender.local_address.clone(),
1204            payload: data.to_vec(),
1205        });
1206        Ok(data.len())
1207    }
1208
1209    pub fn check_send_to_bound_udp_socket(
1210        &self,
1211        socket_id: SocketId,
1212        target_address: InetSocketAddress,
1213    ) -> SocketResult<()> {
1214        let target_address = normalize_inet_address(target_address);
1215        let table = lock_or_recover(&self.inner.state);
1216        let sender = table
1217            .sockets
1218            .get(&socket_id)
1219            .ok_or_else(|| SocketTableError::not_found(socket_id))?;
1220        validate_bound_udp_sender(sender)?;
1221
1222        let receiver_socket_id = lookup_bound_inet_datagram_socket_in_table(
1223            &table.bound_inet_datagrams,
1224            &target_address,
1225        )
1226        .ok_or_else(|| {
1227            SocketTableError::not_found_address(format!(
1228                "no UDP socket bound at {}:{}",
1229                target_address.host(),
1230                target_address.port()
1231            ))
1232        })?;
1233        let receiver = table
1234            .sockets
1235            .get(&receiver_socket_id)
1236            .ok_or_else(|| SocketTableError::not_found(receiver_socket_id))?;
1237        validate_bound_udp_receiver(receiver)?;
1238        Ok(())
1239    }
1240
1241    pub fn recv_datagram(
1242        &self,
1243        socket_id: SocketId,
1244        max_bytes: usize,
1245    ) -> SocketResult<Option<ReceivedDatagram>> {
1246        let mut table = lock_or_recover(&self.inner.state);
1247        let record = table
1248            .sockets
1249            .get_mut(&socket_id)
1250            .ok_or_else(|| SocketTableError::not_found(socket_id))?;
1251        validate_bound_udp_receiver(record)?;
1252
1253        let datagram_state = record.datagram_state.as_mut().ok_or_else(|| {
1254            SocketTableError::invalid_argument(format!(
1255                "socket {socket_id} does not support datagrams"
1256            ))
1257        })?;
1258        let Some(datagram) = datagram_state.recv_queue.pop_front() else {
1259            return Err(SocketTableError::would_block(format!(
1260                "socket {socket_id} has no queued datagrams"
1261            )));
1262        };
1263
1264        let payload = if datagram.payload.len() > max_bytes {
1265            datagram.payload[..max_bytes].to_vec()
1266        } else {
1267            datagram.payload
1268        };
1269        Ok(Some(ReceivedDatagram {
1270            source_address: datagram.source_address,
1271            payload,
1272        }))
1273    }
1274
1275    pub fn poll(&self, socket_id: SocketId, requested: PollEvents) -> SocketResult<PollEvents> {
1276        let table = lock_or_recover(&self.inner.state);
1277        let record = table
1278            .sockets
1279            .get(&socket_id)
1280            .ok_or_else(|| SocketTableError::not_found(socket_id))?;
1281
1282        let mut events = PollEvents::empty();
1283        match record.state {
1284            SocketState::Listening => {
1285                if requested.intersects(POLLIN) && record.pending_accept_count() > 0 {
1286                    events |= POLLIN;
1287                }
1288            }
1289            SocketState::Connected => {
1290                let connection = record.connection_state.as_ref().ok_or_else(|| {
1291                    SocketTableError::not_connected(format!("socket {socket_id} is not connected"))
1292                })?;
1293                let peer = connection
1294                    .peer_socket_id
1295                    .and_then(|peer_socket_id| table.sockets.get(&peer_socket_id));
1296
1297                if requested.intersects(POLLIN) && !connection.recv_buffer.is_empty() {
1298                    events |= POLLIN;
1299                }
1300                if connection.peer_write_shutdown || peer.is_none() {
1301                    events |= POLLHUP;
1302                }
1303
1304                if requested.intersects(POLLOUT) && !connection.write_shutdown {
1305                    if peer
1306                        .and_then(|peer| peer.connection_state.as_ref())
1307                        .map(|peer_connection| peer_connection.read_shutdown)
1308                        .unwrap_or(true)
1309                    {
1310                        events |= POLLERR;
1311                    } else {
1312                        events |= POLLOUT;
1313                    }
1314                }
1315            }
1316            SocketState::Bound if supports_inet_datagram_lifecycle(record.spec) => {
1317                let datagram_state = record.datagram_state.as_ref().ok_or_else(|| {
1318                    SocketTableError::invalid_argument(format!(
1319                        "socket {socket_id} does not support datagrams"
1320                    ))
1321                })?;
1322                if requested.intersects(POLLIN) && !datagram_state.recv_queue.is_empty() {
1323                    events |= POLLIN;
1324                }
1325                if requested.intersects(POLLOUT) {
1326                    events |= POLLOUT;
1327                }
1328            }
1329            SocketState::Created | SocketState::Bound => {}
1330        }
1331
1332        Ok(events)
1333    }
1334
1335    pub fn write(&self, socket_id: SocketId, data: &[u8]) -> SocketResult<usize> {
1336        let mut table = lock_or_recover(&self.inner.state);
1337        let record = table
1338            .sockets
1339            .get(&socket_id)
1340            .cloned()
1341            .ok_or_else(|| SocketTableError::not_found(socket_id))?;
1342        let connection = record.connection_state.as_ref().ok_or_else(|| {
1343            SocketTableError::not_connected(format!("socket {socket_id} is not connected"))
1344        })?;
1345        if record.state != SocketState::Connected {
1346            return Err(SocketTableError::not_connected(format!(
1347                "socket {socket_id} is not connected"
1348            )));
1349        }
1350        if connection.write_shutdown {
1351            return Err(SocketTableError::broken_pipe(format!(
1352                "socket {socket_id} write side is shut down"
1353            )));
1354        }
1355
1356        let peer_socket_id = connection.peer_socket_id.ok_or_else(|| {
1357            SocketTableError::broken_pipe(format!("socket {socket_id} peer is closed"))
1358        })?;
1359        let peer = table.sockets.get_mut(&peer_socket_id).ok_or_else(|| {
1360            SocketTableError::broken_pipe(format!("socket {socket_id} peer is closed"))
1361        })?;
1362        let peer_connection = peer.connection_state.as_mut().ok_or_else(|| {
1363            SocketTableError::broken_pipe(format!("socket {socket_id} peer is closed"))
1364        })?;
1365        if peer_connection.read_shutdown {
1366            return Err(SocketTableError::broken_pipe(format!(
1367                "socket {peer_socket_id} read side is shut down"
1368            )));
1369        }
1370
1371        peer_connection.recv_buffer.extend(data.iter().copied());
1372        Ok(data.len())
1373    }
1374
1375    pub fn check_write(&self, socket_id: SocketId) -> SocketResult<()> {
1376        let table = lock_or_recover(&self.inner.state);
1377        let record = table
1378            .sockets
1379            .get(&socket_id)
1380            .ok_or_else(|| SocketTableError::not_found(socket_id))?;
1381        let connection = record.connection_state.as_ref().ok_or_else(|| {
1382            SocketTableError::not_connected(format!("socket {socket_id} is not connected"))
1383        })?;
1384        if record.state != SocketState::Connected {
1385            return Err(SocketTableError::not_connected(format!(
1386                "socket {socket_id} is not connected"
1387            )));
1388        }
1389        if connection.write_shutdown {
1390            return Err(SocketTableError::broken_pipe(format!(
1391                "socket {socket_id} write side is shut down"
1392            )));
1393        }
1394
1395        let peer_socket_id = connection.peer_socket_id.ok_or_else(|| {
1396            SocketTableError::broken_pipe(format!("socket {socket_id} peer is closed"))
1397        })?;
1398        let peer = table.sockets.get(&peer_socket_id).ok_or_else(|| {
1399            SocketTableError::broken_pipe(format!("socket {socket_id} peer is closed"))
1400        })?;
1401        let peer_connection = peer.connection_state.as_ref().ok_or_else(|| {
1402            SocketTableError::broken_pipe(format!("socket {socket_id} peer is closed"))
1403        })?;
1404        if peer_connection.read_shutdown {
1405            return Err(SocketTableError::broken_pipe(format!(
1406                "socket {peer_socket_id} read side is shut down"
1407            )));
1408        }
1409
1410        Ok(())
1411    }
1412
1413    pub fn read(&self, socket_id: SocketId, max_bytes: usize) -> SocketResult<Option<Vec<u8>>> {
1414        if max_bytes == 0 {
1415            return Ok(Some(Vec::new()));
1416        }
1417
1418        let mut table = lock_or_recover(&self.inner.state);
1419        let record = table
1420            .sockets
1421            .get(&socket_id)
1422            .cloned()
1423            .ok_or_else(|| SocketTableError::not_found(socket_id))?;
1424        if record.state != SocketState::Connected {
1425            return Err(SocketTableError::not_connected(format!(
1426                "socket {socket_id} is not connected"
1427            )));
1428        }
1429
1430        let connection = record.connection_state.as_ref().ok_or_else(|| {
1431            SocketTableError::not_connected(format!("socket {socket_id} is not connected"))
1432        })?;
1433        if connection.read_shutdown {
1434            return Ok(None);
1435        }
1436        if !connection.recv_buffer.is_empty() {
1437            let record = table
1438                .sockets
1439                .get_mut(&socket_id)
1440                .ok_or_else(|| SocketTableError::not_found(socket_id))?;
1441            let connection = record.connection_state.as_mut().ok_or_else(|| {
1442                SocketTableError::not_connected(format!("socket {socket_id} is not connected"))
1443            })?;
1444            let read_len = connection.recv_buffer.len().min(max_bytes);
1445            let bytes = connection.recv_buffer.drain(..read_len).collect::<Vec<_>>();
1446            return Ok(Some(bytes));
1447        }
1448
1449        let peer_open = connection
1450            .peer_socket_id
1451            .map(|peer_socket_id| table.sockets.contains_key(&peer_socket_id))
1452            .unwrap_or(false);
1453        if connection.peer_write_shutdown || !peer_open {
1454            return Ok(None);
1455        }
1456
1457        Err(SocketTableError::would_block(format!(
1458            "socket {socket_id} has no readable data"
1459        )))
1460    }
1461
1462    pub fn shutdown(&self, socket_id: SocketId, how: SocketShutdown) -> SocketResult<SocketRecord> {
1463        let mut table = lock_or_recover(&self.inner.state);
1464        let record = table
1465            .sockets
1466            .remove(&socket_id)
1467            .ok_or_else(|| SocketTableError::not_found(socket_id))?;
1468
1469        if record.state != SocketState::Connected {
1470            table.sockets.insert(socket_id, record);
1471            return Err(SocketTableError::not_connected(format!(
1472                "socket {socket_id} is not connected"
1473            )));
1474        }
1475
1476        let Some(mut connection) = record.connection_state.clone() else {
1477            table.sockets.insert(socket_id, record);
1478            return Err(SocketTableError::not_connected(format!(
1479                "socket {socket_id} is not connected"
1480            )));
1481        };
1482
1483        if matches!(how, SocketShutdown::Read | SocketShutdown::Both) {
1484            connection.recv_buffer.clear();
1485            connection.read_shutdown = true;
1486        }
1487        if matches!(how, SocketShutdown::Write | SocketShutdown::Both) {
1488            connection.write_shutdown = true;
1489            if let Some(peer_socket_id) = connection.peer_socket_id {
1490                if let Some(peer) = table.sockets.get_mut(&peer_socket_id) {
1491                    if let Some(peer_connection) = peer.connection_state.as_mut() {
1492                        peer_connection.peer_write_shutdown = true;
1493                    }
1494                }
1495            }
1496        }
1497
1498        let mut record = record;
1499        record.connection_state = Some(connection);
1500        let cloned = record.clone();
1501        table.sockets.insert(socket_id, record);
1502        Ok(cloned)
1503    }
1504
1505    pub fn remove(&self, socket_id: SocketId) -> SocketResult<SocketRecord> {
1506        let mut table = lock_or_recover(&self.inner.state);
1507        remove_socket(&mut table, socket_id).ok_or_else(|| SocketTableError::not_found(socket_id))
1508    }
1509
1510    pub fn remove_all_for_pid(&self, owner_pid: u32) -> Vec<SocketRecord> {
1511        let mut table = lock_or_recover(&self.inner.state);
1512        let Some(socket_ids) = table.by_owner.remove(&owner_pid) else {
1513            return Vec::new();
1514        };
1515
1516        socket_ids
1517            .into_iter()
1518            .filter_map(|socket_id| remove_socket(&mut table, socket_id))
1519            .collect()
1520    }
1521
1522    pub fn snapshot(&self) -> SocketTableSnapshot {
1523        let table = lock_or_recover(&self.inner.state);
1524        let mut snapshot = SocketTableSnapshot {
1525            sockets: table.sockets.len(),
1526            ..SocketTableSnapshot::default()
1527        };
1528        for record in table.sockets.values() {
1529            if record.state.counts_as_listener() {
1530                snapshot.listeners += 1;
1531            }
1532            if record.state.counts_as_connection() {
1533                snapshot.connections += 1;
1534            }
1535            if let Some(connection) = &record.connection_state {
1536                snapshot.buffered_bytes = snapshot
1537                    .buffered_bytes
1538                    .saturating_add(connection.recv_buffer.len());
1539            }
1540            if let Some(datagram_state) = &record.datagram_state {
1541                snapshot.datagram_queue_len = snapshot
1542                    .datagram_queue_len
1543                    .saturating_add(datagram_state.recv_queue.len());
1544                snapshot.buffered_bytes = snapshot
1545                    .buffered_bytes
1546                    .saturating_add(datagram_queue_bytes(&datagram_state.recv_queue));
1547            }
1548        }
1549        snapshot
1550    }
1551}
1552
1553fn datagram_queue_bytes(queue: &VecDeque<QueuedDatagram>) -> usize {
1554    queue
1555        .iter()
1556        .map(|datagram| datagram.payload.len())
1557        .sum::<usize>()
1558}
1559
1560fn next_socket_id(table: &mut SocketTableState) -> SocketId {
1561    if table.next_socket_id == 0 {
1562        table.next_socket_id = 1;
1563    }
1564    let socket_id = table.next_socket_id;
1565    table.next_socket_id = table.next_socket_id.saturating_add(1);
1566    socket_id
1567}
1568
1569fn validate_state_transition(current: SocketState, next: SocketState) -> SocketResult<()> {
1570    if current == SocketState::Connected && next != SocketState::Connected {
1571        return Err(SocketTableError::invalid_argument(format!(
1572            "invalid socket state transition from {current:?} to {next:?}"
1573        )));
1574    }
1575    Ok(())
1576}
1577
1578fn validate_connect_pair(socket: &SocketRecord, peer: &SocketRecord) -> SocketResult<()> {
1579    if !supports_connection_lifecycle(socket.spec) {
1580        return Err(SocketTableError::invalid_argument(format!(
1581            "socket {} does not support stream connections",
1582            socket.id
1583        )));
1584    }
1585    if !supports_connection_lifecycle(peer.spec) {
1586        return Err(SocketTableError::invalid_argument(format!(
1587            "socket {} does not support stream connections",
1588            peer.id
1589        )));
1590    }
1591    if !matches!(socket.state, SocketState::Created | SocketState::Bound) {
1592        return Err(SocketTableError::invalid_argument(format!(
1593            "socket {} cannot connect in state {:?}",
1594            socket.id, socket.state
1595        )));
1596    }
1597    if !matches!(peer.state, SocketState::Created | SocketState::Bound) {
1598        return Err(SocketTableError::invalid_argument(format!(
1599            "socket {} cannot connect in state {:?}",
1600            peer.id, peer.state
1601        )));
1602    }
1603    Ok(())
1604}
1605
1606fn default_connection_state(spec: SocketSpec, state: SocketState) -> Option<ConnectionState> {
1607    if state == SocketState::Connected && supports_connection_lifecycle(spec) {
1608        Some(ConnectionState::default())
1609    } else {
1610        None
1611    }
1612}
1613
1614fn default_datagram_state(spec: SocketSpec) -> Option<DatagramState> {
1615    if supports_inet_datagram_lifecycle(spec) {
1616        Some(DatagramState::default())
1617    } else {
1618        None
1619    }
1620}
1621
1622fn supports_connection_lifecycle(spec: SocketSpec) -> bool {
1623    matches!(spec.socket_type, SocketType::Stream)
1624}
1625
1626fn supports_listener_lifecycle(spec: SocketSpec) -> bool {
1627    matches!(spec.socket_type, SocketType::Stream)
1628        && matches!(
1629            spec.domain,
1630            SocketDomain::Inet | SocketDomain::Inet6 | SocketDomain::Unix
1631        )
1632}
1633
1634fn supports_inet_bind(spec: SocketSpec) -> bool {
1635    matches!(spec.domain, SocketDomain::Inet | SocketDomain::Inet6)
1636        && matches!(spec.socket_type, SocketType::Stream | SocketType::Datagram)
1637}
1638
1639fn supports_unix_stream_lifecycle(spec: SocketSpec) -> bool {
1640    matches!(spec.socket_type, SocketType::Stream) && matches!(spec.domain, SocketDomain::Unix)
1641}
1642
1643fn supports_inet_stream_lifecycle(spec: SocketSpec) -> bool {
1644    matches!(spec.socket_type, SocketType::Stream)
1645        && matches!(spec.domain, SocketDomain::Inet | SocketDomain::Inet6)
1646}
1647
1648fn supports_inet_datagram_lifecycle(spec: SocketSpec) -> bool {
1649    matches!(spec.socket_type, SocketType::Datagram)
1650        && matches!(spec.domain, SocketDomain::Inet | SocketDomain::Inet6)
1651}
1652
1653fn lookup_conflicting_bound_inet_socket_ids(
1654    table: &SocketTableState,
1655    spec: SocketSpec,
1656    address: &InetSocketAddress,
1657) -> Vec<SocketId> {
1658    if supports_inet_stream_lifecycle(spec) {
1659        table
1660            .bound_inet_streams
1661            .iter()
1662            .find_map(|(bound_address, socket_id)| {
1663                inet_stream_bind_addresses_overlap(bound_address, address).then_some(*socket_id)
1664            })
1665            .into_iter()
1666            .collect()
1667    } else if supports_inet_datagram_lifecycle(spec) {
1668        table
1669            .bound_inet_datagrams
1670            .iter()
1671            .filter(|(bound_address, _)| inet_stream_bind_addresses_overlap(bound_address, address))
1672            .flat_map(|(_, socket_ids)| socket_ids.iter().copied())
1673            .collect()
1674    } else {
1675        Vec::new()
1676    }
1677}
1678
1679fn lookup_bound_inet_socket(
1680    table: &SocketTableState,
1681    spec: SocketSpec,
1682    address: &InetSocketAddress,
1683) -> Option<SocketId> {
1684    if supports_inet_stream_lifecycle(spec) {
1685        lookup_bound_inet_socket_in_table(&table.bound_inet_streams, address)
1686    } else if supports_inet_datagram_lifecycle(spec) {
1687        lookup_bound_inet_datagram_socket_in_table(&table.bound_inet_datagrams, address)
1688    } else {
1689        None
1690    }
1691}
1692
1693fn inet_stream_bind_addresses_overlap(
1694    existing: &InetSocketAddress,
1695    requested: &InetSocketAddress,
1696) -> bool {
1697    if existing == requested {
1698        return true;
1699    }
1700
1701    wildcard_inet_address(existing).as_ref() == Some(requested)
1702        || wildcard_inet_address(requested).as_ref() == Some(existing)
1703}
1704
1705fn lookup_bound_inet_socket_in_table(
1706    sockets: &BTreeMap<InetSocketAddress, SocketId>,
1707    address: &InetSocketAddress,
1708) -> Option<SocketId> {
1709    sockets.get(address).copied().or_else(|| {
1710        wildcard_inet_address(address).and_then(|wildcard| sockets.get(&wildcard).copied())
1711    })
1712}
1713
1714fn lookup_bound_inet_datagram_socket_in_table(
1715    sockets: &BTreeMap<InetSocketAddress, BTreeSet<SocketId>>,
1716    address: &InetSocketAddress,
1717) -> Option<SocketId> {
1718    sockets
1719        .get(address)
1720        .and_then(|socket_ids| socket_ids.first().copied())
1721        .or_else(|| {
1722            wildcard_inet_address(address).and_then(|wildcard| {
1723                sockets
1724                    .get(&wildcard)
1725                    .and_then(|socket_ids| socket_ids.first().copied())
1726            })
1727        })
1728}
1729
1730fn register_bound_inet_socket(
1731    table: &mut SocketTableState,
1732    spec: SocketSpec,
1733    address: InetSocketAddress,
1734    socket_id: SocketId,
1735) {
1736    if supports_inet_stream_lifecycle(spec) {
1737        table.bound_inet_streams.insert(address, socket_id);
1738    } else if supports_inet_datagram_lifecycle(spec) {
1739        table
1740            .bound_inet_datagrams
1741            .entry(address)
1742            .or_default()
1743            .insert(socket_id);
1744    }
1745}
1746
1747fn validate_connect_to_listener(
1748    client: &SocketRecord,
1749    listener: &SocketRecord,
1750) -> SocketResult<()> {
1751    if !supports_connection_lifecycle(client.spec) {
1752        return Err(SocketTableError::invalid_argument(format!(
1753            "socket {} does not support stream connections",
1754            client.id
1755        )));
1756    }
1757    if !supports_listener_lifecycle(listener.spec) {
1758        return Err(SocketTableError::invalid_argument(format!(
1759            "socket {} is not a stream listener",
1760            listener.id
1761        )));
1762    }
1763    if !matches!(client.state, SocketState::Created | SocketState::Bound) {
1764        return Err(SocketTableError::invalid_argument(format!(
1765            "socket {} cannot connect in state {:?}",
1766            client.id, client.state
1767        )));
1768    }
1769    if listener.state != SocketState::Listening {
1770        return Err(SocketTableError::invalid_argument(format!(
1771            "socket {} is not listening",
1772            listener.id
1773        )));
1774    }
1775    Ok(())
1776}
1777
1778fn has_bound_endpoint(record: &SocketRecord) -> bool {
1779    record.local_address.is_some() || record.local_unix_path.is_some()
1780}
1781
1782fn validate_bound_udp_sender(sender: &SocketRecord) -> SocketResult<()> {
1783    if !supports_inet_datagram_lifecycle(sender.spec) {
1784        return Err(SocketTableError::invalid_argument(format!(
1785            "socket {} is not an INET datagram socket",
1786            sender.id
1787        )));
1788    }
1789    if sender.state != SocketState::Bound || sender.local_address.is_none() {
1790        return Err(SocketTableError::invalid_argument(format!(
1791            "socket {} must be bound before sending datagrams",
1792            sender.id
1793        )));
1794    }
1795    Ok(())
1796}
1797
1798fn validate_bound_udp_receiver(receiver: &SocketRecord) -> SocketResult<()> {
1799    if !supports_inet_datagram_lifecycle(receiver.spec) {
1800        return Err(SocketTableError::invalid_argument(format!(
1801            "socket {} is not an INET datagram socket",
1802            receiver.id
1803        )));
1804    }
1805    if receiver.state != SocketState::Bound || receiver.local_address.is_none() {
1806        return Err(SocketTableError::invalid_argument(format!(
1807            "socket {} must be bound to receive datagrams",
1808            receiver.id
1809        )));
1810    }
1811    Ok(())
1812}
1813
1814fn datagram_state_mut(record: &mut SocketRecord) -> SocketResult<&mut DatagramState> {
1815    if !supports_inet_datagram_lifecycle(record.spec) {
1816        return Err(SocketTableError::invalid_argument(format!(
1817            "socket {} is not an INET datagram socket",
1818            record.id
1819        )));
1820    }
1821    record.datagram_state.as_mut().ok_or_else(|| {
1822        SocketTableError::invalid_argument(format!(
1823            "socket {} does not support datagrams",
1824            record.id
1825        ))
1826    })
1827}
1828
1829fn validate_multicast_socket(record: &SocketRecord) -> SocketResult<()> {
1830    validate_bound_udp_receiver(record)?;
1831    if record.spec.domain != SocketDomain::Inet {
1832        return Err(SocketTableError::invalid_argument(format!(
1833            "socket {} multicast membership is only implemented for IPv4 datagrams",
1834            record.id
1835        )));
1836    }
1837    Ok(())
1838}
1839
1840fn normalize_multicast_membership(
1841    spec: SocketSpec,
1842    membership: SocketMulticastMembership,
1843) -> SocketResult<SocketMulticastMembership> {
1844    let group_address = membership.group_address.trim().to_ascii_lowercase();
1845    let interface_address = membership
1846        .interface_address
1847        .map(|value| value.trim().to_ascii_lowercase())
1848        .filter(|value| !value.is_empty());
1849
1850    match spec.domain {
1851        SocketDomain::Inet => {
1852            let parsed = group_address.parse::<Ipv4Addr>().map_err(|_| {
1853                SocketTableError::invalid_argument(format!(
1854                    "invalid IPv4 multicast address {group_address}"
1855                ))
1856            })?;
1857            if !parsed.is_multicast() {
1858                return Err(SocketTableError::invalid_argument(format!(
1859                    "address {group_address} is not an IPv4 multicast group"
1860                )));
1861            }
1862        }
1863        SocketDomain::Inet6 => {
1864            let parsed = group_address.parse::<Ipv6Addr>().map_err(|_| {
1865                SocketTableError::invalid_argument(format!(
1866                    "invalid IPv6 multicast address {group_address}"
1867                ))
1868            })?;
1869            if !parsed.is_multicast() {
1870                return Err(SocketTableError::invalid_argument(format!(
1871                    "address {group_address} is not an IPv6 multicast group"
1872                )));
1873            }
1874        }
1875        SocketDomain::Unix => {
1876            return Err(SocketTableError::invalid_argument(
1877                "unix sockets do not support multicast membership",
1878            ));
1879        }
1880    }
1881
1882    Ok(SocketMulticastMembership::new(
1883        group_address,
1884        interface_address,
1885    ))
1886}
1887
1888fn has_incompatible_inet_bind_conflict(
1889    table: &SocketTableState,
1890    record: &SocketRecord,
1891    conflicting_ids: &[SocketId],
1892) -> bool {
1893    conflicting_ids.iter().any(|conflicting_id| {
1894        if *conflicting_id == record.id {
1895            return false;
1896        }
1897
1898        let Some(existing) = table.sockets.get(conflicting_id) else {
1899            return false;
1900        };
1901
1902        if supports_inet_datagram_lifecycle(record.spec) {
1903            !inet_datagram_bind_shares_port(record, existing)
1904        } else {
1905            true
1906        }
1907    })
1908}
1909
1910fn inet_datagram_bind_shares_port(requested: &SocketRecord, existing: &SocketRecord) -> bool {
1911    (requested.reuse_port() && existing.reuse_port())
1912        || (requested.reuse_address() && existing.reuse_address())
1913}
1914
1915fn remove_socket(table: &mut SocketTableState, socket_id: SocketId) -> Option<SocketRecord> {
1916    let record = table.sockets.remove(&socket_id)?;
1917    unregister_bound_socket(table, &record);
1918    unregister_multicast_memberships(table, &record);
1919    if let Some(listener_state) = record.listener_state.as_ref() {
1920        let pending_socket_ids = listener_state
1921            .pending_accepts
1922            .iter()
1923            .filter_map(|pending| pending.accepted_socket_id)
1924            .collect::<Vec<_>>();
1925        for pending_socket_id in pending_socket_ids {
1926            let _ = remove_socket(table, pending_socket_id);
1927        }
1928    }
1929    if let Some(connection) = record.connection_state.as_ref() {
1930        if let Some(peer_socket_id) = connection.peer_socket_id {
1931            if let Some(peer) = table.sockets.get_mut(&peer_socket_id) {
1932                if let Some(peer_connection) = peer.connection_state.as_mut() {
1933                    if peer_connection.peer_socket_id == Some(socket_id) {
1934                        peer_connection.peer_socket_id = None;
1935                    }
1936                    peer_connection.peer_write_shutdown = true;
1937                }
1938            }
1939        }
1940    }
1941    if let Some(owner_sockets) = table.by_owner.get_mut(&record.owner_pid) {
1942        owner_sockets.remove(&socket_id);
1943        if owner_sockets.is_empty() {
1944            table.by_owner.remove(&record.owner_pid);
1945        }
1946    }
1947    Some(record)
1948}
1949
1950fn unregister_bound_socket(table: &mut SocketTableState, record: &SocketRecord) {
1951    let Some(address) = record.local_address.as_ref() else {
1952        if supports_unix_stream_lifecycle(record.spec) {
1953            if let Some(path) = record.local_unix_path.as_ref() {
1954                if table.bound_unix_streams.get(path).copied() == Some(record.id) {
1955                    table.bound_unix_streams.remove(path);
1956                }
1957            }
1958        }
1959        return;
1960    };
1961    if supports_inet_stream_lifecycle(record.spec)
1962        && table.bound_inet_streams.get(address).copied() == Some(record.id)
1963    {
1964        table.bound_inet_streams.remove(address);
1965    }
1966    if supports_inet_datagram_lifecycle(record.spec) {
1967        if let Some(socket_ids) = table.bound_inet_datagrams.get_mut(address) {
1968            socket_ids.remove(&record.id);
1969            if socket_ids.is_empty() {
1970                table.bound_inet_datagrams.remove(address);
1971            }
1972        }
1973    }
1974}
1975
1976fn unregister_multicast_memberships(table: &mut SocketTableState, record: &SocketRecord) {
1977    let Some(datagram_state) = record.datagram_state.as_ref() else {
1978        return;
1979    };
1980
1981    for membership in &datagram_state.multicast_memberships {
1982        if let Some(socket_ids) = table.multicast_groups.get_mut(membership) {
1983            socket_ids.remove(&record.id);
1984            if socket_ids.is_empty() {
1985                table.multicast_groups.remove(membership);
1986            }
1987        }
1988    }
1989}
1990
1991fn normalize_inet_address(address: InetSocketAddress) -> InetSocketAddress {
1992    match address.host().to_ascii_lowercase().as_str() {
1993        "localhost" => InetSocketAddress::new("127.0.0.1", address.port()),
1994        _ => address,
1995    }
1996}
1997
1998fn wildcard_inet_address(address: &InetSocketAddress) -> Option<InetSocketAddress> {
1999    match address.host() {
2000        "127.0.0.1" => Some(InetSocketAddress::new("0.0.0.0", address.port())),
2001        "::1" => Some(InetSocketAddress::new("::", address.port())),
2002        _ => None,
2003    }
2004}
2005
2006fn normalize_unix_socket_path(path: impl AsRef<str>) -> SocketResult<String> {
2007    let normalized = normalize_path(path.as_ref());
2008    if normalized == "/" {
2009        return Err(SocketTableError::invalid_argument(
2010            "unix socket path must not be empty or root",
2011        ));
2012    }
2013    Ok(normalized)
2014}
2015
2016fn lock_or_recover<'a, T>(mutex: &'a Mutex<T>) -> MutexGuard<'a, T> {
2017    match mutex.lock() {
2018        Ok(guard) => guard,
2019        Err(poisoned) => poisoned.into_inner(),
2020    }
2021}
2022
2023#[cfg(test)]
2024mod tests {
2025    use super::*;
2026
2027    /// Reads the monotonic socket-id counter without advancing it, so a test can
2028    /// observe whether a code path consumed an id.
2029    fn peek_next_socket_id(table: &SocketTable) -> SocketId {
2030        lock_or_recover(&table.inner.state).next_socket_id
2031    }
2032
2033    #[test]
2034    fn full_backlog_unix_connect_does_not_consume_socket_id() {
2035        let table = SocketTable::new();
2036        let path = "/tmp/leak-test/server.sock";
2037
2038        let listener = table.allocate(1, SocketSpec::unix_stream());
2039        table
2040            .bind_unix(listener.id, path)
2041            .expect("bind unix listener");
2042        table.listen(listener.id, 1).expect("listen with backlog 1");
2043
2044        // Fill the only backlog slot with one pending connection.
2045        let first = table.allocate(2, SocketSpec::unix_stream());
2046        table
2047            .connect_to_bound_unix_stream(first.id, path)
2048            .expect("first connect fills the backlog");
2049
2050        // A second connect must be rejected because the backlog is full, and it
2051        // must NOT consume a socket id (the counter is monotonic and never reclaims).
2052        let second = table.allocate(2, SocketSpec::unix_stream());
2053        let before = peek_next_socket_id(&table);
2054        let error = table
2055            .connect_to_bound_unix_stream(second.id, path)
2056            .expect_err("full-backlog connect must fail");
2057        assert_eq!(error.code(), "EAGAIN");
2058        let after = peek_next_socket_id(&table);
2059
2060        assert_eq!(
2061            before, after,
2062            "full-backlog unix connect leaked a socket id (counter advanced from {before} to {after})"
2063        );
2064    }
2065
2066    #[test]
2067    fn full_backlog_inet_connect_does_not_consume_socket_id() {
2068        let table = SocketTable::new();
2069        let target = InetSocketAddress::new("127.0.0.1", 49222);
2070
2071        let listener = table.allocate(1, SocketSpec::tcp());
2072        table
2073            .bind_inet(listener.id, target.clone())
2074            .expect("bind inet listener");
2075        table.listen(listener.id, 1).expect("listen with backlog 1");
2076
2077        // Fill the only backlog slot with one pending connection.
2078        let first = table.allocate(2, SocketSpec::tcp());
2079        table
2080            .connect_to_bound_inet_stream(first.id, target.clone())
2081            .expect("first connect fills the backlog");
2082
2083        // A second connect must be rejected because the backlog is full, and it
2084        // must NOT consume a socket id (the counter is monotonic and never reclaims).
2085        let second = table.allocate(2, SocketSpec::tcp());
2086        let before = peek_next_socket_id(&table);
2087        let error = table
2088            .connect_to_bound_inet_stream(second.id, target)
2089            .expect_err("full-backlog connect must fail");
2090        assert_eq!(error.code(), "EAGAIN");
2091        let after = peek_next_socket_id(&table);
2092
2093        assert_eq!(
2094            before, after,
2095            "full-backlog inet connect leaked a socket id (counter advanced from {before} to {after})"
2096        );
2097    }
2098}