Skip to main content

ssh_commander_core/tools/
dns.rs

1//! Multi-perspective DNS resolution.
2//!
3//! For a given hostname, query each connected SSH host's resolver plus
4//! the local Mac's resolver in parallel. Useful for spotting DNS
5//! drift — e.g. internal vs public split-horizon, or stale caches on a
6//! specific host.
7//!
8//! Implementation: `dig +short <name> <type>` on the remote (cheap and
9//! universally available). Local perspective uses `tokio::net::lookup_host`
10//! so we don't shell out for the most common path.
11
12use crate::ssh::SshClient;
13use crate::tools::ToolsError;
14
15/// What to resolve and how.
16#[derive(Debug, Clone)]
17pub struct DnsQuery {
18    pub name: String,
19    pub record_type: DnsRecordType,
20}
21
22#[derive(Debug, Clone, Copy, PartialEq, Eq)]
23pub enum DnsRecordType {
24    A,
25    AAAA,
26    CNAME,
27    MX,
28    TXT,
29    NS,
30}
31
32impl DnsRecordType {
33    fn dig_arg(self) -> &'static str {
34        match self {
35            DnsRecordType::A => "A",
36            DnsRecordType::AAAA => "AAAA",
37            DnsRecordType::CNAME => "CNAME",
38            DnsRecordType::MX => "MX",
39            DnsRecordType::TXT => "TXT",
40            DnsRecordType::NS => "NS",
41        }
42    }
43}
44
45/// One perspective's answer.
46#[derive(Debug, Clone)]
47pub struct DnsAnswer {
48    pub perspective: String,
49    pub query: String,
50    pub record_type: DnsRecordType,
51    pub answers: Vec<String>,
52    pub error: Option<String>,
53    pub elapsed_ms: u64,
54}
55
56/// Resolve from the local Mac. `tokio::net::lookup_host` only handles A/AAAA;
57/// non-address record types fall back to spawning `dig` locally.
58pub async fn dns_resolve_local(query: &DnsQuery) -> DnsAnswer {
59    let started = std::time::Instant::now();
60    let perspective = "local".to_string();
61    match query.record_type {
62        DnsRecordType::A | DnsRecordType::AAAA => {
63            // lookup_host wants host:port — use a dummy port; we throw it away.
64            let target = format!("{}:0", query.name);
65            match tokio::net::lookup_host(target).await {
66                Ok(iter) => {
67                    let want_v6 = query.record_type == DnsRecordType::AAAA;
68                    let answers: Vec<String> = iter
69                        .filter(|sa| sa.is_ipv6() == want_v6)
70                        .map(|sa| sa.ip().to_string())
71                        .collect();
72                    DnsAnswer {
73                        perspective,
74                        query: query.name.clone(),
75                        record_type: query.record_type,
76                        answers,
77                        error: None,
78                        elapsed_ms: started.elapsed().as_millis() as u64,
79                    }
80                }
81                Err(e) => DnsAnswer {
82                    perspective,
83                    query: query.name.clone(),
84                    record_type: query.record_type,
85                    answers: vec![],
86                    error: Some(e.to_string()),
87                    elapsed_ms: started.elapsed().as_millis() as u64,
88                },
89            }
90        }
91        _ => {
92            let dig_cmd = format!(
93                "dig +short {} {}",
94                shell_safe(&query.name),
95                query.record_type.dig_arg()
96            );
97            match tokio::process::Command::new("sh")
98                .arg("-c")
99                .arg(&dig_cmd)
100                .output()
101                .await
102            {
103                Ok(out) if out.status.success() => DnsAnswer {
104                    perspective,
105                    query: query.name.clone(),
106                    record_type: query.record_type,
107                    answers: parse_dig_lines(&String::from_utf8_lossy(&out.stdout)),
108                    error: None,
109                    elapsed_ms: started.elapsed().as_millis() as u64,
110                },
111                Ok(out) => DnsAnswer {
112                    perspective,
113                    query: query.name.clone(),
114                    record_type: query.record_type,
115                    answers: vec![],
116                    error: Some(String::from_utf8_lossy(&out.stderr).to_string()),
117                    elapsed_ms: started.elapsed().as_millis() as u64,
118                },
119                Err(e) => DnsAnswer {
120                    perspective,
121                    query: query.name.clone(),
122                    record_type: query.record_type,
123                    answers: vec![],
124                    error: Some(e.to_string()),
125                    elapsed_ms: started.elapsed().as_millis() as u64,
126                },
127            }
128        }
129    }
130}
131
132/// Resolve from one remote SSH host using `dig +short`.
133///
134/// `perspective_label` is what the UI shows in the table — typically the
135/// connection's display name or `<user>@<host>`.
136pub async fn dns_resolve_remote(
137    client: &SshClient,
138    perspective_label: &str,
139    query: &DnsQuery,
140) -> DnsAnswer {
141    let started = std::time::Instant::now();
142    let cmd = format!(
143        "dig +short {} {}",
144        shell_safe(&query.name),
145        query.record_type.dig_arg()
146    );
147    match client.execute_command_full(&cmd).await {
148        Ok(out) if out.is_success() => DnsAnswer {
149            perspective: perspective_label.to_string(),
150            query: query.name.clone(),
151            record_type: query.record_type,
152            answers: parse_dig_lines(&out.stdout),
153            error: None,
154            elapsed_ms: started.elapsed().as_millis() as u64,
155        },
156        Ok(out) => {
157            // `dig` returns non-zero when it can't reach a resolver but
158            // also when the answer set is empty on certain platforms.
159            // Treat empty stdout + empty stderr as "no answer" rather
160            // than an error so the UI shows a clean blank cell.
161            let answers = parse_dig_lines(&out.stdout);
162            let err = if answers.is_empty() && !out.stderr.trim().is_empty() {
163                Some(out.stderr.trim().to_string())
164            } else if answers.is_empty() && out.exit_code.unwrap_or(0) != 0 {
165                Some(format!(
166                    "dig exited {}",
167                    out.exit_code
168                        .map(|c| c.to_string())
169                        .unwrap_or_else(|| "?".into())
170                ))
171            } else {
172                None
173            };
174            DnsAnswer {
175                perspective: perspective_label.to_string(),
176                query: query.name.clone(),
177                record_type: query.record_type,
178                answers,
179                error: err,
180                elapsed_ms: started.elapsed().as_millis() as u64,
181            }
182        }
183        Err(e) => DnsAnswer {
184            perspective: perspective_label.to_string(),
185            query: query.name.clone(),
186            record_type: query.record_type,
187            answers: vec![],
188            error: Some(e.to_string()),
189            elapsed_ms: started.elapsed().as_millis() as u64,
190        },
191    }
192}
193
194fn parse_dig_lines(out: &str) -> Vec<String> {
195    out.lines()
196        .map(|l| l.trim())
197        .filter(|l| !l.is_empty() && !l.starts_with(';'))
198        .map(|l| l.to_string())
199        .collect()
200}
201
202/// Reject hostnames containing shell metacharacters. DNS names can never
203/// validly contain these — anything else is an injection attempt.
204fn shell_safe(name: &str) -> String {
205    name.chars()
206        .filter(|c| c.is_ascii_alphanumeric() || matches!(c, '.' | '-' | '_'))
207        .collect()
208}
209
210// Allow ToolsError to surface here for FFI parity even though the public
211// resolvers above return `DnsAnswer` directly (errors live inside the
212// answer struct so multi-host fan-out can collect partial results).
213#[allow(dead_code)]
214fn _ensure_error_unused(e: ToolsError) -> ToolsError {
215    e
216}
217
218#[cfg(test)]
219mod tests {
220    use super::*;
221
222    #[test]
223    fn parses_dig_short_output() {
224        let raw = "\
2251.2.3.4
2265.6.7.8
227;; Truncated, retrying in TCP mode.
228";
229        let parsed = parse_dig_lines(raw);
230        assert_eq!(parsed, vec!["1.2.3.4", "5.6.7.8"]);
231    }
232
233    #[test]
234    fn shell_safe_strips_metachars() {
235        assert_eq!(shell_safe("example.com"), "example.com");
236        assert_eq!(shell_safe("ex; rm -rf /"), "exrm-rf");
237        assert_eq!(shell_safe("foo$(bad)"), "foobad");
238    }
239}