Skip to main content

sentinel_proxy/acme/dns/
provider.rs

1//! DNS provider trait for DNS-01 challenges
2//!
3//! Defines the interface that all DNS providers must implement.
4
5use async_trait::async_trait;
6use std::fmt::Debug;
7use thiserror::Error;
8
9/// Result type for DNS operations
10pub type DnsResult<T> = Result<T, DnsProviderError>;
11
12/// Errors that can occur during DNS provider operations
13#[derive(Debug, Error)]
14pub enum DnsProviderError {
15    /// Authentication failed with the DNS provider
16    #[error("Authentication failed: {0}")]
17    Authentication(String),
18
19    /// Zone not found for the domain
20    #[error("Zone not found for domain '{domain}'")]
21    ZoneNotFound { domain: String },
22
23    /// Record creation failed
24    #[error("Failed to create TXT record for '{record_name}': {message}")]
25    RecordCreation { record_name: String, message: String },
26
27    /// Record deletion failed
28    #[error("Failed to delete TXT record '{record_id}': {message}")]
29    RecordDeletion { record_id: String, message: String },
30
31    /// API request failed
32    #[error("API request failed: {0}")]
33    ApiRequest(String),
34
35    /// Rate limited by provider
36    #[error("Rate limited by DNS provider, retry after {retry_after_secs}s")]
37    RateLimited { retry_after_secs: u64 },
38
39    /// Request timeout
40    #[error("Request timed out after {elapsed_secs}s")]
41    Timeout { elapsed_secs: u64 },
42
43    /// Invalid configuration
44    #[error("Invalid configuration: {0}")]
45    Configuration(String),
46
47    /// Credential loading failed
48    #[error("Failed to load credentials: {0}")]
49    Credentials(String),
50
51    /// Domain not supported by this provider
52    #[error("Domain '{domain}' is not supported by this provider")]
53    UnsupportedDomain { domain: String },
54}
55
56/// Trait for DNS providers that support DNS-01 challenges
57///
58/// Implementations must be thread-safe and support concurrent operations.
59#[async_trait]
60pub trait DnsProvider: Send + Sync + Debug {
61    /// Returns the provider name (e.g., "hetzner", "cloudflare")
62    fn name(&self) -> &'static str;
63
64    /// Create a TXT record for DNS-01 challenge
65    ///
66    /// # Arguments
67    ///
68    /// * `domain` - The full domain name (e.g., "example.com" or "sub.example.com")
69    /// * `record_name` - The challenge record name (typically "_acme-challenge")
70    /// * `record_value` - The challenge value (base64url-encoded digest)
71    ///
72    /// # Returns
73    ///
74    /// The record ID for later cleanup, or an error
75    ///
76    /// # Implementation Notes
77    ///
78    /// - The full record name should be `{record_name}.{domain}`
79    /// - Use a short TTL (60s recommended) for challenge records
80    /// - If a record already exists, either update it or create a new one
81    async fn create_txt_record(
82        &self,
83        domain: &str,
84        record_name: &str,
85        record_value: &str,
86    ) -> DnsResult<String>;
87
88    /// Delete a TXT record after challenge validation
89    ///
90    /// # Arguments
91    ///
92    /// * `domain` - The domain the record belongs to
93    /// * `record_id` - The record ID returned from `create_txt_record`
94    ///
95    /// # Implementation Notes
96    ///
97    /// - Should not error if the record doesn't exist (idempotent)
98    /// - Called during cleanup, even if validation failed
99    async fn delete_txt_record(&self, domain: &str, record_id: &str) -> DnsResult<()>;
100
101    /// Check if the provider supports/manages the given domain
102    ///
103    /// # Arguments
104    ///
105    /// * `domain` - The domain to check
106    ///
107    /// # Returns
108    ///
109    /// `true` if the provider can manage DNS for this domain
110    ///
111    /// # Implementation Notes
112    ///
113    /// - Should check if a zone exists for this domain or its parent
114    /// - May cache zone information to reduce API calls
115    async fn supports_domain(&self, domain: &str) -> DnsResult<bool>;
116}
117
118/// ACME challenge record name prefix
119pub const ACME_CHALLENGE_RECORD: &str = "_acme-challenge";
120
121/// Recommended TTL for challenge records (60 seconds)
122pub const CHALLENGE_TTL: u32 = 60;
123
124/// Extract the parent domain from a domain name
125///
126/// For wildcard domains (*.example.com), returns the base domain.
127/// For subdomains (sub.example.com), returns the same domain.
128///
129/// The actual zone lookup is done by the provider.
130pub fn normalize_domain(domain: &str) -> &str {
131    domain.strip_prefix("*.").unwrap_or(domain)
132}
133
134/// Build the full ACME challenge record name
135///
136/// For `example.com`, returns `_acme-challenge.example.com`
137/// For `*.example.com`, returns `_acme-challenge.example.com`
138pub fn challenge_record_fqdn(domain: &str) -> String {
139    let normalized = normalize_domain(domain);
140    format!("{}.{}", ACME_CHALLENGE_RECORD, normalized)
141}
142
143#[cfg(test)]
144mod tests {
145    use super::*;
146    use std::collections::HashMap;
147    use std::sync::atomic::{AtomicU64, Ordering};
148    use parking_lot::Mutex;
149
150    #[test]
151    fn test_normalize_domain() {
152        assert_eq!(normalize_domain("example.com"), "example.com");
153        assert_eq!(normalize_domain("*.example.com"), "example.com");
154        assert_eq!(normalize_domain("sub.example.com"), "sub.example.com");
155        assert_eq!(normalize_domain("*.sub.example.com"), "sub.example.com");
156    }
157
158    #[test]
159    fn test_challenge_record_fqdn() {
160        assert_eq!(
161            challenge_record_fqdn("example.com"),
162            "_acme-challenge.example.com"
163        );
164        assert_eq!(
165            challenge_record_fqdn("*.example.com"),
166            "_acme-challenge.example.com"
167        );
168        assert_eq!(
169            challenge_record_fqdn("sub.example.com"),
170            "_acme-challenge.sub.example.com"
171        );
172    }
173
174    #[test]
175    fn test_dns_provider_error_display() {
176        let err = DnsProviderError::Authentication("bad token".to_string());
177        assert!(err.to_string().contains("Authentication failed"));
178
179        let err = DnsProviderError::ZoneNotFound { domain: "test.com".to_string() };
180        assert!(err.to_string().contains("test.com"));
181
182        let err = DnsProviderError::RecordCreation {
183            record_name: "_acme-challenge".to_string(),
184            message: "API error".to_string(),
185        };
186        assert!(err.to_string().contains("_acme-challenge"));
187
188        let err = DnsProviderError::RateLimited { retry_after_secs: 60 };
189        assert!(err.to_string().contains("60"));
190
191        let err = DnsProviderError::Timeout { elapsed_secs: 30 };
192        assert!(err.to_string().contains("30"));
193    }
194
195    /// Mock DNS provider for testing
196    #[derive(Debug)]
197    pub struct MockDnsProvider {
198        /// Records created: (domain, record_name) -> (record_id, value)
199        pub records: Mutex<HashMap<(String, String), (String, String)>>,
200        /// Supported domains
201        pub supported_domains: Vec<String>,
202        /// Counter for generating record IDs
203        pub record_counter: AtomicU64,
204        /// Whether to fail on create
205        pub fail_on_create: bool,
206        /// Whether to fail on delete
207        pub fail_on_delete: bool,
208    }
209
210    impl MockDnsProvider {
211        pub fn new(supported_domains: Vec<String>) -> Self {
212            Self {
213                records: Mutex::new(HashMap::new()),
214                supported_domains,
215                record_counter: AtomicU64::new(1),
216                fail_on_create: false,
217                fail_on_delete: false,
218            }
219        }
220
221        pub fn with_failure_on_create(mut self) -> Self {
222            self.fail_on_create = true;
223            self
224        }
225
226        pub fn with_failure_on_delete(mut self) -> Self {
227            self.fail_on_delete = true;
228            self
229        }
230
231        pub fn get_record(&self, domain: &str, record_name: &str) -> Option<(String, String)> {
232            self.records.lock().get(&(domain.to_string(), record_name.to_string())).cloned()
233        }
234
235        pub fn record_count(&self) -> usize {
236            self.records.lock().len()
237        }
238    }
239
240    #[async_trait]
241    impl DnsProvider for MockDnsProvider {
242        fn name(&self) -> &'static str {
243            "mock"
244        }
245
246        async fn create_txt_record(
247            &self,
248            domain: &str,
249            record_name: &str,
250            record_value: &str,
251        ) -> DnsResult<String> {
252            if self.fail_on_create {
253                return Err(DnsProviderError::RecordCreation {
254                    record_name: record_name.to_string(),
255                    message: "Mock failure".to_string(),
256                });
257            }
258
259            let record_id = format!("record-{}", self.record_counter.fetch_add(1, Ordering::SeqCst));
260            self.records.lock().insert(
261                (domain.to_string(), record_name.to_string()),
262                (record_id.clone(), record_value.to_string()),
263            );
264            Ok(record_id)
265        }
266
267        async fn delete_txt_record(&self, domain: &str, record_id: &str) -> DnsResult<()> {
268            if self.fail_on_delete {
269                return Err(DnsProviderError::RecordDeletion {
270                    record_id: record_id.to_string(),
271                    message: "Mock failure".to_string(),
272                });
273            }
274
275            // Find and remove the record by ID
276            let mut records = self.records.lock();
277            records.retain(|_, (id, _)| id != record_id);
278            Ok(())
279        }
280
281        async fn supports_domain(&self, domain: &str) -> DnsResult<bool> {
282            let normalized = normalize_domain(domain);
283            Ok(self.supported_domains.iter().any(|d| {
284                normalized == *d || normalized.ends_with(&format!(".{}", d))
285            }))
286        }
287    }
288
289    #[tokio::test]
290    async fn test_mock_provider_create_record() {
291        let provider = MockDnsProvider::new(vec!["example.com".to_string()]);
292
293        let record_id = provider
294            .create_txt_record("example.com", "_acme-challenge", "test-value")
295            .await
296            .unwrap();
297
298        assert!(record_id.starts_with("record-"));
299        assert_eq!(provider.record_count(), 1);
300
301        let (stored_id, stored_value) = provider
302            .get_record("example.com", "_acme-challenge")
303            .unwrap();
304        assert_eq!(stored_id, record_id);
305        assert_eq!(stored_value, "test-value");
306    }
307
308    #[tokio::test]
309    async fn test_mock_provider_delete_record() {
310        let provider = MockDnsProvider::new(vec!["example.com".to_string()]);
311
312        let record_id = provider
313            .create_txt_record("example.com", "_acme-challenge", "test-value")
314            .await
315            .unwrap();
316        assert_eq!(provider.record_count(), 1);
317
318        provider
319            .delete_txt_record("example.com", &record_id)
320            .await
321            .unwrap();
322        assert_eq!(provider.record_count(), 0);
323    }
324
325    #[tokio::test]
326    async fn test_mock_provider_supports_domain() {
327        let provider = MockDnsProvider::new(vec!["example.com".to_string()]);
328
329        assert!(provider.supports_domain("example.com").await.unwrap());
330        assert!(provider.supports_domain("sub.example.com").await.unwrap());
331        assert!(provider.supports_domain("*.example.com").await.unwrap());
332        assert!(!provider.supports_domain("other.com").await.unwrap());
333    }
334
335    #[tokio::test]
336    async fn test_mock_provider_failure_on_create() {
337        let provider = MockDnsProvider::new(vec!["example.com".to_string()])
338            .with_failure_on_create();
339
340        let result = provider
341            .create_txt_record("example.com", "_acme-challenge", "test-value")
342            .await;
343
344        assert!(result.is_err());
345        assert!(matches!(result.unwrap_err(), DnsProviderError::RecordCreation { .. }));
346    }
347
348    #[tokio::test]
349    async fn test_mock_provider_failure_on_delete() {
350        let provider = MockDnsProvider::new(vec!["example.com".to_string()])
351            .with_failure_on_delete();
352
353        let result = provider.delete_txt_record("example.com", "record-1").await;
354
355        assert!(result.is_err());
356        assert!(matches!(result.unwrap_err(), DnsProviderError::RecordDeletion { .. }));
357    }
358}