ssh_commander_core/tools/
dns.rs1use crate::ssh::SshClient;
13use crate::tools::ToolsError;
14
15#[derive(Debug, Clone)]
17pub struct DnsQuery {
18 pub name: String,
19 pub record_type: DnsRecordType,
20}
21
22#[derive(Debug, Clone, Copy, PartialEq, Eq)]
23pub enum DnsRecordType {
24 A,
25 AAAA,
26 CNAME,
27 MX,
28 TXT,
29 NS,
30}
31
32impl DnsRecordType {
33 fn dig_arg(self) -> &'static str {
34 match self {
35 DnsRecordType::A => "A",
36 DnsRecordType::AAAA => "AAAA",
37 DnsRecordType::CNAME => "CNAME",
38 DnsRecordType::MX => "MX",
39 DnsRecordType::TXT => "TXT",
40 DnsRecordType::NS => "NS",
41 }
42 }
43}
44
45#[derive(Debug, Clone)]
47pub struct DnsAnswer {
48 pub perspective: String,
49 pub query: String,
50 pub record_type: DnsRecordType,
51 pub answers: Vec<String>,
52 pub error: Option<String>,
53 pub elapsed_ms: u64,
54}
55
56pub async fn dns_resolve_local(query: &DnsQuery) -> DnsAnswer {
59 let started = std::time::Instant::now();
60 let perspective = "local".to_string();
61 match query.record_type {
62 DnsRecordType::A | DnsRecordType::AAAA => {
63 let target = format!("{}:0", query.name);
65 match tokio::net::lookup_host(target).await {
66 Ok(iter) => {
67 let want_v6 = query.record_type == DnsRecordType::AAAA;
68 let answers: Vec<String> = iter
69 .filter(|sa| sa.is_ipv6() == want_v6)
70 .map(|sa| sa.ip().to_string())
71 .collect();
72 DnsAnswer {
73 perspective,
74 query: query.name.clone(),
75 record_type: query.record_type,
76 answers,
77 error: None,
78 elapsed_ms: started.elapsed().as_millis() as u64,
79 }
80 }
81 Err(e) => DnsAnswer {
82 perspective,
83 query: query.name.clone(),
84 record_type: query.record_type,
85 answers: vec![],
86 error: Some(e.to_string()),
87 elapsed_ms: started.elapsed().as_millis() as u64,
88 },
89 }
90 }
91 _ => {
92 let dig_cmd = format!(
93 "dig +short {} {}",
94 shell_safe(&query.name),
95 query.record_type.dig_arg()
96 );
97 match tokio::process::Command::new("sh")
98 .arg("-c")
99 .arg(&dig_cmd)
100 .output()
101 .await
102 {
103 Ok(out) if out.status.success() => DnsAnswer {
104 perspective,
105 query: query.name.clone(),
106 record_type: query.record_type,
107 answers: parse_dig_lines(&String::from_utf8_lossy(&out.stdout)),
108 error: None,
109 elapsed_ms: started.elapsed().as_millis() as u64,
110 },
111 Ok(out) => DnsAnswer {
112 perspective,
113 query: query.name.clone(),
114 record_type: query.record_type,
115 answers: vec![],
116 error: Some(String::from_utf8_lossy(&out.stderr).to_string()),
117 elapsed_ms: started.elapsed().as_millis() as u64,
118 },
119 Err(e) => DnsAnswer {
120 perspective,
121 query: query.name.clone(),
122 record_type: query.record_type,
123 answers: vec![],
124 error: Some(e.to_string()),
125 elapsed_ms: started.elapsed().as_millis() as u64,
126 },
127 }
128 }
129 }
130}
131
132pub async fn dns_resolve_remote(
137 client: &SshClient,
138 perspective_label: &str,
139 query: &DnsQuery,
140) -> DnsAnswer {
141 let started = std::time::Instant::now();
142 let cmd = format!(
143 "dig +short {} {}",
144 shell_safe(&query.name),
145 query.record_type.dig_arg()
146 );
147 match client.execute_command_full(&cmd).await {
148 Ok(out) if out.is_success() => DnsAnswer {
149 perspective: perspective_label.to_string(),
150 query: query.name.clone(),
151 record_type: query.record_type,
152 answers: parse_dig_lines(&out.stdout),
153 error: None,
154 elapsed_ms: started.elapsed().as_millis() as u64,
155 },
156 Ok(out) => {
157 let answers = parse_dig_lines(&out.stdout);
162 let err = if answers.is_empty() && !out.stderr.trim().is_empty() {
163 Some(out.stderr.trim().to_string())
164 } else if answers.is_empty() && out.exit_code.unwrap_or(0) != 0 {
165 Some(format!(
166 "dig exited {}",
167 out.exit_code
168 .map(|c| c.to_string())
169 .unwrap_or_else(|| "?".into())
170 ))
171 } else {
172 None
173 };
174 DnsAnswer {
175 perspective: perspective_label.to_string(),
176 query: query.name.clone(),
177 record_type: query.record_type,
178 answers,
179 error: err,
180 elapsed_ms: started.elapsed().as_millis() as u64,
181 }
182 }
183 Err(e) => DnsAnswer {
184 perspective: perspective_label.to_string(),
185 query: query.name.clone(),
186 record_type: query.record_type,
187 answers: vec![],
188 error: Some(e.to_string()),
189 elapsed_ms: started.elapsed().as_millis() as u64,
190 },
191 }
192}
193
194fn parse_dig_lines(out: &str) -> Vec<String> {
195 out.lines()
196 .map(|l| l.trim())
197 .filter(|l| !l.is_empty() && !l.starts_with(';'))
198 .map(|l| l.to_string())
199 .collect()
200}
201
202fn shell_safe(name: &str) -> String {
205 name.chars()
206 .filter(|c| c.is_ascii_alphanumeric() || matches!(c, '.' | '-' | '_'))
207 .collect()
208}
209
210#[allow(dead_code)]
214fn _ensure_error_unused(e: ToolsError) -> ToolsError {
215 e
216}
217
218#[cfg(test)]
219mod tests {
220 use super::*;
221
222 #[test]
223 fn parses_dig_short_output() {
224 let raw = "\
2251.2.3.4
2265.6.7.8
227;; Truncated, retrying in TCP mode.
228";
229 let parsed = parse_dig_lines(raw);
230 assert_eq!(parsed, vec!["1.2.3.4", "5.6.7.8"]);
231 }
232
233 #[test]
234 fn shell_safe_strips_metachars() {
235 assert_eq!(shell_safe("example.com"), "example.com");
236 assert_eq!(shell_safe("ex; rm -rf /"), "exrm-rf");
237 assert_eq!(shell_safe("foo$(bad)"), "foobad");
238 }
239}