torbox_core_rs/
client.rs

1use std::marker::PhantomData;
2
3use crate::api::ApiResponse;
4use crate::body::ToMultipart;
5use crate::error::ApiError;
6use crate::traits::FromBytes;
7use reqwest::header::{ACCEPT, AUTHORIZATION, HeaderMap, HeaderValue};
8use reqwest::multipart::Form;
9use reqwest::{Client, Method};
10use serde::{Serialize, de::DeserializeOwned};
11
12pub trait EndpointSpec {
13    /// JSON body you send - Use `()` to not send anything.
14    type Req: serde::Serialize;
15    /// Typed payload you expect back on success
16    type Resp: serde::de::DeserializeOwned;
17
18    const PATH: &'static str;
19    const METHOD: Method = Method::POST;
20}
21
22pub struct Endpoint<'c, S: EndpointSpec> {
23    client: &'c TorboxClient,
24    //todo: fix
25    _path: Option<String>,
26    _marker: PhantomData<S>,
27}
28
29impl<'c, S: EndpointSpec> Endpoint<'c, S> {
30    pub fn new(client: &'c TorboxClient) -> Self {
31        Self {
32            client,
33            _path: None,
34            _marker: PhantomData,
35        }
36    }
37
38    pub fn new_with_url(client: &'c TorboxClient, full_url: impl Into<String>) -> Self {
39        Self {
40            client,
41            _path: Some(full_url.into()),
42            _marker: std::marker::PhantomData,
43        }
44    }
45
46    pub async fn call_no_body(&self, url_suffix: &str) -> Result<ApiResponse<S::Resp>, ApiError>
47    where
48        S::Resp:,
49        <S as EndpointSpec>::Resp: std::fmt::Debug,
50    {
51        self.client.request(S::METHOD, url_suffix).await
52    }
53
54    pub async fn call(&self, body: S::Req) -> Result<ApiResponse<S::Resp>, ApiError> {
55        self.client
56            .request_with_json(S::METHOD, S::PATH, body)
57            .await
58    }
59
60    pub async fn call_query(&self, query: S::Req) -> Result<ApiResponse<S::Resp>, ApiError>
61    where
62        S::Req: Serialize,
63    {
64        self.client
65            .request_with_query(S::METHOD, S::PATH, &query)
66            .await
67    }
68
69    pub async fn call_multipart(&self, body: S::Req) -> Result<ApiResponse<S::Resp>, ApiError>
70    where
71        S::Req: ToMultipart + Send + Sync,
72    {
73        let form = body.to_multipart().await;
74        self.client
75            .request_multipart(S::METHOD, S::PATH, form)
76            .await
77    }
78
79    pub async fn call_query_bytes(&self, query: S::Req) -> Result<Vec<u8>, ApiError>
80    where
81        S::Req: Serialize,
82    {
83        let url = format!("{}/{}", self.client.base_url, S::PATH);
84        let response = self
85            .client
86            .client
87            .request(S::METHOD, &url)
88            .headers(self.client.headers("application/json"))
89            .query(&query)
90            .send()
91            .await?;
92
93        Ok(response.bytes().await?.to_vec())
94    }
95
96    pub async fn call_query_raw<T>(&self, query: S::Req) -> Result<T, ApiError>
97    where
98        T: DeserializeOwned + FromBytes,
99        S::Req: Serialize,
100    {
101        let res = self
102            .client
103            .client
104            .request(S::METHOD, format!("{}/{}", self.client.base_url, S::PATH))
105            .headers(self.client.headers("application/json"))
106            .query(&query)
107            .send()
108            .await?;
109
110        self.client.parse_response::<T>(res).await
111    }
112}
113
114#[derive(Clone)]
115#[cfg_attr(feature = "specta", derive(specta::Type))]
116pub struct TorboxClient {
117    /// Client can be specta skipped because TorboxClient should NEVER be used in any frontend, type is only used to be able to derive the APIs built from it.
118    #[cfg_attr(feature = "specta", specta(skip))]
119    pub client: Client,
120    pub(crate) token: String,
121    pub base_url: String,
122}
123
124impl TorboxClient {
125    pub fn new(token: String) -> Self {
126        let client = Client::builder()
127            .redirect(reqwest::redirect::Policy::none())
128            .build()
129            .unwrap();
130        Self {
131            client,
132            token,
133            base_url: "https://api.torbox.app/v1".to_string(),
134        }
135    }
136
137    pub fn with_base_url(&self, new_base: impl Into<String>) -> Self {
138        let mut new = self.clone();
139        new.base_url = new_base.into();
140        new
141    }
142
143    pub fn token(&self) -> &str {
144        &self.token
145    }
146
147    async fn parse_response<T>(&self, res: reqwest::Response) -> Result<T, ApiError>
148    where
149        T: DeserializeOwned + FromBytes,
150    {
151        let content_type = res
152            .headers()
153            .get(reqwest::header::CONTENT_TYPE)
154            .and_then(|v| v.to_str().ok())
155            .unwrap_or("");
156
157        if content_type.starts_with("application/json") {
158            let text = res.text().await?;
159            serde_json::from_str::<T>(&text).map_err(ApiError::from)
160        } else {
161            // Handle binary responses
162            let bytes = res.bytes().await?.to_vec();
163            T::from_bytes(bytes)
164        }
165    }
166
167    fn headers(&self, _content_type: &'static str) -> HeaderMap {
168        let mut headers = HeaderMap::new();
169        headers.insert(ACCEPT, HeaderValue::from_static("application/json"));
170        headers.insert(
171            AUTHORIZATION,
172            HeaderValue::from_str(&format!("Bearer {}", self.token)).unwrap(),
173        );
174        headers
175    }
176
177    pub async fn request_multipart<T: DeserializeOwned>(
178        &self,
179        method: Method,
180        endpoint: &str,
181        form: Form,
182    ) -> Result<T, ApiError> {
183        let url = format!("{}/{}", self.base_url, endpoint);
184
185        let res = self
186            .client
187            .request(method, url)
188            .headers(self.headers("multipart/form-data"))
189            .multipart(form)
190            .send()
191            .await?;
192
193        let text = res.text().await?;
194
195        let parsed = serde_json::from_str::<T>(&text)?;
196        Ok(parsed)
197    }
198
199    pub async fn request<T: DeserializeOwned + FromBytes>(
200        &self,
201        method: Method,
202        endpoint: &str,
203    ) -> Result<T, ApiError> {
204        let res = self
205            .client
206            .request(method, format!("{}/{}", self.base_url, endpoint))
207            .headers(self.headers("application/json"))
208            .send()
209            .await?;
210
211        self.parse_response::<T>(res).await
212    }
213
214    pub async fn request_with_json<T: DeserializeOwned, B: Serialize>(
215        &self,
216        method: Method,
217        endpoint: &str,
218        body: B,
219    ) -> Result<T, ApiError> {
220        let res = self
221            .client
222            .request(method, format!("{}/{}", self.base_url, endpoint))
223            .headers(self.headers("application/json"))
224            .json(&body)
225            .send()
226            .await?;
227
228        let text = res.text().await?;
229
230        let parsed = serde_json::from_str::<T>(&text)?;
231        Ok(parsed)
232    }
233
234    pub async fn request_with_query<T: DeserializeOwned, Q: Serialize>(
235        &self,
236        method: Method,
237        endpoint: &str,
238        query: &Q,
239    ) -> Result<T, ApiError> {
240        let res = self
241            .client
242            .request(method, format!("{}/{}", self.base_url, endpoint))
243            .headers(self.headers("application/json"))
244            .query(query)
245            .send()
246            .await?;
247
248        let text = res.text().await?;
249        // eprintln!("Raw API response: {}", text);
250
251        let parsed = serde_json::from_str::<T>(&text)?;
252        Ok(parsed)
253    }
254}