1use crate::ynab::errors::{Error, ErrorResponse};
2use governor::{DefaultDirectRateLimiter, Quota, RateLimiter};
3use std::fmt;
4use std::num::NonZeroU32;
5use std::sync::Arc;
6use std::time::Duration;
7
8pub struct Client {
10 base_url: reqwest::Url,
11 http_client: reqwest::Client,
12 limiter: Option<Arc<DefaultDirectRateLimiter>>,
13 api_key: String,
14 timeout: Option<Duration>,
15}
16
17impl fmt::Debug for Client {
18 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
19 f.debug_struct("Client")
20 .field("base_url", &self.base_url)
21 .field("api_key", &"[redacted]")
22 .finish()
23 }
24}
25
26impl Client {
27 pub fn new(api_key: impl Into<String>) -> Result<Self, Error> {
39 let api_key = api_key.into();
40 let http_client = Self::build_http_client(&api_key, None)?;
41 Ok(Self {
42 base_url: reqwest::Url::parse("https://api.ynab.com/v1").unwrap(),
43 http_client,
44 limiter: None,
45 api_key,
46 timeout: None,
47 })
48 }
49
50 fn build_http_client(
51 api_key: &str,
52 timeout: Option<Duration>,
53 ) -> Result<reqwest::Client, Error> {
54 let mut headers = reqwest::header::HeaderMap::new();
55 headers.insert(
56 reqwest::header::AUTHORIZATION,
57 format!("Bearer {}", api_key)
58 .parse()
59 .expect("api key must be valid ASCII"),
60 );
61 let mut builder = reqwest::Client::builder().default_headers(headers);
62 if let Some(t) = timeout {
63 builder = builder.timeout(t);
64 }
65 builder.build().map_err(Into::into)
66 }
67
68 pub fn with_timeout(mut self, timeout: Duration) -> Result<Self, Error> {
82 self.http_client = Self::build_http_client(&self.api_key, Some(timeout))?;
83 self.timeout = Some(timeout);
84 Ok(self)
85 }
86
87 pub fn with_base_url(mut self, base_url: impl AsRef<str>) -> Result<Self, Error> {
89 self.base_url = reqwest::Url::parse(base_url.as_ref())?;
90 Ok(self)
91 }
92
93 pub fn with_rate_limiter(
114 mut self,
115 requests_per_hour: usize,
116 burst_volume: Option<usize>,
117 ) -> Result<Self, Error> {
118 let requests = NonZeroU32::new(requests_per_hour as u32)
119 .ok_or_else(|| Error::InvalidRateLimit("requests_per_hour must be non-zero".into()))?;
120
121 let quota = match burst_volume {
122 None => Quota::per_hour(requests),
123 Some(burst) => {
124 let effective = (requests_per_hour as u32)
125 .checked_sub(burst as u32)
126 .ok_or_else(|| {
127 Error::InvalidRateLimit(
128 "requests_per_hour must be greater than burst_volume".into(),
129 )
130 })?;
131 let effective_rate = NonZeroU32::new(effective).ok_or_else(|| {
132 Error::InvalidRateLimit(
133 "requests_per_hour - burst_volume must be non-zero".into(),
134 )
135 })?;
136 let burst = NonZeroU32::new(burst as u32).ok_or_else(|| {
137 Error::InvalidRateLimit("burst_volume must be non-zero".into())
138 })?;
139 Quota::per_hour(effective_rate).allow_burst(burst)
140 }
141 };
142
143 self.limiter = Some(Arc::new(RateLimiter::direct(quota)));
144 Ok(self)
145 }
146
147 pub(crate) async fn get<T: serde::de::DeserializeOwned, Q: serde::ser::Serialize + ?Sized>(
148 &self,
149 endpoint: &str,
150 params: Option<&Q>,
151 ) -> Result<T, Error> {
152 if let Some(limiter) = &self.limiter {
153 limiter.until_ready().await;
154 }
155
156 let mut url = self.base_url.clone();
157 url.path_segments_mut()
158 .expect("base URL must be a valid base")
159 .extend(endpoint.split('/'));
160
161 let mut builder = self.http_client.get(url);
162 if let Some(p) = params {
163 builder = builder.query(p);
164 }
165 let res = builder.send().await?;
166 let status = res.status();
167
168 if !status.is_success() {
169 let err_body: ErrorResponse = res.json().await?;
170 return Err(Error::new_api_error(status, err_body.error));
171 }
172
173 res.json().await.map_err(Into::into)
174 }
175
176 pub(crate) async fn post<T: serde::de::DeserializeOwned, B: serde::ser::Serialize>(
177 &self,
178 endpoint: &str,
179 body: B,
180 ) -> Result<T, Error> {
181 if let Some(limiter) = &self.limiter {
182 limiter.until_ready().await;
183 }
184 let mut url = self.base_url.clone();
185 url.path_segments_mut()
186 .expect("base URL must be a valid base")
187 .extend(endpoint.split('/'));
188
189 let res = self.http_client.post(url).json(&body).send().await?;
190 let status = res.status();
191
192 if !status.is_success() {
193 let err_body: ErrorResponse = res.json().await?;
194 return Err(Error::new_api_error(status, err_body.error));
195 }
196
197 res.json().await.map_err(Into::into)
198 }
199
200 pub(crate) async fn patch<T: serde::de::DeserializeOwned, B: serde::ser::Serialize>(
201 &self,
202 endpoint: &str,
203 body: B,
204 ) -> Result<T, Error> {
205 if let Some(limiter) = &self.limiter {
206 limiter.until_ready().await;
207 }
208 let mut url = self.base_url.clone();
209 url.path_segments_mut()
210 .expect("base URL must be a valid base")
211 .extend(endpoint.split('/'));
212
213 let res = self.http_client.patch(url).json(&body).send().await?;
214 let status = res.status();
215
216 if !status.is_success() {
217 let err_body: ErrorResponse = res.json().await?;
218 return Err(Error::new_api_error(status, err_body.error));
219 }
220
221 res.json().await.map_err(Into::into)
222 }
223
224 pub(crate) async fn put<T: serde::de::DeserializeOwned, B: serde::ser::Serialize>(
225 &self,
226 endpoint: &str,
227 body: B,
228 ) -> Result<T, Error> {
229 if let Some(limiter) = &self.limiter {
230 limiter.until_ready().await;
231 }
232 let mut url = self.base_url.clone();
233 url.path_segments_mut()
234 .expect("base URL must be a valid base")
235 .extend(endpoint.split('/'));
236
237 let res = self.http_client.put(url).json(&body).send().await?;
238 let status = res.status();
239
240 if !status.is_success() {
241 let err_body: ErrorResponse = res.json().await?;
242 return Err(Error::new_api_error(status, err_body.error));
243 }
244
245 res.json().await.map_err(Into::into)
246 }
247
248 pub(crate) async fn delete<T: serde::de::DeserializeOwned>(
249 &self,
250 endpoint: &str,
251 ) -> Result<T, Error> {
252 if let Some(limiter) = &self.limiter {
253 limiter.until_ready().await;
254 }
255 let mut url = self.base_url.clone();
256 url.path_segments_mut()
257 .expect("base URL must be a valid base")
258 .extend(endpoint.split('/'));
259
260 let res = self.http_client.delete(url).send().await?;
261 let status = res.status();
262
263 if !status.is_success() {
264 let err_body: ErrorResponse = res.json().await?;
265 return Err(Error::new_api_error(status, err_body.error));
266 }
267
268 res.json().await.map_err(Into::into)
269 }
270}