1use std::marker::PhantomData;
2
3use crate::api::ApiResponse;
4use crate::body::ToMultipart;
5use crate::error::ApiError;
6use crate::traits::FromBytes;
7use reqwest::header::{ACCEPT, AUTHORIZATION, HeaderMap, HeaderValue};
8use reqwest::multipart::Form;
9use reqwest::{Client, Method};
10use serde::{Serialize, de::DeserializeOwned};
11
12pub trait EndpointSpec {
13 type Req: serde::Serialize;
15 type Resp: serde::de::DeserializeOwned;
17
18 const PATH: &'static str;
19 const METHOD: Method = Method::POST;
20}
21
22pub struct Endpoint<'c, S: EndpointSpec> {
23 client: &'c TorboxClient,
24 _path: Option<String>,
26 _marker: PhantomData<S>,
27}
28
29impl<'c, S: EndpointSpec> Endpoint<'c, S> {
30 pub fn new(client: &'c TorboxClient) -> Self {
31 Self {
32 client,
33 _path: None,
34 _marker: PhantomData,
35 }
36 }
37
38 pub fn new_with_url(client: &'c TorboxClient, full_url: impl Into<String>) -> Self {
39 Self {
40 client,
41 _path: Some(full_url.into()),
42 _marker: std::marker::PhantomData,
43 }
44 }
45
46 pub async fn call_no_body(&self, url_suffix: &str) -> Result<ApiResponse<S::Resp>, ApiError>
47 where
48 S::Resp:,
49 <S as EndpointSpec>::Resp: std::fmt::Debug,
50 {
51 self.client.request(S::METHOD, url_suffix).await
52 }
53
54 pub async fn call(&self, body: S::Req) -> Result<ApiResponse<S::Resp>, ApiError> {
55 self.client
56 .request_with_json(S::METHOD, S::PATH, body)
57 .await
58 }
59
60 pub async fn call_query(&self, query: S::Req) -> Result<ApiResponse<S::Resp>, ApiError>
61 where
62 S::Req: Serialize,
63 {
64 self.client
65 .request_with_query(S::METHOD, S::PATH, &query)
66 .await
67 }
68
69 pub async fn call_multipart(&self, body: S::Req) -> Result<ApiResponse<S::Resp>, ApiError>
70 where
71 S::Req: ToMultipart + Send + Sync,
72 {
73 let form = body.to_multipart().await;
74 self.client
75 .request_multipart(S::METHOD, S::PATH, form)
76 .await
77 }
78
79 pub async fn call_query_bytes(&self, query: S::Req) -> Result<Vec<u8>, ApiError>
80 where
81 S::Req: Serialize,
82 {
83 let url = format!("{}/{}", self.client.base_url, S::PATH);
84 let response = self
85 .client
86 .client
87 .request(S::METHOD, &url)
88 .headers(self.client.headers("application/json"))
89 .query(&query)
90 .send()
91 .await?;
92
93 Ok(response.bytes().await?.to_vec())
94 }
95
96 pub async fn call_query_raw<T>(&self, query: S::Req) -> Result<T, ApiError>
97 where
98 T: DeserializeOwned + FromBytes,
99 S::Req: Serialize,
100 {
101 let res = self
102 .client
103 .client
104 .request(S::METHOD, format!("{}/{}", self.client.base_url, S::PATH))
105 .headers(self.client.headers("application/json"))
106 .query(&query)
107 .send()
108 .await?;
109
110 self.client.parse_response::<T>(res).await
111 }
112}
113
114#[derive(Clone)]
115#[cfg_attr(feature = "specta", derive(specta::Type))]
116pub struct TorboxClient {
117 #[cfg_attr(feature = "specta", specta(skip))]
119 pub client: Client,
120 pub(crate) token: String,
121 pub base_url: String,
122}
123
124impl TorboxClient {
125 pub fn new(token: String) -> Self {
126 let client = Client::builder()
127 .redirect(reqwest::redirect::Policy::none())
128 .build()
129 .unwrap();
130 Self {
131 client,
132 token,
133 base_url: "https://api.torbox.app/v1".to_string(),
134 }
135 }
136
137 pub fn with_base_url(&self, new_base: impl Into<String>) -> Self {
138 let mut new = self.clone();
139 new.base_url = new_base.into();
140 new
141 }
142
143 pub fn token(&self) -> &str {
144 &self.token
145 }
146
147 async fn parse_response<T>(&self, res: reqwest::Response) -> Result<T, ApiError>
148 where
149 T: DeserializeOwned + FromBytes,
150 {
151 let content_type = res
152 .headers()
153 .get(reqwest::header::CONTENT_TYPE)
154 .and_then(|v| v.to_str().ok())
155 .unwrap_or("");
156
157 if content_type.starts_with("application/json") {
158 let text = res.text().await?;
159 serde_json::from_str::<T>(&text).map_err(ApiError::from)
160 } else {
161 let bytes = res.bytes().await?.to_vec();
163 T::from_bytes(bytes)
164 }
165 }
166
167 fn headers(&self, _content_type: &'static str) -> HeaderMap {
168 let mut headers = HeaderMap::new();
169 headers.insert(ACCEPT, HeaderValue::from_static("application/json"));
170 headers.insert(
171 AUTHORIZATION,
172 HeaderValue::from_str(&format!("Bearer {}", self.token)).unwrap(),
173 );
174 headers
175 }
176
177 pub async fn request_multipart<T: DeserializeOwned>(
178 &self,
179 method: Method,
180 endpoint: &str,
181 form: Form,
182 ) -> Result<T, ApiError> {
183 let url = format!("{}/{}", self.base_url, endpoint);
184
185 let res = self
186 .client
187 .request(method, url)
188 .headers(self.headers("multipart/form-data"))
189 .multipart(form)
190 .send()
191 .await?;
192
193 let text = res.text().await?;
194
195 let parsed = serde_json::from_str::<T>(&text)?;
196 Ok(parsed)
197 }
198
199 pub async fn request<T: DeserializeOwned + FromBytes>(
200 &self,
201 method: Method,
202 endpoint: &str,
203 ) -> Result<T, ApiError> {
204 let res = self
205 .client
206 .request(method, format!("{}/{}", self.base_url, endpoint))
207 .headers(self.headers("application/json"))
208 .send()
209 .await?;
210
211 self.parse_response::<T>(res).await
212 }
213
214 pub async fn request_with_json<T: DeserializeOwned, B: Serialize>(
215 &self,
216 method: Method,
217 endpoint: &str,
218 body: B,
219 ) -> Result<T, ApiError> {
220 let res = self
221 .client
222 .request(method, format!("{}/{}", self.base_url, endpoint))
223 .headers(self.headers("application/json"))
224 .json(&body)
225 .send()
226 .await?;
227
228 let text = res.text().await?;
229
230 let parsed = serde_json::from_str::<T>(&text)?;
231 Ok(parsed)
232 }
233
234 pub async fn request_with_query<T: DeserializeOwned, Q: Serialize>(
235 &self,
236 method: Method,
237 endpoint: &str,
238 query: &Q,
239 ) -> Result<T, ApiError> {
240 let res = self
241 .client
242 .request(method, format!("{}/{}", self.base_url, endpoint))
243 .headers(self.headers("application/json"))
244 .query(query)
245 .send()
246 .await?;
247
248 let text = res.text().await?;
249 let parsed = serde_json::from_str::<T>(&text)?;
252 Ok(parsed)
253 }
254}