1#![allow(unused)]
3
4mod types;
5
6use std::{fmt::Display, sync::Arc};
7use async_lock::Mutex;
8use cfg_if::cfg_if;
9use hyper::Uri;
10use serde::{de::DeserializeOwned, Serialize};
11use tracing::{error, info, warn};
12
13
14use crate::{dnsimple::types::{Accounts, CreateRecord, GetRecord, Records, UpdateRecord}, errors::{Error, Result}, http, Config, DnsProvider, RecordType};
15
16
17const API_BASE: &str = "https://api.dnsimple.com/v2";
18
19pub struct Auth {
20 key: String,
21}
22
23impl Auth {
24 fn get_header(&self) -> String {
25 format!("Bearer {}", self.key)
26 }
27}
28
29struct DnSimple {
30 config: Config,
31 endpoint: &'static str,
32 auth: Auth,
33 acc_id: Mutex<Option<u32>>,
34}
35
36impl DnSimple {
37 pub fn new(config: Config, auth: Auth, acc: Option<u32>) -> Self {
38 Self::new_with_endpoint(config, auth, acc, API_BASE)
39 }
40
41 fn new_with_endpoint(config: Config, auth: Auth, acc: Option<u32>, endpoint: &'static str) -> Self {
42 let acc_id = Mutex::new(acc);
43 DnSimple {
44 config,
45 endpoint,
46 auth,
47 acc_id,
48 }
49 }
50
51 async fn get_upstream_id(&self) -> Result<u32> {
52 info!("Fetching account ID from upstream");
53 let endpoint = format!("{}/accounts", self.endpoint);
54 let uri = endpoint.parse()
55 .map_err(|e| Error::UrlError(format!("Error: {endpoint} -> {e}")))?;
56
57 let accounts_p = http::get::<Accounts>(uri, Some(self.auth.get_header())).await?;
58
59 match accounts_p {
60 Some(accounts) if accounts.accounts.len() == 1 => {
61 Ok(accounts.accounts[0].id)
62 }
63 Some(accounts) if accounts.accounts.len() > 1 => {
64 Err(Error::ApiError("More than one account returned; you must specify the account ID to use".to_string()))
65 }
66 _ => {
68 Err(Error::ApiError("No accounts returned from upstream".to_string()))
69 }
70 }
71 }
72
73 async fn get_id(&self) -> Result<u32> {
74 let mut id_p = self.acc_id.lock().await;
78
79 if let Some(id) = *id_p {
80 return Ok(id);
81 }
82
83 let id = self.get_upstream_id().await?;
84 *id_p = Some(id);
85
86 Ok(id)
87 }
88
89 async fn get_upstream_record<T>(&self, rtype: RecordType, host: &str) -> Result<Option<GetRecord<T>>>
90 where
91 T: DeserializeOwned
92 {
93 let acc_id = self.get_id().await?;
94
95 let url = format!("{}/{acc_id}/zones/{}/records?name={host}&type={rtype}", self.endpoint, self.config.domain)
96 .parse()
97 .map_err(|e| Error::UrlError(format!("Error: {e}")))?;
98
99 let auth = self.auth.get_header();
100 let mut recs: Records<T> = match http::get(url, Some(auth)).await? {
101 Some(rec) => rec,
102 None => return Ok(None)
103 };
104
105 let nr = recs.records.len();
108 if nr > 1 {
109 error!("Returned number of IPs is {}, should be 1", nr);
110 return Err(Error::UnexpectedRecord(format!("Returned number of IPs is {nr}, should be 1")));
111 } else if nr == 0 {
112 warn!("No IP returned for {host}, continuing");
113 return Ok(None);
114 }
115
116
117 Ok(Some(recs.records.remove(0)))
118 }
119}
120
121
122impl DnsProvider for DnSimple {
123
124 async fn get_record<T>(&self, rtype: RecordType, host: &str) -> Result<Option<T> >
125 where
126 T: DeserializeOwned
127 {
128 let rec: GetRecord<T> = match self.get_upstream_record(rtype, host).await? {
129 Some(recs) => recs,
130 None => return Ok(None)
131 };
132
133
134 Ok(Some(rec.content))
135 }
136
137 async fn create_record<T>(&self, rtype: RecordType, host: &str, record: &T) -> Result<()>
138 where
139 T: Display + Sync,
140 {
141 let acc_id = self.get_id().await?;
142
143 let url = format!("{}/{acc_id}/zones/{}/records", self.endpoint, self.config.domain)
144 .parse()
145 .map_err(|e| Error::UrlError(format!("Error: {e}")))?;
146 let auth = self.auth.get_header();
147
148 let rec = CreateRecord {
149 name: host.to_string(),
150 rtype,
151 content: record.to_string(),
152 ttl: 300,
153 };
154
155 if self.config.dry_run {
156 info!("DRY-RUN: Would have sent {rec:?} to {url}");
157 return Ok(())
158 }
159 http::post::<CreateRecord>(url, &rec, Some(auth)).await?;
160
161 Ok(())
162 }
163
164 async fn update_record<T>(&self, rtype: RecordType, host: &str, urec: &T) -> Result<()>
165 where
166 T: DeserializeOwned + Display + Sync + Send,
167 {
168 let rec: GetRecord<T> = match self.get_upstream_record(rtype, host).await? {
169 Some(rec) => rec,
170 None => {
171 warn!("DELETE: Record {host} doesn't exist");
172 return Ok(());
173 }
174 };
175
176 let acc_id = self.get_id().await?;
177 let rid = rec.id;
178
179 let update = UpdateRecord {
180 content: urec.to_string(),
181 };
182
183 let url = format!("{}/{acc_id}/zones/{}/records/{rid}", self.endpoint, self.config.domain)
184 .parse()
185 .map_err(|e| Error::UrlError(format!("Error: {e}")))?;
186 if self.config.dry_run {
187 info!("DRY-RUN: Would have sent PATCH to {url}");
188 return Ok(())
189 }
190
191 let auth = self.auth.get_header();
192 http::patch(url, &update, Some(auth)).await?;
193
194 Ok(())
195 }
196
197 async fn delete_record(&self, rtype: RecordType, host: &str) -> Result<()> {
198 let rec: GetRecord<String> = match self.get_upstream_record(rtype, host).await? {
199 Some(rec) => rec,
200 None => {
201 warn!("DELETE: Record {host} doesn't exist");
202 return Ok(());
203 }
204 };
205
206 let acc_id = self.get_id().await?;
207 let rid = rec.id;
208
209 let url = format!("{}/{acc_id}/zones/{}/records/{rid}", self.endpoint, self.config.domain)
210 .parse()
211 .map_err(|e| Error::UrlError(format!("Error: {e}")))?;
212 if self.config.dry_run {
213 info!("DRY-RUN: Would have sent DELETE to {url}");
214 return Ok(())
215 }
216
217 let auth = self.auth.get_header();
218 http::delete(url, Some(auth)).await?;
219
220 Ok(())
221 }
222}
223
224
225
226#[cfg(test)]
227mod tests {
228 use crate::strip_quotes;
229
230 use super::*;
231 use std::{env, net::Ipv4Addr};
232 use random_string::charsets::ALPHANUMERIC;
233 use tracing_test::traced_test;
234
235 const TEST_API: &str = "https://api.sandbox.dnsimple.com/v2";
236
237 fn get_client() -> DnSimple {
238 let auth = Auth { key: env::var("DNSIMPLE_TOKEN").unwrap() };
239 let config = Config {
240 domain: env::var("DNSIMPLE_TEST_DOMAIN").unwrap(),
241 dry_run: false,
242 };
243 DnSimple::new_with_endpoint(config, auth, None, TEST_API)
244 }
245
246 async fn test_id_fetch() -> Result<()> {
247 let client = get_client();
248
249 let id = client.get_upstream_id().await?;
250 assert_eq!(2602, id);
251
252 Ok(())
253 }
254
255 async fn test_create_update_delete_ipv4() -> Result<()> {
257 let client = get_client();
258
259 let host = random_string::generate(16, ALPHANUMERIC);
260
261 let ip: Ipv4Addr = "1.1.1.1".parse()?;
263 client.create_record(RecordType::A, &host, &ip).await?;
264 let cur = client.get_record(RecordType::A, &host).await?;
265 assert_eq!(Some(ip), cur);
266
267
268 let ip: Ipv4Addr = "2.2.2.2".parse()?;
270 client.update_record(RecordType::A, &host, &ip).await?;
271 let cur = client.get_record(RecordType::A, &host).await?;
272 assert_eq!(Some(ip), cur);
273
274
275 client.delete_record(RecordType::A, &host).await?;
277 let del: Option<Ipv4Addr> = client.get_record(RecordType::A, &host).await?;
278 assert!(del.is_none());
279
280 Ok(())
281 }
282
283 async fn test_create_update_delete_txt() -> Result<()> {
284 let client = get_client();
285
286 let host = random_string::generate(16, ALPHANUMERIC);
287
288 let txt = "a text reference".to_string();
290 client.create_record(RecordType::TXT, &host, &txt).await?;
291 let cur: Option<String> = client.get_record(RecordType::TXT, &host).await?;
292 assert_eq!(txt, strip_quotes(&cur.unwrap()));
293
294
295 let txt = "another text reference".to_string();
297 client.update_record(RecordType::TXT, &host, &txt).await?;
298 let cur: Option<String> = client.get_record(RecordType::TXT, &host).await?;
299 assert_eq!(txt, strip_quotes(&cur.unwrap()));
300
301
302 client.delete_record(RecordType::TXT, &host).await?;
304 let del: Option<String> = client.get_record(RecordType::TXT, &host).await?;
305 assert!(del.is_none());
306
307 Ok(())
308 }
309
310 async fn test_create_update_delete_txt_default() -> Result<()> {
311 let client = get_client();
312
313 let host = random_string::generate(16, ALPHANUMERIC);
314
315 let txt = "a text reference".to_string();
317 client.create_txt_record(&host, &txt).await?;
318 let cur = client.get_txt_record(&host).await?;
319 assert_eq!(txt, strip_quotes(&cur.unwrap()));
320
321
322 let txt = "another text reference".to_string();
324 client.update_txt_record(&host, &txt).await?;
325 let cur = client.get_txt_record(&host).await?;
326 assert_eq!(txt, strip_quotes(&cur.unwrap()));
327
328
329 client.delete_txt_record(&host).await?;
331 let del = client.get_txt_record(&host).await?;
332 assert!(del.is_none());
333
334 Ok(())
335 }
336
337
338 #[cfg(feature = "smol")]
339 mod smol_tests {
340 use super::*;
341 use macro_rules_attribute::apply;
342 use smol_macros::test;
343
344 #[apply(test!)]
345 #[traced_test]
346 #[cfg_attr(not(feature = "test_dnsimple"), ignore = "DnSimple API test")]
347 async fn smol_id_fetch() -> Result<()> {
348 test_id_fetch().await?;
349 Ok(())
350 }
351
352
353 #[apply(test!)]
354 #[traced_test]
355 #[cfg_attr(not(feature = "test_dnsimple"), ignore = "DnSimple API test")]
356 async fn smol_create_update_v4() -> Result<()> {
357 test_create_update_delete_ipv4().await?;
358 Ok(())
359 }
360
361 #[apply(test!)]
362 #[traced_test]
363 #[cfg_attr(not(feature = "test_dnsimple"), ignore = "DnSimple API test")]
364 async fn smol_create_update_txt() -> Result<()> {
365 test_create_update_delete_txt().await?;
366 Ok(())
367 }
368
369 #[apply(test!)]
370 #[traced_test]
371 #[cfg_attr(not(feature = "test_dnsimple"), ignore = "DnSimple API test")]
372 async fn smol_create_update_default() -> Result<()> {
373 test_create_update_delete_txt_default().await?;
374 Ok(())
375 }
376 }
377
378 #[cfg(feature = "tokio")]
379 mod tokio_tests {
380 use super::*;
381
382 #[tokio::test]
383 #[traced_test]
384 #[cfg_attr(not(feature = "test_dnsimple"), ignore = "DnSimple API test")]
385 async fn tokio_id_fetch() -> Result<()> {
386 test_id_fetch().await
387 }
388
389 #[tokio::test]
390 #[traced_test]
391 #[cfg_attr(not(feature = "test_dnsimple"), ignore = "DnSimple API test")]
392 async fn tokio_create_update() -> Result<()> {
393 test_create_update_delete_ipv4().await
394 }
395 }
396
397
398}
399