simple_mdns/
instance_information.rs1use std::{
2 collections::{HashMap, HashSet},
3 net::{IpAddr, SocketAddr},
4};
5
6use crate::conversion_utils::{hashmap_to_txt, ip_addr_to_resource_record, port_to_srv_record};
7use simple_dns::{Name, ResourceRecord};
8
9#[derive(Debug, PartialEq, Eq, Clone)]
12pub struct InstanceInformation {
13 instance_name: String,
14 pub ip_addresses: HashSet<IpAddr>,
16 pub ports: HashSet<u16>,
18 pub attributes: HashMap<String, Option<String>>,
20}
21
22impl<'a> InstanceInformation {
23 pub fn new(instance_name: String) -> Self {
25 Self {
26 instance_name,
27 ip_addresses: Default::default(),
28 ports: Default::default(),
29 attributes: HashMap::new(),
30 }
31 }
32
33 pub fn with_socket_address(mut self, socket_address: SocketAddr) -> Self {
36 self.ip_addresses.insert(socket_address.ip());
37 self.ports.insert(socket_address.port());
38
39 self
40 }
41
42 pub fn with_ip_address(mut self, ip_address: IpAddr) -> Self {
44 self.ip_addresses.insert(ip_address);
45 self
46 }
47
48 pub fn with_port(mut self, port: u16) -> Self {
50 self.ports.insert(port);
51 self
52 }
53
54 pub fn with_attribute(mut self, key: String, value: Option<String>) -> Self {
56 self.attributes.insert(key, value);
57 self
58 }
59
60 pub fn escaped_instance_name(&self) -> String {
65 escaped_instance_name(self.instance_name.as_str())
66 }
67
68 pub fn unescaped_instance_name(&self) -> String {
73 unescaped_instance_name(self.instance_name.as_str())
74 }
75
76 #[allow(unused)]
77 pub(crate) fn from_records<'b>(
78 service_name: &Name<'b>,
79 records: impl Iterator<Item = &'b ResourceRecord<'b>>,
80 ) -> Option<Self> {
81 let mut ip_addresses: HashSet<IpAddr> = Default::default();
82 let mut ports = HashSet::new();
83 let mut attributes = HashMap::new();
84
85 let mut instance_name: Option<String> = Default::default();
86 for resource in records {
87 if instance_name.is_none() {
88 instance_name = resource
89 .name
90 .without(service_name)
91 .map(|sub_domain| sub_domain.to_string());
92 }
93
94 match &resource.rdata {
95 simple_dns::rdata::RData::A(a) => {
96 ip_addresses.insert(std::net::Ipv4Addr::from(a.address).into());
97 }
98 simple_dns::rdata::RData::AAAA(aaaa) => {
99 ip_addresses.insert(std::net::Ipv6Addr::from(aaaa.address).into());
100 }
101 simple_dns::rdata::RData::TXT(txt) => attributes.extend(txt.attributes()),
102 simple_dns::rdata::RData::SRV(srv) => {
103 ports.insert(srv.port);
104 }
105 _ => {}
106 }
107 }
108
109 instance_name.map(|instance_name| InstanceInformation {
110 instance_name,
111 ip_addresses,
112 ports,
113 attributes,
114 })
115 }
116
117 pub fn into_records(
119 self,
120 service_name: &Name<'a>,
121 ttl: u32,
122 ) -> Result<Vec<ResourceRecord<'a>>, crate::SimpleMdnsError> {
123 let mut records = Vec::new();
124
125 for ip_address in self.ip_addresses {
126 records.push(ip_addr_to_resource_record(service_name, ip_address, ttl));
127 }
128
129 for port in self.ports {
130 records.push(port_to_srv_record(service_name, port, ttl));
131 }
132
133 records.push(hashmap_to_txt(service_name, self.attributes, ttl)?);
134
135 Ok(records)
136 }
137
138 pub fn get_socket_addresses(&'_ self) -> impl Iterator<Item = SocketAddr> + '_ {
140 self.ip_addresses.iter().copied().flat_map(move |addr| {
141 self.ports
142 .iter()
143 .copied()
144 .map(move |port| SocketAddr::new(addr, port))
145 })
146 }
147}
148
149impl std::hash::Hash for InstanceInformation {
150 fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
151 self.instance_name.hash(state);
152 self.ip_addresses.iter().for_each(|v| v.hash(state));
153 self.ports.iter().for_each(|v| v.hash(state));
154 }
155}
156
157fn escaped_instance_name(instance_name: &str) -> String {
158 let mut escaped_name = String::new();
159
160 for c in instance_name.chars() {
161 match c {
162 '.' => escaped_name.push_str("\\."),
163 '\\' => escaped_name.push_str("\\\\"),
164 _ => escaped_name.push(c),
165 }
166 }
167
168 escaped_name
169}
170
171fn unescaped_instance_name(instance_name: &str) -> String {
172 let mut unescaped_name = String::new();
173 let mut maybe_scaped = instance_name.chars();
174
175 while let Some(c) = maybe_scaped.next() {
176 match c {
177 '\\' => {
178 if let Some(c) = maybe_scaped.next() {
179 unescaped_name.push(c)
180 }
181 }
182 _ => unescaped_name.push(c),
183 }
184 }
185
186 unescaped_name
187}
188
189#[cfg(test)]
190mod tests {
191 use super::*;
192
193 #[test]
194 fn test_escaped_instance_name_simple() {
195 let instance_name = "example.com";
196 let expected_escaped_name = "example\\.com";
197
198 let escaped_name = escaped_instance_name(instance_name);
199
200 assert_eq!(escaped_name, expected_escaped_name);
201 }
202
203 #[test]
204 fn test_escaped_instance_name_with_backslash() {
205 let instance_name = "\\example.com";
206 let expected_escaped_name = "\\\\example\\.com";
207
208 let escaped_name = escaped_instance_name(instance_name);
209
210 assert_eq!(escaped_name, expected_escaped_name);
211 }
212
213 #[test]
214 fn test_escaped_instance_name_with_multiple_dots() {
215 let instance_name = "foo.bar.baz";
216 let expected_escaped_name = "foo\\.bar\\.baz";
217
218 let escaped_name = escaped_instance_name(instance_name);
219
220 assert_eq!(escaped_name, expected_escaped_name);
221 }
222
223 #[test]
224 fn test_unescaped_instance_name_simple() {
225 let instance_name = "example\\.com";
226 let expected_unescaped_name = "example.com";
227
228 let unescaped_name = unescaped_instance_name(instance_name);
229
230 assert_eq!(unescaped_name, expected_unescaped_name);
231 }
232
233 #[test]
234 fn test_unescaped_instance_name_with_multiple_slashes() {
235 let instance_name = r#"example\\\.com"#;
236 let expected_unescaped_name = "example\\.com";
237
238 let unescaped_name = unescaped_instance_name(instance_name);
239
240 assert_eq!(unescaped_name, expected_unescaped_name);
241 }
242}