simple_mdns/
instance_information.rs

1use std::{
2    collections::{HashMap, HashSet},
3    net::{IpAddr, SocketAddr},
4};
5
6use crate::conversion_utils::{hashmap_to_txt, ip_addr_to_resource_record, port_to_srv_record};
7use simple_dns::{Name, ResourceRecord};
8
9/// Represents a single instance of the service.
10/// Notice that it is not possible to associate a port to a single ip address, due to limitations of the DNS protocol
11#[derive(Debug, PartialEq, Eq, Clone)]
12pub struct InstanceInformation {
13    instance_name: String,
14    /// Ips for this instance
15    pub ip_addresses: HashSet<IpAddr>,
16    /// Ports for this instance
17    pub ports: HashSet<u16>,
18    /// Attributes for this instance
19    pub attributes: HashMap<String, Option<String>>,
20}
21
22impl<'a> InstanceInformation {
23    /// Creates an empty InstanceInformation
24    pub fn new(instance_name: String) -> Self {
25        Self {
26            instance_name,
27            ip_addresses: Default::default(),
28            ports: Default::default(),
29            attributes: HashMap::new(),
30        }
31    }
32
33    /// Adds the `ip_address` and `port` to this instance information. This is the equivalent of
34    /// `with_ip_address(ip_address).with_port(port)`
35    pub fn with_socket_address(mut self, socket_address: SocketAddr) -> Self {
36        self.ip_addresses.insert(socket_address.ip());
37        self.ports.insert(socket_address.port());
38
39        self
40    }
41
42    /// Adds `ip_address` to the list of ip addresses for this instance
43    pub fn with_ip_address(mut self, ip_address: IpAddr) -> Self {
44        self.ip_addresses.insert(ip_address);
45        self
46    }
47
48    /// Adds `port` to the list of ports for this instance
49    pub fn with_port(mut self, port: u16) -> Self {
50        self.ports.insert(port);
51        self
52    }
53
54    /// Add and attribute to the list of attributes
55    pub fn with_attribute(mut self, key: String, value: Option<String>) -> Self {
56        self.attributes.insert(key, value);
57        self
58    }
59
60    /// Escape the instance name
61    ///
62    /// . will be replaced with \.
63    /// \ will be replaced with \\
64    pub fn escaped_instance_name(&self) -> String {
65        escaped_instance_name(self.instance_name.as_str())
66    }
67
68    /// Unescape the instance name
69    ///
70    /// \. will be replaced with .
71    /// \\ will be replaced with \
72    pub fn unescaped_instance_name(&self) -> String {
73        unescaped_instance_name(self.instance_name.as_str())
74    }
75
76    pub(crate) fn from_records<'b>(
77        service_name: &Name<'b>,
78        records: impl Iterator<Item = &'b ResourceRecord<'b>>,
79    ) -> Option<Self> {
80        let mut ip_addresses: HashSet<IpAddr> = Default::default();
81        let mut ports = HashSet::new();
82        let mut attributes = HashMap::new();
83
84        let mut instance_name: Option<String> = Default::default();
85        for resource in records {
86            if instance_name.is_none() {
87                instance_name = resource
88                    .name
89                    .without(service_name)
90                    .map(|sub_domain| sub_domain.to_string());
91            }
92
93            match &resource.rdata {
94                simple_dns::rdata::RData::A(a) => {
95                    ip_addresses.insert(std::net::Ipv4Addr::from(a.address).into());
96                }
97                simple_dns::rdata::RData::AAAA(aaaa) => {
98                    ip_addresses.insert(std::net::Ipv6Addr::from(aaaa.address).into());
99                }
100                simple_dns::rdata::RData::TXT(txt) => attributes.extend(txt.attributes()),
101                simple_dns::rdata::RData::SRV(srv) => {
102                    ports.insert(srv.port);
103                }
104                _ => {}
105            }
106        }
107
108        instance_name.map(|instance_name| InstanceInformation {
109            instance_name,
110            ip_addresses,
111            ports,
112            attributes,
113        })
114    }
115
116    /// Transform into a [Vec<ResourceRecord>](`Vec<ResourceRecord>`)
117    pub fn into_records(
118        self,
119        service_name: &Name<'a>,
120        ttl: u32,
121    ) -> Result<Vec<ResourceRecord<'a>>, crate::SimpleMdnsError> {
122        let mut records = Vec::new();
123
124        for ip_address in self.ip_addresses {
125            records.push(ip_addr_to_resource_record(service_name, ip_address, ttl));
126        }
127
128        for port in self.ports {
129            records.push(port_to_srv_record(service_name, port, ttl));
130        }
131
132        records.push(hashmap_to_txt(service_name, self.attributes, ttl)?);
133
134        Ok(records)
135    }
136
137    /// Creates a Iterator of [`SocketAddr`](`std::net::SocketAddr`) for each ip address and port combination
138    pub fn get_socket_addresses(&'_ self) -> impl Iterator<Item = SocketAddr> + '_ {
139        self.ip_addresses.iter().copied().flat_map(move |addr| {
140            self.ports
141                .iter()
142                .copied()
143                .map(move |port| SocketAddr::new(addr, port))
144        })
145    }
146}
147
148impl std::hash::Hash for InstanceInformation {
149    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
150        self.instance_name.hash(state);
151        self.ip_addresses.iter().for_each(|v| v.hash(state));
152        self.ports.iter().for_each(|v| v.hash(state));
153    }
154}
155
156fn escaped_instance_name(instance_name: &str) -> String {
157    let mut escaped_name = String::new();
158
159    for c in instance_name.chars() {
160        match c {
161            '.' => escaped_name.push_str("\\."),
162            '\\' => escaped_name.push_str("\\\\"),
163            _ => escaped_name.push(c),
164        }
165    }
166
167    escaped_name
168}
169
170fn unescaped_instance_name(instance_name: &str) -> String {
171    let mut unescaped_name = String::new();
172    let mut maybe_scaped = instance_name.chars();
173
174    while let Some(c) = maybe_scaped.next() {
175        match c {
176            '\\' => {
177                if let Some(c) = maybe_scaped.next() {
178                    unescaped_name.push(c)
179                }
180            }
181            _ => unescaped_name.push(c),
182        }
183    }
184
185    unescaped_name
186}
187
188#[cfg(test)]
189mod tests {
190    use super::*;
191
192    #[test]
193    fn test_escaped_instance_name_simple() {
194        let instance_name = "example.com";
195        let expected_escaped_name = "example\\.com";
196
197        let escaped_name = escaped_instance_name(instance_name);
198
199        assert_eq!(escaped_name, expected_escaped_name);
200    }
201
202    #[test]
203    fn test_escaped_instance_name_with_backslash() {
204        let instance_name = "\\example.com";
205        let expected_escaped_name = "\\\\example\\.com";
206
207        let escaped_name = escaped_instance_name(instance_name);
208
209        assert_eq!(escaped_name, expected_escaped_name);
210    }
211
212    #[test]
213    fn test_escaped_instance_name_with_multiple_dots() {
214        let instance_name = "foo.bar.baz";
215        let expected_escaped_name = "foo\\.bar\\.baz";
216
217        let escaped_name = escaped_instance_name(instance_name);
218
219        assert_eq!(escaped_name, expected_escaped_name);
220    }
221
222    #[test]
223    fn test_unescaped_instance_name_simple() {
224        let instance_name = "example\\.com";
225        let expected_unescaped_name = "example.com";
226
227        let unescaped_name = unescaped_instance_name(instance_name);
228
229        assert_eq!(unescaped_name, expected_unescaped_name);
230    }
231
232    #[test]
233    fn test_unescaped_instance_name_with_multiple_slashes() {
234        let instance_name = r#"example\\\.com"#;
235        let expected_unescaped_name = "example\\.com";
236
237        let unescaped_name = unescaped_instance_name(instance_name);
238
239        assert_eq!(unescaped_name, expected_unescaped_name);
240    }
241}