simple_mdns/
instance_information.rs1use std::{
2 collections::{HashMap, HashSet},
3 net::{IpAddr, SocketAddr},
4};
5
6use crate::conversion_utils::{hashmap_to_txt, ip_addr_to_resource_record, port_to_srv_record};
7use simple_dns::{Name, ResourceRecord};
8
9#[derive(Debug, PartialEq, Eq, Clone)]
12pub struct InstanceInformation {
13 instance_name: String,
14 pub ip_addresses: HashSet<IpAddr>,
16 pub ports: HashSet<u16>,
18 pub attributes: HashMap<String, Option<String>>,
20}
21
22impl<'a> InstanceInformation {
23 pub fn new(instance_name: String) -> Self {
25 Self {
26 instance_name,
27 ip_addresses: Default::default(),
28 ports: Default::default(),
29 attributes: HashMap::new(),
30 }
31 }
32
33 pub fn with_socket_address(mut self, socket_address: SocketAddr) -> Self {
36 self.ip_addresses.insert(socket_address.ip());
37 self.ports.insert(socket_address.port());
38
39 self
40 }
41
42 pub fn with_ip_address(mut self, ip_address: IpAddr) -> Self {
44 self.ip_addresses.insert(ip_address);
45 self
46 }
47
48 pub fn with_port(mut self, port: u16) -> Self {
50 self.ports.insert(port);
51 self
52 }
53
54 pub fn with_attribute(mut self, key: String, value: Option<String>) -> Self {
56 self.attributes.insert(key, value);
57 self
58 }
59
60 pub fn escaped_instance_name(&self) -> String {
65 escaped_instance_name(self.instance_name.as_str())
66 }
67
68 pub fn unescaped_instance_name(&self) -> String {
73 unescaped_instance_name(self.instance_name.as_str())
74 }
75
76 pub(crate) fn from_records<'b>(
77 service_name: &Name<'b>,
78 records: impl Iterator<Item = &'b ResourceRecord<'b>>,
79 ) -> Option<Self> {
80 let mut ip_addresses: HashSet<IpAddr> = Default::default();
81 let mut ports = HashSet::new();
82 let mut attributes = HashMap::new();
83
84 let mut instance_name: Option<String> = Default::default();
85 for resource in records {
86 if instance_name.is_none() {
87 instance_name = resource
88 .name
89 .without(service_name)
90 .map(|sub_domain| sub_domain.to_string());
91 }
92
93 match &resource.rdata {
94 simple_dns::rdata::RData::A(a) => {
95 ip_addresses.insert(std::net::Ipv4Addr::from(a.address).into());
96 }
97 simple_dns::rdata::RData::AAAA(aaaa) => {
98 ip_addresses.insert(std::net::Ipv6Addr::from(aaaa.address).into());
99 }
100 simple_dns::rdata::RData::TXT(txt) => attributes.extend(txt.attributes()),
101 simple_dns::rdata::RData::SRV(srv) => {
102 ports.insert(srv.port);
103 }
104 _ => {}
105 }
106 }
107
108 instance_name.map(|instance_name| InstanceInformation {
109 instance_name,
110 ip_addresses,
111 ports,
112 attributes,
113 })
114 }
115
116 pub fn into_records(
118 self,
119 service_name: &Name<'a>,
120 ttl: u32,
121 ) -> Result<Vec<ResourceRecord<'a>>, crate::SimpleMdnsError> {
122 let mut records = Vec::new();
123
124 for ip_address in self.ip_addresses {
125 records.push(ip_addr_to_resource_record(service_name, ip_address, ttl));
126 }
127
128 for port in self.ports {
129 records.push(port_to_srv_record(service_name, port, ttl));
130 }
131
132 records.push(hashmap_to_txt(service_name, self.attributes, ttl)?);
133
134 Ok(records)
135 }
136
137 pub fn get_socket_addresses(&'_ self) -> impl Iterator<Item = SocketAddr> + '_ {
139 self.ip_addresses.iter().copied().flat_map(move |addr| {
140 self.ports
141 .iter()
142 .copied()
143 .map(move |port| SocketAddr::new(addr, port))
144 })
145 }
146}
147
148impl std::hash::Hash for InstanceInformation {
149 fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
150 self.instance_name.hash(state);
151 self.ip_addresses.iter().for_each(|v| v.hash(state));
152 self.ports.iter().for_each(|v| v.hash(state));
153 }
154}
155
156fn escaped_instance_name(instance_name: &str) -> String {
157 let mut escaped_name = String::new();
158
159 for c in instance_name.chars() {
160 match c {
161 '.' => escaped_name.push_str("\\."),
162 '\\' => escaped_name.push_str("\\\\"),
163 _ => escaped_name.push(c),
164 }
165 }
166
167 escaped_name
168}
169
170fn unescaped_instance_name(instance_name: &str) -> String {
171 let mut unescaped_name = String::new();
172 let mut maybe_scaped = instance_name.chars();
173
174 while let Some(c) = maybe_scaped.next() {
175 match c {
176 '\\' => {
177 if let Some(c) = maybe_scaped.next() {
178 unescaped_name.push(c)
179 }
180 }
181 _ => unescaped_name.push(c),
182 }
183 }
184
185 unescaped_name
186}
187
188#[cfg(test)]
189mod tests {
190 use super::*;
191
192 #[test]
193 fn test_escaped_instance_name_simple() {
194 let instance_name = "example.com";
195 let expected_escaped_name = "example\\.com";
196
197 let escaped_name = escaped_instance_name(instance_name);
198
199 assert_eq!(escaped_name, expected_escaped_name);
200 }
201
202 #[test]
203 fn test_escaped_instance_name_with_backslash() {
204 let instance_name = "\\example.com";
205 let expected_escaped_name = "\\\\example\\.com";
206
207 let escaped_name = escaped_instance_name(instance_name);
208
209 assert_eq!(escaped_name, expected_escaped_name);
210 }
211
212 #[test]
213 fn test_escaped_instance_name_with_multiple_dots() {
214 let instance_name = "foo.bar.baz";
215 let expected_escaped_name = "foo\\.bar\\.baz";
216
217 let escaped_name = escaped_instance_name(instance_name);
218
219 assert_eq!(escaped_name, expected_escaped_name);
220 }
221
222 #[test]
223 fn test_unescaped_instance_name_simple() {
224 let instance_name = "example\\.com";
225 let expected_unescaped_name = "example.com";
226
227 let unescaped_name = unescaped_instance_name(instance_name);
228
229 assert_eq!(unescaped_name, expected_unescaped_name);
230 }
231
232 #[test]
233 fn test_unescaped_instance_name_with_multiple_slashes() {
234 let instance_name = r#"example\\\.com"#;
235 let expected_unescaped_name = "example\\.com";
236
237 let unescaped_name = unescaped_instance_name(instance_name);
238
239 assert_eq!(unescaped_name, expected_unescaped_name);
240 }
241}