1use futures_util::StreamExt;
16use reqwest::header::{HeaderMap, HeaderValue, AUTHORIZATION};
17use serde::de::DeserializeOwned;
18use serde::Serialize;
19use tokio::io::AsyncWriteExt;
20
21use crate::error::{Error, Result};
22use crate::token::Token;
23
24#[derive(Clone)]
29pub struct Client {
30 inner: reqwest::Client,
31 base_url: String,
32}
33
34impl Client {
35 pub fn new(base_url: impl Into<String>, token: Token) -> Result<Self> {
38 let mut headers = HeaderMap::new();
39 let value = format!("Bearer {}", token.as_str());
40 let header = HeaderValue::from_str(&value)
41 .map_err(|_| Error::BadRequest("token contained invalid bytes".into()))?;
42 headers.insert(AUTHORIZATION, header);
43
44 let inner = reqwest::Client::builder()
45 .default_headers(headers)
46 .user_agent(concat!(
47 "wavekat-platform-client/",
48 env!("CARGO_PKG_VERSION")
49 ))
50 .build()?;
51 Ok(Self {
52 inner,
53 base_url: base_url.into().trim_end_matches('/').to_string(),
54 })
55 }
56
57 pub fn base_url(&self) -> &str {
61 &self.base_url
62 }
63
64 fn url(&self, path: &str) -> String {
65 format!("{}{}", self.base_url, path)
66 }
67
68 pub async fn get_json<T: DeserializeOwned>(&self, path: &str) -> Result<T> {
70 let url = self.url(path);
71 let resp = self.inner.get(&url).send().await?;
72 decode(url, resp).await
73 }
74
75 pub async fn get_json_query<T: DeserializeOwned, Q: Serialize + ?Sized>(
78 &self,
79 path: &str,
80 query: &Q,
81 ) -> Result<T> {
82 let url = self.url(path);
83 let resp = self.inner.get(&url).query(query).send().await?;
84 decode(url, resp).await
85 }
86
87 pub async fn post_json<T: DeserializeOwned, B: Serialize + ?Sized>(
90 &self,
91 path: &str,
92 body: &B,
93 ) -> Result<T> {
94 let url = self.url(path);
95 let resp = self.inner.post(&url).json(body).send().await?;
96 decode(url, resp).await
97 }
98
99 pub async fn post_empty(&self, path: &str) -> Result<()> {
101 let url = self.url(path);
102 let resp = self.inner.post(&url).send().await?;
103 ensure_success(url, resp).await
104 }
105
106 pub async fn post_empty_returning_json<T: DeserializeOwned>(&self, path: &str) -> Result<T> {
110 let url = self.url(path);
111 let resp = self.inner.post(&url).send().await?;
112 decode(url, resp).await
113 }
114
115 pub async fn delete(&self, path: &str) -> Result<()> {
117 let url = self.url(path);
118 let resp = self.inner.delete(&url).send().await?;
119 ensure_success(url, resp).await
120 }
121
122 pub async fn put_proxy_bytes(&self, path: &str, body: Vec<u8>) -> Result<()> {
126 self.put_raw_bytes(path, "application/octet-stream", body)
127 .await
128 }
129
130 pub async fn put_raw_bytes(&self, path: &str, content_type: &str, body: Vec<u8>) -> Result<()> {
135 let url = self.url(path);
136 let resp = self
137 .inner
138 .put(&url)
139 .header(reqwest::header::CONTENT_TYPE, content_type)
140 .body(body)
141 .send()
142 .await?;
143 ensure_success(url, resp).await
144 }
145
146 pub async fn put_presigned_bytes(presigned_url: &str, body: Vec<u8>) -> Result<()> {
151 let resp = reqwest::Client::new()
152 .put(presigned_url)
153 .body(body)
154 .send()
155 .await?;
156 ensure_success(presigned_url.to_string(), resp).await
157 }
158
159 pub async fn get_public_json<T: DeserializeOwned>(
171 base_url: &str,
172 path: &str,
173 query: &[(&str, &str)],
174 ) -> Result<T> {
175 let base = base_url.trim_end_matches('/');
176 let url = format!("{}{}", base, path);
177 let mut req = reqwest::Client::new().get(&url);
178 if !query.is_empty() {
179 req = req.query(query);
180 }
181 let resp = req.send().await?;
182 decode(url, resp).await
183 }
184
185 pub async fn get_stream_to<W: AsyncWriteExt + Unpin>(
189 &self,
190 path: &str,
191 sink: &mut W,
192 ) -> Result<u64> {
193 let url = self.url(path);
194 let resp = self.inner.get(&url).send().await?;
195 let status = resp.status();
196 if !status.is_success() {
197 let body = resp.text().await.unwrap_or_default();
198 return Err(http_error(status.as_u16(), url, body));
199 }
200 let mut stream = resp.bytes_stream();
201 let mut written: u64 = 0;
202 while let Some(chunk) = stream.next().await {
203 let bytes = chunk?;
204 sink.write_all(&bytes).await?;
205 written += bytes.len() as u64;
206 }
207 sink.flush().await?;
208 Ok(written)
209 }
210}
211
212async fn decode<T: DeserializeOwned>(url: String, resp: reqwest::Response) -> Result<T> {
213 let status = resp.status();
214 let text = resp.text().await?;
215 if !status.is_success() {
216 return Err(http_error(status.as_u16(), url, text));
217 }
218 serde_json::from_str(&text).map_err(|source| Error::Decode { url, source })
219}
220
221async fn ensure_success(url: String, resp: reqwest::Response) -> Result<()> {
222 let status = resp.status();
223 if status.is_success() {
224 return Ok(());
225 }
226 let body = resp.text().await.unwrap_or_default();
227 Err(http_error(status.as_u16(), url, body))
228}
229
230fn http_error(status: u16, url: String, body: String) -> Error {
235 let body = truncate(&body, 500).to_string();
236 if status == 401 {
237 Error::Unauthorized { url, body }
238 } else {
239 Error::Http { status, url, body }
240 }
241}
242
243fn truncate(s: &str, n: usize) -> &str {
244 if s.len() > n {
245 let mut end = n;
250 while end > 0 && !s.is_char_boundary(end) {
251 end -= 1;
252 }
253 &s[..end]
254 } else {
255 s
256 }
257}
258
259#[cfg(test)]
260mod tests {
261 use super::*;
262
263 #[test]
264 fn http_error_format_matches_cli_shape() {
265 let e = Error::Http {
270 status: 500,
271 url: "https://platform.wavekat.com/api/me".into(),
272 body: "boom".into(),
273 };
274 let s = e.to_string();
275 assert!(s.contains("500"), "{s}");
276 assert!(s.contains("https://platform.wavekat.com/api/me"), "{s}");
277 assert!(s.contains("boom"), "{s}");
278 }
279
280 #[test]
281 fn http_error_splits_401_into_unauthorized() {
282 let e = http_error(
285 401,
286 "https://platform.wavekat.com/api/me".into(),
287 "{\"error\":\"unauthenticated\"}".into(),
288 );
289 assert!(
290 matches!(e, Error::Unauthorized { .. }),
291 "expected Unauthorized, got {e:?}"
292 );
293 let s = e.to_string();
295 assert!(s.contains("401"), "{s}");
296 assert!(s.contains("https://platform.wavekat.com/api/me"), "{s}");
297 }
298
299 #[test]
300 fn http_error_keeps_non_401_in_http_variant() {
301 let e = http_error(
302 500,
303 "https://platform.wavekat.com/api/me".into(),
304 "boom".into(),
305 );
306 assert!(
307 matches!(e, Error::Http { status: 500, .. }),
308 "expected Http {{ status: 500 }}, got {e:?}"
309 );
310 }
311
312 #[test]
313 fn truncate_respects_char_boundaries() {
314 let s = "a".repeat(498) + "é"; let t = truncate(&s, 499);
317 assert!(s.starts_with(t));
318 }
319}