1use crate::ynab::errors::{Error, ErrorResponse};
2use governor::{DefaultDirectRateLimiter, Quota, RateLimiter};
3use std::fmt;
4use std::num::NonZeroU32;
5use std::sync::Arc;
6use std::time::Duration;
7
8pub struct Client {
10 base_url: reqwest::Url,
11 http_client: reqwest::Client,
12 limiter: Option<Arc<DefaultDirectRateLimiter>>,
13 api_key: String,
14 timeout: Option<Duration>,
15}
16
17impl fmt::Debug for Client {
18 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
19 f.debug_struct("Client")
20 .field("base_url", &self.base_url)
21 .field("api_key", &"[redacted]")
22 .finish()
23 }
24}
25
26impl Client {
27 pub fn new(api_key: impl Into<String>) -> Result<Self, Error> {
29 let api_key = api_key.into();
30 let http_client = Self::build_http_client(&api_key, None)?;
31 Ok(Self {
32 base_url: reqwest::Url::parse("https://api.ynab.com/v1").unwrap(),
33 http_client,
34 limiter: None,
35 api_key,
36 timeout: None,
37 })
38 }
39
40 fn build_http_client(
41 api_key: &str,
42 timeout: Option<Duration>,
43 ) -> Result<reqwest::Client, Error> {
44 let mut headers = reqwest::header::HeaderMap::new();
45 headers.insert(
46 reqwest::header::AUTHORIZATION,
47 format!("Bearer {}", api_key)
48 .parse()
49 .expect("api key must be valid ASCII"),
50 );
51 let mut builder = reqwest::Client::builder().default_headers(headers);
52 if let Some(t) = timeout {
53 builder = builder.timeout(t);
54 }
55 builder.build().map_err(Into::into)
56 }
57
58 pub fn with_timeout(mut self, timeout: Duration) -> Result<Self, Error> {
60 self.http_client = Self::build_http_client(&self.api_key, Some(timeout))?;
61 self.timeout = Some(timeout);
62 Ok(self)
63 }
64
65 pub fn with_base_url(mut self, base_url: impl AsRef<str>) -> Result<Self, Error> {
67 self.base_url = reqwest::Url::parse(base_url.as_ref())?;
68 Ok(self)
69 }
70
71 pub fn with_rate_limiter(
83 mut self,
84 requests_per_hour: usize,
85 burst_volume: Option<usize>,
86 ) -> Result<Self, Error> {
87 let requests = NonZeroU32::new(requests_per_hour as u32)
88 .ok_or_else(|| Error::InvalidRateLimit("requests_per_hour must be non-zero".into()))?;
89
90 let quota = match burst_volume {
91 None => Quota::per_hour(requests),
92 Some(burst) => {
93 let effective = (requests_per_hour as u32)
94 .checked_sub(burst as u32)
95 .ok_or_else(|| {
96 Error::InvalidRateLimit(
97 "requests_per_hour must be greater than burst_volume".into(),
98 )
99 })?;
100 let effective_rate = NonZeroU32::new(effective).ok_or_else(|| {
101 Error::InvalidRateLimit(
102 "requests_per_hour - burst_volume must be non-zero".into(),
103 )
104 })?;
105 let burst = NonZeroU32::new(burst as u32).ok_or_else(|| {
106 Error::InvalidRateLimit("burst_volume must be non-zero".into())
107 })?;
108 Quota::per_hour(effective_rate).allow_burst(burst)
109 }
110 };
111
112 self.limiter = Some(Arc::new(RateLimiter::direct(quota)));
113 Ok(self)
114 }
115
116 pub(crate) async fn get<T: serde::de::DeserializeOwned, Q: serde::ser::Serialize + ?Sized>(
117 &self,
118 endpoint: &str,
119 params: Option<&Q>,
120 ) -> Result<T, Error> {
121 if let Some(limiter) = &self.limiter {
122 limiter.until_ready().await;
123 }
124
125 let mut url = self.base_url.clone();
126 url.path_segments_mut()
127 .expect("base URL must be a valid base")
128 .extend(endpoint.split('/'));
129
130 let mut builder = self.http_client.get(url);
131 if let Some(p) = params {
132 builder = builder.query(p);
133 }
134 let res = builder.send().await?;
135 let status = res.status();
136
137 if !status.is_success() {
138 let err_body: ErrorResponse = res.json().await?;
139 return Err(Error::new_api_error(status, err_body.error));
140 }
141
142 res.json().await.map_err(Into::into)
143 }
144
145 pub(crate) async fn post<T: serde::de::DeserializeOwned, B: serde::ser::Serialize>(
146 &self,
147 endpoint: &str,
148 body: B,
149 ) -> Result<T, Error> {
150 if let Some(limiter) = &self.limiter {
151 limiter.until_ready().await;
152 }
153 let mut url = self.base_url.clone();
154 url.path_segments_mut()
155 .expect("base URL must be a valid base")
156 .extend(endpoint.split('/'));
157
158 let res = self.http_client.post(url).json(&body).send().await?;
159 let status = res.status();
160
161 if !status.is_success() {
162 let err_body: ErrorResponse = res.json().await?;
163 return Err(Error::new_api_error(status, err_body.error));
164 }
165
166 res.json().await.map_err(Into::into)
167 }
168
169 pub(crate) async fn patch<T: serde::de::DeserializeOwned, B: serde::ser::Serialize>(
170 &self,
171 endpoint: &str,
172 body: B,
173 ) -> Result<T, Error> {
174 if let Some(limiter) = &self.limiter {
175 limiter.until_ready().await;
176 }
177 let mut url = self.base_url.clone();
178 url.path_segments_mut()
179 .expect("base URL must be a valid base")
180 .extend(endpoint.split('/'));
181
182 let res = self.http_client.patch(url).json(&body).send().await?;
183 let status = res.status();
184
185 if !status.is_success() {
186 let err_body: ErrorResponse = res.json().await?;
187 return Err(Error::new_api_error(status, err_body.error));
188 }
189
190 res.json().await.map_err(Into::into)
191 }
192
193 pub(crate) async fn put<T: serde::de::DeserializeOwned, B: serde::ser::Serialize>(
194 &self,
195 endpoint: &str,
196 body: B,
197 ) -> Result<T, Error> {
198 if let Some(limiter) = &self.limiter {
199 limiter.until_ready().await;
200 }
201 let mut url = self.base_url.clone();
202 url.path_segments_mut()
203 .expect("base URL must be a valid base")
204 .extend(endpoint.split('/'));
205
206 let res = self.http_client.put(url).json(&body).send().await?;
207 let status = res.status();
208
209 if !status.is_success() {
210 let err_body: ErrorResponse = res.json().await?;
211 return Err(Error::new_api_error(status, err_body.error));
212 }
213
214 res.json().await.map_err(Into::into)
215 }
216
217 pub(crate) async fn delete<T: serde::de::DeserializeOwned>(
218 &self,
219 endpoint: &str,
220 ) -> Result<T, Error> {
221 if let Some(limiter) = &self.limiter {
222 limiter.until_ready().await;
223 }
224 let mut url = self.base_url.clone();
225 url.path_segments_mut()
226 .expect("base URL must be a valid base")
227 .extend(endpoint.split('/'));
228
229 let res = self.http_client.delete(url).send().await?;
230 let status = res.status();
231
232 if !status.is_success() {
233 let err_body: ErrorResponse = res.json().await?;
234 return Err(Error::new_api_error(status, err_body.error));
235 }
236
237 res.json().await.map_err(Into::into)
238 }
239}