Skip to main content

sentinel_proxy/acme/dns/
provider.rs

1//! DNS provider trait for DNS-01 challenges
2//!
3//! Defines the interface that all DNS providers must implement.
4
5use async_trait::async_trait;
6use std::fmt::Debug;
7use thiserror::Error;
8
9/// Result type for DNS operations
10pub type DnsResult<T> = Result<T, DnsProviderError>;
11
12/// Errors that can occur during DNS provider operations
13#[derive(Debug, Error)]
14pub enum DnsProviderError {
15    /// Authentication failed with the DNS provider
16    #[error("Authentication failed: {0}")]
17    Authentication(String),
18
19    /// Zone not found for the domain
20    #[error("Zone not found for domain '{domain}'")]
21    ZoneNotFound { domain: String },
22
23    /// Record creation failed
24    #[error("Failed to create TXT record for '{record_name}': {message}")]
25    RecordCreation {
26        record_name: String,
27        message: String,
28    },
29
30    /// Record deletion failed
31    #[error("Failed to delete TXT record '{record_id}': {message}")]
32    RecordDeletion { record_id: String, message: String },
33
34    /// API request failed
35    #[error("API request failed: {0}")]
36    ApiRequest(String),
37
38    /// Rate limited by provider
39    #[error("Rate limited by DNS provider, retry after {retry_after_secs}s")]
40    RateLimited { retry_after_secs: u64 },
41
42    /// Request timeout
43    #[error("Request timed out after {elapsed_secs}s")]
44    Timeout { elapsed_secs: u64 },
45
46    /// Invalid configuration
47    #[error("Invalid configuration: {0}")]
48    Configuration(String),
49
50    /// Credential loading failed
51    #[error("Failed to load credentials: {0}")]
52    Credentials(String),
53
54    /// Domain not supported by this provider
55    #[error("Domain '{domain}' is not supported by this provider")]
56    UnsupportedDomain { domain: String },
57}
58
59/// Trait for DNS providers that support DNS-01 challenges
60///
61/// Implementations must be thread-safe and support concurrent operations.
62#[async_trait]
63pub trait DnsProvider: Send + Sync + Debug {
64    /// Returns the provider name (e.g., "hetzner", "cloudflare")
65    fn name(&self) -> &'static str;
66
67    /// Create a TXT record for DNS-01 challenge
68    ///
69    /// # Arguments
70    ///
71    /// * `domain` - The full domain name (e.g., "example.com" or "sub.example.com")
72    /// * `record_name` - The challenge record name (typically "_acme-challenge")
73    /// * `record_value` - The challenge value (base64url-encoded digest)
74    ///
75    /// # Returns
76    ///
77    /// The record ID for later cleanup, or an error
78    ///
79    /// # Implementation Notes
80    ///
81    /// - The full record name should be `{record_name}.{domain}`
82    /// - Use a short TTL (60s recommended) for challenge records
83    /// - If a record already exists, either update it or create a new one
84    async fn create_txt_record(
85        &self,
86        domain: &str,
87        record_name: &str,
88        record_value: &str,
89    ) -> DnsResult<String>;
90
91    /// Delete a TXT record after challenge validation
92    ///
93    /// # Arguments
94    ///
95    /// * `domain` - The domain the record belongs to
96    /// * `record_id` - The record ID returned from `create_txt_record`
97    ///
98    /// # Implementation Notes
99    ///
100    /// - Should not error if the record doesn't exist (idempotent)
101    /// - Called during cleanup, even if validation failed
102    async fn delete_txt_record(&self, domain: &str, record_id: &str) -> DnsResult<()>;
103
104    /// Check if the provider supports/manages the given domain
105    ///
106    /// # Arguments
107    ///
108    /// * `domain` - The domain to check
109    ///
110    /// # Returns
111    ///
112    /// `true` if the provider can manage DNS for this domain
113    ///
114    /// # Implementation Notes
115    ///
116    /// - Should check if a zone exists for this domain or its parent
117    /// - May cache zone information to reduce API calls
118    async fn supports_domain(&self, domain: &str) -> DnsResult<bool>;
119}
120
121/// ACME challenge record name prefix
122pub const ACME_CHALLENGE_RECORD: &str = "_acme-challenge";
123
124/// Recommended TTL for challenge records (60 seconds)
125pub const CHALLENGE_TTL: u32 = 60;
126
127/// Extract the parent domain from a domain name
128///
129/// For wildcard domains (*.example.com), returns the base domain.
130/// For subdomains (sub.example.com), returns the same domain.
131///
132/// The actual zone lookup is done by the provider.
133pub fn normalize_domain(domain: &str) -> &str {
134    domain.strip_prefix("*.").unwrap_or(domain)
135}
136
137/// Build the full ACME challenge record name
138///
139/// For `example.com`, returns `_acme-challenge.example.com`
140/// For `*.example.com`, returns `_acme-challenge.example.com`
141pub fn challenge_record_fqdn(domain: &str) -> String {
142    let normalized = normalize_domain(domain);
143    format!("{}.{}", ACME_CHALLENGE_RECORD, normalized)
144}
145
146#[cfg(test)]
147mod tests {
148    use super::*;
149    use parking_lot::Mutex;
150    use std::collections::HashMap;
151    use std::sync::atomic::{AtomicU64, Ordering};
152
153    #[test]
154    fn test_normalize_domain() {
155        assert_eq!(normalize_domain("example.com"), "example.com");
156        assert_eq!(normalize_domain("*.example.com"), "example.com");
157        assert_eq!(normalize_domain("sub.example.com"), "sub.example.com");
158        assert_eq!(normalize_domain("*.sub.example.com"), "sub.example.com");
159    }
160
161    #[test]
162    fn test_challenge_record_fqdn() {
163        assert_eq!(
164            challenge_record_fqdn("example.com"),
165            "_acme-challenge.example.com"
166        );
167        assert_eq!(
168            challenge_record_fqdn("*.example.com"),
169            "_acme-challenge.example.com"
170        );
171        assert_eq!(
172            challenge_record_fqdn("sub.example.com"),
173            "_acme-challenge.sub.example.com"
174        );
175    }
176
177    #[test]
178    fn test_dns_provider_error_display() {
179        let err = DnsProviderError::Authentication("bad token".to_string());
180        assert!(err.to_string().contains("Authentication failed"));
181
182        let err = DnsProviderError::ZoneNotFound {
183            domain: "test.com".to_string(),
184        };
185        assert!(err.to_string().contains("test.com"));
186
187        let err = DnsProviderError::RecordCreation {
188            record_name: "_acme-challenge".to_string(),
189            message: "API error".to_string(),
190        };
191        assert!(err.to_string().contains("_acme-challenge"));
192
193        let err = DnsProviderError::RateLimited {
194            retry_after_secs: 60,
195        };
196        assert!(err.to_string().contains("60"));
197
198        let err = DnsProviderError::Timeout { elapsed_secs: 30 };
199        assert!(err.to_string().contains("30"));
200    }
201
202    /// Mock DNS provider for testing
203    #[derive(Debug)]
204    pub struct MockDnsProvider {
205        /// Records created: (domain, record_name) -> (record_id, value)
206        pub records: Mutex<HashMap<(String, String), (String, String)>>,
207        /// Supported domains
208        pub supported_domains: Vec<String>,
209        /// Counter for generating record IDs
210        pub record_counter: AtomicU64,
211        /// Whether to fail on create
212        pub fail_on_create: bool,
213        /// Whether to fail on delete
214        pub fail_on_delete: bool,
215    }
216
217    impl MockDnsProvider {
218        pub fn new(supported_domains: Vec<String>) -> Self {
219            Self {
220                records: Mutex::new(HashMap::new()),
221                supported_domains,
222                record_counter: AtomicU64::new(1),
223                fail_on_create: false,
224                fail_on_delete: false,
225            }
226        }
227
228        pub fn with_failure_on_create(mut self) -> Self {
229            self.fail_on_create = true;
230            self
231        }
232
233        pub fn with_failure_on_delete(mut self) -> Self {
234            self.fail_on_delete = true;
235            self
236        }
237
238        pub fn get_record(&self, domain: &str, record_name: &str) -> Option<(String, String)> {
239            self.records
240                .lock()
241                .get(&(domain.to_string(), record_name.to_string()))
242                .cloned()
243        }
244
245        pub fn record_count(&self) -> usize {
246            self.records.lock().len()
247        }
248    }
249
250    #[async_trait]
251    impl DnsProvider for MockDnsProvider {
252        fn name(&self) -> &'static str {
253            "mock"
254        }
255
256        async fn create_txt_record(
257            &self,
258            domain: &str,
259            record_name: &str,
260            record_value: &str,
261        ) -> DnsResult<String> {
262            if self.fail_on_create {
263                return Err(DnsProviderError::RecordCreation {
264                    record_name: record_name.to_string(),
265                    message: "Mock failure".to_string(),
266                });
267            }
268
269            let record_id = format!(
270                "record-{}",
271                self.record_counter.fetch_add(1, Ordering::SeqCst)
272            );
273            self.records.lock().insert(
274                (domain.to_string(), record_name.to_string()),
275                (record_id.clone(), record_value.to_string()),
276            );
277            Ok(record_id)
278        }
279
280        async fn delete_txt_record(&self, domain: &str, record_id: &str) -> DnsResult<()> {
281            if self.fail_on_delete {
282                return Err(DnsProviderError::RecordDeletion {
283                    record_id: record_id.to_string(),
284                    message: "Mock failure".to_string(),
285                });
286            }
287
288            // Find and remove the record by ID
289            let mut records = self.records.lock();
290            records.retain(|_, (id, _)| id != record_id);
291            Ok(())
292        }
293
294        async fn supports_domain(&self, domain: &str) -> DnsResult<bool> {
295            let normalized = normalize_domain(domain);
296            Ok(self
297                .supported_domains
298                .iter()
299                .any(|d| normalized == *d || normalized.ends_with(&format!(".{}", d))))
300        }
301    }
302
303    #[tokio::test]
304    async fn test_mock_provider_create_record() {
305        let provider = MockDnsProvider::new(vec!["example.com".to_string()]);
306
307        let record_id = provider
308            .create_txt_record("example.com", "_acme-challenge", "test-value")
309            .await
310            .unwrap();
311
312        assert!(record_id.starts_with("record-"));
313        assert_eq!(provider.record_count(), 1);
314
315        let (stored_id, stored_value) = provider
316            .get_record("example.com", "_acme-challenge")
317            .unwrap();
318        assert_eq!(stored_id, record_id);
319        assert_eq!(stored_value, "test-value");
320    }
321
322    #[tokio::test]
323    async fn test_mock_provider_delete_record() {
324        let provider = MockDnsProvider::new(vec!["example.com".to_string()]);
325
326        let record_id = provider
327            .create_txt_record("example.com", "_acme-challenge", "test-value")
328            .await
329            .unwrap();
330        assert_eq!(provider.record_count(), 1);
331
332        provider
333            .delete_txt_record("example.com", &record_id)
334            .await
335            .unwrap();
336        assert_eq!(provider.record_count(), 0);
337    }
338
339    #[tokio::test]
340    async fn test_mock_provider_supports_domain() {
341        let provider = MockDnsProvider::new(vec!["example.com".to_string()]);
342
343        assert!(provider.supports_domain("example.com").await.unwrap());
344        assert!(provider.supports_domain("sub.example.com").await.unwrap());
345        assert!(provider.supports_domain("*.example.com").await.unwrap());
346        assert!(!provider.supports_domain("other.com").await.unwrap());
347    }
348
349    #[tokio::test]
350    async fn test_mock_provider_failure_on_create() {
351        let provider =
352            MockDnsProvider::new(vec!["example.com".to_string()]).with_failure_on_create();
353
354        let result = provider
355            .create_txt_record("example.com", "_acme-challenge", "test-value")
356            .await;
357
358        assert!(result.is_err());
359        assert!(matches!(
360            result.unwrap_err(),
361            DnsProviderError::RecordCreation { .. }
362        ));
363    }
364
365    #[tokio::test]
366    async fn test_mock_provider_failure_on_delete() {
367        let provider =
368            MockDnsProvider::new(vec!["example.com".to_string()]).with_failure_on_delete();
369
370        let result = provider.delete_txt_record("example.com", "record-1").await;
371
372        assert!(result.is_err());
373        assert!(matches!(
374            result.unwrap_err(),
375            DnsProviderError::RecordDeletion { .. }
376        ));
377    }
378}