Skip to main content

tokio_multicast/
socket.rs

1use std::collections::HashSet;
2use std::net::{IpAddr, Ipv4Addr, SocketAddr};
3use std::sync::Arc;
4use std::sync::Mutex;
5use std::time::SystemTime;
6
7use bytes::Bytes;
8use tokio::net::UdpSocket;
9
10use crate::{
11    raw, Datagram, Interface, Membership, MulticastConfig, MulticastError,
12    MulticastSocketBuilder, RecvMeta, Result,
13};
14
15#[derive(Debug)]
16pub struct MulticastSocket {
17    socket: UdpSocket,
18    config: MulticastConfig,
19    memberships: Arc<Mutex<HashSet<Membership>>>,
20}
21
22impl MulticastSocket {
23    pub fn builder() -> MulticastSocketBuilder {
24        MulticastSocketBuilder::new()
25    }
26
27    pub(crate) async fn from_config(config: MulticastConfig) -> Result<Self> {
28        let std_socket = raw::build_std_socket(&config)?;
29        let socket = UdpSocket::from_std(std_socket)?;
30        let memberships = config.memberships.iter().cloned().collect();
31
32        Ok(Self {
33            socket,
34            config,
35            memberships: Arc::new(Mutex::new(memberships)),
36        })
37    }
38
39    pub fn local_addr(&self) -> std::io::Result<SocketAddr> {
40        self.socket.local_addr()
41    }
42
43    pub fn config(&self) -> &MulticastConfig {
44        &self.config
45    }
46
47    pub fn memberships(&self) -> HashSet<Membership> {
48        self.memberships.lock().unwrap().clone()
49    }
50
51    pub async fn join(&self, membership: Membership) -> Result<()> {
52        let mut state = self.memberships.lock().unwrap();
53        if state.contains(&membership) {
54            return Ok(());
55        }
56
57        self.apply_join(&membership)?;
58        state.insert(membership);
59        Ok(())
60    }
61
62    pub async fn leave(&self, membership: &Membership) -> Result<()> {
63        let mut state = self.memberships.lock().unwrap();
64        if !state.contains(membership) {
65            return Ok(());
66        }
67
68        self.apply_leave(membership)?;
69        state.remove(membership);
70        Ok(())
71    }
72
73    pub async fn recv_from(&self, buf: &mut [u8]) -> std::io::Result<(usize, SocketAddr)> {
74        self.socket.recv_from(buf).await
75    }
76
77    pub async fn recv(&self, buf: &mut [u8]) -> std::io::Result<usize> {
78        self.socket.recv(buf).await
79    }
80
81    pub async fn recv_datagram(&self, buf_size: usize) -> std::io::Result<Datagram> {
82        let mut buf = vec![0_u8; buf_size];
83        let (size, peer) = self.socket.recv_from(&mut buf).await?;
84        buf.truncate(size);
85
86        Ok(Datagram {
87            payload: Bytes::from(buf),
88            meta: RecvMeta {
89                peer,
90                local_addr: self.local_addr().ok(),
91                interface: None,
92                group: self.primary_group(),
93                timestamp: Some(SystemTime::now()),
94            },
95        })
96    }
97
98    pub async fn send_to(&self, payload: &[u8], target: SocketAddr) -> std::io::Result<usize> {
99        self.socket.send_to(payload, target).await
100    }
101
102    pub async fn send_to_group(&self, payload: &[u8]) -> Result<usize> {
103        let group = self
104            .primary_group()
105            .ok_or(MulticastError::NoMembershipsConfigured)?;
106        let target = match group {
107            IpAddr::V4(group) => SocketAddr::from((group, self.config.bind_addr.port())),
108            IpAddr::V6(group) => {
109                let scope_id = self.ipv6_scope_id();
110                raw::group_as_v6_socket(group, self.config.bind_addr.port(), scope_id).into()
111            }
112        };
113        Ok(self.socket.send_to(payload, target).await?)
114    }
115
116    fn ipv6_scope_id(&self) -> u32 {
117        match self.config.outbound_interface.as_ref() {
118            Some(Interface::V6(index)) => *index,
119            _ => match self.config.inbound_interface.as_ref() {
120                Some(Interface::V6(index)) => *index,
121                _ => 0,
122            },
123        }
124    }
125
126    fn primary_group(&self) -> Option<IpAddr> {
127        self.config.memberships.first().map(Membership::group)
128    }
129
130    fn apply_join(&self, membership: &Membership) -> Result<()> {
131        match membership {
132            Membership::AnySource {
133                group: IpAddr::V4(group),
134            } => {
135                let interface = match &self.config.inbound_interface {
136                    Some(crate::Interface::V4(addr)) => *addr,
137                    _ => Ipv4Addr::UNSPECIFIED,
138                };
139                self.socket.join_multicast_v4(*group, interface)?;
140                Ok(())
141            }
142            Membership::AnySource {
143                group: IpAddr::V6(group),
144            } => {
145                let index = match &self.config.inbound_interface {
146                    Some(crate::Interface::V6(index)) => *index,
147                    _ => 0,
148                };
149                self.socket.join_multicast_v6(group, index)?;
150                Ok(())
151            }
152            Membership::SourceSpecific { .. } => Err(MulticastError::UnsupportedOption(
153                "dynamic source-specific membership",
154            )),
155        }
156    }
157
158    fn apply_leave(&self, membership: &Membership) -> Result<()> {
159        match membership {
160            Membership::AnySource {
161                group: IpAddr::V4(group),
162            } => {
163                let interface = match &self.config.inbound_interface {
164                    Some(crate::Interface::V4(addr)) => *addr,
165                    _ => Ipv4Addr::UNSPECIFIED,
166                };
167                self.socket.leave_multicast_v4(*group, interface)?;
168                Ok(())
169            }
170            Membership::AnySource {
171                group: IpAddr::V6(group),
172            } => {
173                let index = match &self.config.inbound_interface {
174                    Some(crate::Interface::V6(index)) => *index,
175                    _ => 0,
176                };
177                self.socket.leave_multicast_v6(group, index)?;
178                Ok(())
179            }
180            Membership::SourceSpecific { .. } => Err(MulticastError::UnsupportedOption(
181                "dynamic source-specific membership",
182            )),
183        }
184    }
185}