Skip to main content

simple_mdns/
instance_information.rs

1use std::{
2    collections::{HashMap, HashSet},
3    net::{IpAddr, SocketAddr},
4};
5
6use crate::conversion_utils::{hashmap_to_txt, ip_addr_to_resource_record, port_to_srv_record};
7use simple_dns::{Name, ResourceRecord};
8
9/// Represents a single instance of the service.
10/// Notice that it is not possible to associate a port to a single ip address, due to limitations of the DNS protocol
11#[derive(Debug, PartialEq, Eq, Clone)]
12pub struct InstanceInformation {
13    instance_name: String,
14    /// Ips for this instance
15    pub ip_addresses: HashSet<IpAddr>,
16    /// Ports for this instance
17    pub ports: HashSet<u16>,
18    /// Attributes for this instance
19    pub attributes: HashMap<String, Option<String>>,
20}
21
22impl<'a> InstanceInformation {
23    /// Creates an empty InstanceInformation
24    pub fn new(instance_name: String) -> Self {
25        Self {
26            instance_name,
27            ip_addresses: Default::default(),
28            ports: Default::default(),
29            attributes: HashMap::new(),
30        }
31    }
32
33    /// Adds the `ip_address` and `port` to this instance information. This is the equivalent of
34    /// `with_ip_address(ip_address).with_port(port)`
35    pub fn with_socket_address(mut self, socket_address: SocketAddr) -> Self {
36        self.ip_addresses.insert(socket_address.ip());
37        self.ports.insert(socket_address.port());
38
39        self
40    }
41
42    /// Adds `ip_address` to the list of ip addresses for this instance
43    pub fn with_ip_address(mut self, ip_address: IpAddr) -> Self {
44        self.ip_addresses.insert(ip_address);
45        self
46    }
47
48    /// Adds `port` to the list of ports for this instance
49    pub fn with_port(mut self, port: u16) -> Self {
50        self.ports.insert(port);
51        self
52    }
53
54    /// Add and attribute to the list of attributes
55    pub fn with_attribute(mut self, key: String, value: Option<String>) -> Self {
56        self.attributes.insert(key, value);
57        self
58    }
59
60    /// Escape the instance name
61    ///
62    /// . will be replaced with \.
63    /// \ will be replaced with \\
64    pub fn escaped_instance_name(&self) -> String {
65        escaped_instance_name(self.instance_name.as_str())
66    }
67
68    /// Unescape the instance name
69    ///
70    /// \. will be replaced with .
71    /// \\ will be replaced with \
72    pub fn unescaped_instance_name(&self) -> String {
73        unescaped_instance_name(self.instance_name.as_str())
74    }
75
76    #[allow(unused)]
77    pub(crate) fn from_records<'b>(
78        service_name: &Name<'b>,
79        records: impl Iterator<Item = &'b ResourceRecord<'b>>,
80    ) -> Option<Self> {
81        let mut ip_addresses: HashSet<IpAddr> = Default::default();
82        let mut ports = HashSet::new();
83        let mut attributes = HashMap::new();
84
85        let mut instance_name: Option<String> = Default::default();
86        for resource in records {
87            if instance_name.is_none() {
88                instance_name = resource
89                    .name
90                    .without(service_name)
91                    .map(|sub_domain| sub_domain.to_string());
92            }
93
94            match &resource.rdata {
95                simple_dns::rdata::RData::A(a) => {
96                    ip_addresses.insert(std::net::Ipv4Addr::from(a.address).into());
97                }
98                simple_dns::rdata::RData::AAAA(aaaa) => {
99                    ip_addresses.insert(std::net::Ipv6Addr::from(aaaa.address).into());
100                }
101                simple_dns::rdata::RData::TXT(txt) => attributes.extend(txt.attributes()),
102                simple_dns::rdata::RData::SRV(srv) => {
103                    ports.insert(srv.port);
104                }
105                _ => {}
106            }
107        }
108
109        instance_name.map(|instance_name| InstanceInformation {
110            instance_name,
111            ip_addresses,
112            ports,
113            attributes,
114        })
115    }
116
117    /// Transform into a [`Vec<ResourceRecord>`](`Vec<ResourceRecord>`)
118    pub fn into_records(
119        self,
120        service_name: &Name<'a>,
121        ttl: u32,
122    ) -> Result<Vec<ResourceRecord<'a>>, crate::SimpleMdnsError> {
123        let mut records = Vec::new();
124
125        for ip_address in self.ip_addresses {
126            records.push(ip_addr_to_resource_record(service_name, ip_address, ttl));
127        }
128
129        for port in self.ports {
130            records.push(port_to_srv_record(service_name, port, ttl));
131        }
132
133        records.push(hashmap_to_txt(service_name, self.attributes, ttl)?);
134
135        Ok(records)
136    }
137
138    /// Creates a Iterator of [`SocketAddr`](`std::net::SocketAddr`) for each ip address and port combination
139    pub fn get_socket_addresses(&'_ self) -> impl Iterator<Item = SocketAddr> + '_ {
140        self.ip_addresses.iter().copied().flat_map(move |addr| {
141            self.ports
142                .iter()
143                .copied()
144                .map(move |port| SocketAddr::new(addr, port))
145        })
146    }
147}
148
149impl std::hash::Hash for InstanceInformation {
150    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
151        self.instance_name.hash(state);
152        self.ip_addresses.iter().for_each(|v| v.hash(state));
153        self.ports.iter().for_each(|v| v.hash(state));
154    }
155}
156
157fn escaped_instance_name(instance_name: &str) -> String {
158    let mut escaped_name = String::new();
159
160    for c in instance_name.chars() {
161        match c {
162            '.' => escaped_name.push_str("\\."),
163            '\\' => escaped_name.push_str("\\\\"),
164            _ => escaped_name.push(c),
165        }
166    }
167
168    escaped_name
169}
170
171fn unescaped_instance_name(instance_name: &str) -> String {
172    let mut unescaped_name = String::new();
173    let mut maybe_scaped = instance_name.chars();
174
175    while let Some(c) = maybe_scaped.next() {
176        match c {
177            '\\' => {
178                if let Some(c) = maybe_scaped.next() {
179                    unescaped_name.push(c)
180                }
181            }
182            _ => unescaped_name.push(c),
183        }
184    }
185
186    unescaped_name
187}
188
189#[cfg(test)]
190mod tests {
191    use super::*;
192
193    #[test]
194    fn test_escaped_instance_name_simple() {
195        let instance_name = "example.com";
196        let expected_escaped_name = "example\\.com";
197
198        let escaped_name = escaped_instance_name(instance_name);
199
200        assert_eq!(escaped_name, expected_escaped_name);
201    }
202
203    #[test]
204    fn test_escaped_instance_name_with_backslash() {
205        let instance_name = "\\example.com";
206        let expected_escaped_name = "\\\\example\\.com";
207
208        let escaped_name = escaped_instance_name(instance_name);
209
210        assert_eq!(escaped_name, expected_escaped_name);
211    }
212
213    #[test]
214    fn test_escaped_instance_name_with_multiple_dots() {
215        let instance_name = "foo.bar.baz";
216        let expected_escaped_name = "foo\\.bar\\.baz";
217
218        let escaped_name = escaped_instance_name(instance_name);
219
220        assert_eq!(escaped_name, expected_escaped_name);
221    }
222
223    #[test]
224    fn test_unescaped_instance_name_simple() {
225        let instance_name = "example\\.com";
226        let expected_unescaped_name = "example.com";
227
228        let unescaped_name = unescaped_instance_name(instance_name);
229
230        assert_eq!(unescaped_name, expected_unescaped_name);
231    }
232
233    #[test]
234    fn test_unescaped_instance_name_with_multiple_slashes() {
235        let instance_name = r#"example\\\.com"#;
236        let expected_unescaped_name = "example\\.com";
237
238        let unescaped_name = unescaped_instance_name(instance_name);
239
240        assert_eq!(unescaped_name, expected_unescaped_name);
241    }
242}