Skip to main content

rust_ynab/ynab/
client.rs

1use crate::ynab::errors::{Error, ErrorResponse};
2use governor::{DefaultDirectRateLimiter, Quota, RateLimiter};
3use std::fmt;
4use std::num::NonZeroU32;
5use std::sync::Arc;
6use std::time::Duration;
7
8/// Client is the YNAB API client. Use Client::new() to create one.
9pub struct Client {
10    base_url: reqwest::Url,
11    http_client: reqwest::Client,
12    limiter: Option<Arc<DefaultDirectRateLimiter>>,
13    api_key: String,
14    timeout: Option<Duration>,
15}
16
17impl fmt::Debug for Client {
18    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
19        f.debug_struct("Client")
20            .field("base_url", &self.base_url)
21            .field("api_key", &"[redacted]")
22            .finish()
23    }
24}
25
26impl Client {
27    /// Creates a new client with the given Personal Access Token.
28    ///
29    /// # Examples
30    ///
31    /// ```no_run
32    /// use rust_ynab::Client;
33    ///
34    /// # async fn example() -> Result<(), Box<dyn std::error::Error>> {
35    /// let client = Client::new(&std::env::var("YNAB_TOKEN")?)?;
36    /// # Ok(()) }
37    /// ```
38    pub fn new(api_key: impl Into<String>) -> Result<Self, Error> {
39        let api_key = api_key.into();
40        let http_client = Self::build_http_client(&api_key, None)?;
41        Ok(Self {
42            base_url: reqwest::Url::parse("https://api.ynab.com/v1").unwrap(),
43            http_client,
44            limiter: None,
45            api_key,
46            timeout: None,
47        })
48    }
49
50    fn build_http_client(
51        api_key: &str,
52        timeout: Option<Duration>,
53    ) -> Result<reqwest::Client, Error> {
54        let mut headers = reqwest::header::HeaderMap::new();
55        headers.insert(
56            reqwest::header::AUTHORIZATION,
57            format!("Bearer {}", api_key)
58                .parse()
59                .expect("api key must be valid ASCII"),
60        );
61        let mut builder = reqwest::Client::builder().default_headers(headers);
62        if let Some(t) = timeout {
63            builder = builder.timeout(t);
64        }
65        builder.build().map_err(Into::into)
66    }
67
68    /// Sets the request timeout. Returns `self` for chaining.
69    ///
70    /// # Examples
71    ///
72    /// ```no_run
73    /// use rust_ynab::Client;
74    /// use std::time::Duration;
75    ///
76    /// # async fn example() -> Result<(), Box<dyn std::error::Error>> {
77    /// let client = Client::new(&std::env::var("YNAB_TOKEN")?)?
78    ///     .with_timeout(Duration::from_secs(30))?;
79    /// # Ok(()) }
80    /// ```
81    pub fn with_timeout(mut self, timeout: Duration) -> Result<Self, Error> {
82        self.http_client = Self::build_http_client(&self.api_key, Some(timeout))?;
83        self.timeout = Some(timeout);
84        Ok(self)
85    }
86
87    /// Overrides the base URL. Primarily useful for testing.
88    pub fn with_base_url(mut self, base_url: impl AsRef<str>) -> Result<Self, Error> {
89        self.base_url = reqwest::Url::parse(base_url.as_ref())?;
90        Ok(self)
91    }
92
93    /// Configures a token bucket rate limiter on the client. Returns `self` for chaining.
94    ///
95    /// `requests_per_hour` is the total allowed requests per hour.
96    /// `burst_volume` optionally allows a number of requests to be made immediately
97    /// before throttling begins. The effective sustained rate becomes
98    /// `requests_per_hour - burst_volume` to account for burst consumption.
99    /// If `None`, no burst is allowed and the full rate is sustained evenly.
100    ///
101    /// # Examples
102    ///
103    /// ```no_run
104    /// use rust_ynab::Client;
105    /// use std::time::Duration;
106    ///
107    /// # async fn example() -> Result<(), Box<dyn std::error::Error>> {
108    /// let client = Client::new(&std::env::var("YNAB_TOKEN")?)?
109    ///     .with_rate_limiter(200, Some(10))?  // 10 burst, then 190/hr
110    ///     .with_timeout(Duration::from_secs(30))?;
111    /// # Ok(()) }
112    /// ```
113    pub fn with_rate_limiter(
114        mut self,
115        requests_per_hour: usize,
116        burst_volume: Option<usize>,
117    ) -> Result<Self, Error> {
118        let requests = NonZeroU32::new(requests_per_hour as u32)
119            .ok_or_else(|| Error::InvalidRateLimit("requests_per_hour must be non-zero".into()))?;
120
121        let quota = match burst_volume {
122            None => Quota::per_hour(requests),
123            Some(burst) => {
124                let effective = (requests_per_hour as u32)
125                    .checked_sub(burst as u32)
126                    .ok_or_else(|| {
127                        Error::InvalidRateLimit(
128                            "requests_per_hour must be greater than burst_volume".into(),
129                        )
130                    })?;
131                let effective_rate = NonZeroU32::new(effective).ok_or_else(|| {
132                    Error::InvalidRateLimit(
133                        "requests_per_hour - burst_volume must be non-zero".into(),
134                    )
135                })?;
136                let burst = NonZeroU32::new(burst as u32).ok_or_else(|| {
137                    Error::InvalidRateLimit("burst_volume must be non-zero".into())
138                })?;
139                Quota::per_hour(effective_rate).allow_burst(burst)
140            }
141        };
142
143        self.limiter = Some(Arc::new(RateLimiter::direct(quota)));
144        Ok(self)
145    }
146
147    pub(crate) async fn get<T: serde::de::DeserializeOwned, Q: serde::ser::Serialize + ?Sized>(
148        &self,
149        endpoint: &str,
150        params: Option<&Q>,
151    ) -> Result<T, Error> {
152        if let Some(limiter) = &self.limiter {
153            limiter.until_ready().await;
154        }
155
156        let mut url = self.base_url.clone();
157        url.path_segments_mut()
158            .expect("base URL must be a valid base")
159            .extend(endpoint.split('/'));
160
161        let mut builder = self.http_client.get(url);
162        if let Some(p) = params {
163            builder = builder.query(p);
164        }
165        let res = builder.send().await?;
166        let status = res.status();
167
168        if !status.is_success() {
169            let err_body: ErrorResponse = res.json().await?;
170            return Err(Error::new_api_error(status, err_body.error));
171        }
172
173        res.json().await.map_err(Into::into)
174    }
175
176    pub(crate) async fn post<T: serde::de::DeserializeOwned, B: serde::ser::Serialize>(
177        &self,
178        endpoint: &str,
179        body: B,
180    ) -> Result<T, Error> {
181        if let Some(limiter) = &self.limiter {
182            limiter.until_ready().await;
183        }
184        let mut url = self.base_url.clone();
185        url.path_segments_mut()
186            .expect("base URL must be a valid base")
187            .extend(endpoint.split('/'));
188
189        let res = self.http_client.post(url).json(&body).send().await?;
190        let status = res.status();
191
192        if !status.is_success() {
193            let err_body: ErrorResponse = res.json().await?;
194            return Err(Error::new_api_error(status, err_body.error));
195        }
196
197        res.json().await.map_err(Into::into)
198    }
199
200    pub(crate) async fn patch<T: serde::de::DeserializeOwned, B: serde::ser::Serialize>(
201        &self,
202        endpoint: &str,
203        body: B,
204    ) -> Result<T, Error> {
205        if let Some(limiter) = &self.limiter {
206            limiter.until_ready().await;
207        }
208        let mut url = self.base_url.clone();
209        url.path_segments_mut()
210            .expect("base URL must be a valid base")
211            .extend(endpoint.split('/'));
212
213        let res = self.http_client.patch(url).json(&body).send().await?;
214        let status = res.status();
215
216        if !status.is_success() {
217            let err_body: ErrorResponse = res.json().await?;
218            return Err(Error::new_api_error(status, err_body.error));
219        }
220
221        res.json().await.map_err(Into::into)
222    }
223
224    pub(crate) async fn put<T: serde::de::DeserializeOwned, B: serde::ser::Serialize>(
225        &self,
226        endpoint: &str,
227        body: B,
228    ) -> Result<T, Error> {
229        if let Some(limiter) = &self.limiter {
230            limiter.until_ready().await;
231        }
232        let mut url = self.base_url.clone();
233        url.path_segments_mut()
234            .expect("base URL must be a valid base")
235            .extend(endpoint.split('/'));
236
237        let res = self.http_client.put(url).json(&body).send().await?;
238        let status = res.status();
239
240        if !status.is_success() {
241            let err_body: ErrorResponse = res.json().await?;
242            return Err(Error::new_api_error(status, err_body.error));
243        }
244
245        res.json().await.map_err(Into::into)
246    }
247
248    pub(crate) async fn delete<T: serde::de::DeserializeOwned>(
249        &self,
250        endpoint: &str,
251    ) -> Result<T, Error> {
252        if let Some(limiter) = &self.limiter {
253            limiter.until_ready().await;
254        }
255        let mut url = self.base_url.clone();
256        url.path_segments_mut()
257            .expect("base URL must be a valid base")
258            .extend(endpoint.split('/'));
259
260        let res = self.http_client.delete(url).send().await?;
261        let status = res.status();
262
263        if !status.is_success() {
264            let err_body: ErrorResponse = res.json().await?;
265            return Err(Error::new_api_error(status, err_body.error));
266        }
267
268        res.json().await.map_err(Into::into)
269    }
270}