trakt_core/
utils.rs

1use std::{num::ParseIntError, str::FromStr};
2
3use http::{header::AsHeaderName, HeaderMap, StatusCode};
4use serde::Serialize;
5
6use crate::{
7    error::{ApiError, DeserializeError, FromHttpError, HeaderError, IntoHttpError},
8    AuthRequirement, Context, Metadata,
9};
10
11/// `Pagination` struct is used to specify the page number and the maximum number of items to be shown per page.
12///
13/// Default values are `page = 1` and `limit = 10`.
14#[derive(Debug, Copy, Clone, Eq, PartialEq, Hash, Serialize)]
15pub struct Pagination {
16    pub page: usize,
17    pub limit: usize,
18}
19
20impl Default for Pagination {
21    fn default() -> Self {
22        Self::DEFAULT
23    }
24}
25
26impl Pagination {
27    const DEFAULT: Self = Self::new(1, 10);
28
29    #[inline]
30    #[must_use]
31    pub const fn new(page: usize, limit: usize) -> Self {
32        Self { page, limit }
33    }
34}
35
36/// `PaginationResponse` struct is used to store the paginated response from the API.
37#[derive(Debug, Clone, Eq, PartialEq, Hash)]
38pub struct PaginationResponse<T> {
39    pub items: Vec<T>,
40    pub current_page: usize,
41    pub items_per_page: usize,
42    pub total_pages: usize,
43    pub total_items: usize,
44}
45
46impl<T> PaginationResponse<T> {
47    /// Create a new `PaginationResponse` instance from items and Trakt.tv API response headers.
48    ///
49    /// # Errors
50    ///
51    /// Returns a `DeserializeError` if the headers are missing or if the header values are not valid.
52    pub fn from_headers(items: Vec<T>, map: &HeaderMap) -> Result<Self, DeserializeError> {
53        let current_page = parse_from_header(map, "X-Pagination-Page")?;
54        let items_per_page = parse_from_header(map, "X-Pagination-Limit")?;
55        let total_pages = parse_from_header(map, "X-Pagination-Page-Count")?;
56        let total_items = parse_from_header(map, "X-Pagination-Item-Count")?;
57
58        Ok(Self {
59            items,
60            current_page,
61            items_per_page,
62            total_pages,
63            total_items,
64        })
65    }
66
67    #[inline]
68    #[must_use]
69    pub const fn next_page(&self) -> Option<Pagination> {
70        if self.current_page < self.total_pages {
71            Some(Pagination::new(self.current_page + 1, self.items_per_page))
72        } else {
73            None
74        }
75    }
76}
77
78/// Helper function to parse a header value to an integer.
79///
80/// # Errors
81///
82/// Returns a `DeserializeError` if the header is missing, if the header value is not a valid
83/// string, or if the string value cannot be parsed to an integer.
84pub fn parse_from_header<T, K>(map: &HeaderMap, key: K) -> Result<T, DeserializeError>
85where
86    T: FromStr<Err = ParseIntError>,
87    K: AsHeaderName,
88{
89    map.get(key)
90        .ok_or(HeaderError::MissingHeader)?
91        .to_str()
92        .map_err(HeaderError::ToStrError)?
93        .parse()
94        .map_err(DeserializeError::ParseInt)
95}
96
97/// Helper function to handle the response body from the API.
98///
99/// Will check if the response has the expected status code and will try to deserialize the
100/// response body.
101///
102/// # Errors
103///
104/// Returns a `FromHttpError` if the response status code is not the expected one or if the body
105/// failed to be deserialized.
106pub fn handle_response_body<B, T>(
107    response: &http::Response<B>,
108    expected: StatusCode,
109) -> Result<T, FromHttpError>
110where
111    B: AsRef<[u8]>,
112    T: serde::de::DeserializeOwned,
113{
114    if response.status() == expected {
115        Ok(serde_json::from_slice(response.body().as_ref()).map_err(DeserializeError::Json)?)
116    } else {
117        Err(FromHttpError::Api(ApiError::from(response.status())))
118    }
119}
120
121/// Helper function to construct an HTTP request using the given context, metadata, and
122/// path/query/body values.
123///
124/// # Errors
125///
126/// Returns an `IntoHttpError` if the http request cannot be constructed.
127pub fn construct_req<B>(
128    ctx: &Context,
129    md: &Metadata,
130    path: &impl Serialize,
131    query: &impl Serialize,
132    body: B,
133) -> Result<http::Request<B>, IntoHttpError> {
134    let url = crate::construct_url(ctx.base_url, md.endpoint, path, query)?;
135
136    let request = http::Request::builder()
137        .method(&md.method)
138        .uri(url)
139        .header("Content-Type", "application/json")
140        .header("trakt-api-version", "2")
141        .header("trakt-api-key", ctx.client_id);
142    let request = match (md.auth, ctx.oauth_token) {
143        (AuthRequirement::None, _) | (AuthRequirement::Optional, None) => request,
144        (AuthRequirement::Optional | AuthRequirement::Required, Some(token)) => {
145            request.header("Authorization", format!("Bearer {token}"))
146        }
147        (AuthRequirement::Required, None) => {
148            return Err(IntoHttpError::MissingToken);
149        }
150    };
151    Ok(request.body(body)?)
152}
153
154#[cfg(test)]
155mod tests {
156    use http::HeaderValue;
157
158    use super::*;
159
160    #[test]
161    fn test_parse_from_header() {
162        let mut map = HeaderMap::new();
163        map.insert("B", HeaderValue::from_bytes(b"hello\xfa").unwrap());
164        map.insert("C", HeaderValue::from_static("hello"));
165        map.insert("D", HeaderValue::from_static("10"));
166
167        assert!(matches!(
168            parse_from_header::<u32, _>(&map, "A"),
169            Err(DeserializeError::Header(HeaderError::MissingHeader))
170        ));
171        assert!(matches!(
172            parse_from_header::<u32, _>(&map, "B"),
173            Err(DeserializeError::Header(HeaderError::ToStrError(_)))
174        ));
175        assert!(matches!(
176            parse_from_header::<u32, _>(&map, "C"),
177            Err(DeserializeError::ParseInt(_))
178        ));
179        assert_eq!(parse_from_header::<u32, _>(&map, "D").unwrap(), 10);
180    }
181
182    #[test]
183    fn test_handle_response_body_ok() {
184        let response = http::Response::builder()
185            .status(StatusCode::OK)
186            .body(b"\"hello\"")
187            .unwrap();
188        assert_eq!(
189            handle_response_body::<_, String>(&response, StatusCode::OK).unwrap(),
190            "hello"
191        );
192    }
193
194    #[test]
195    fn test_handle_response_body_bad_request() {
196        let response = http::Response::builder()
197            .status(StatusCode::BAD_REQUEST)
198            .body(b"\"hello\"")
199            .unwrap();
200        assert!(matches!(
201            handle_response_body::<_, String>(&response, StatusCode::OK),
202            Err(FromHttpError::Api(ApiError::BadRequest))
203        ));
204    }
205
206    #[test]
207    fn test_handle_response_body_deserialize_error() {
208        let response = http::Response::builder()
209            .status(StatusCode::OK)
210            .body(b"\"hello\xfa\"")
211            .unwrap();
212        assert!(matches!(
213            handle_response_body::<_, String>(&response, StatusCode::OK),
214            Err(FromHttpError::Deserialize(DeserializeError::Json(_)))
215        ));
216    }
217
218    #[allow(clippy::cognitive_complexity)]
219    #[test]
220    fn test_construct_req() {
221        let mut ctx = Context {
222            base_url: "https://api.trakt.tv",
223            client_id: "client id",
224            oauth_token: None,
225        };
226        let mut md = Metadata {
227            endpoint: "/test",
228            method: http::Method::GET,
229            auth: AuthRequirement::None,
230        };
231
232        let req = construct_req(&ctx, &md, &(), &(), "body").unwrap();
233        assert_eq!(req.method(), &http::Method::GET);
234        assert_eq!(req.uri(), "https://api.trakt.tv/test");
235        assert_eq!(
236            req.headers().get("Content-Type").unwrap(),
237            "application/json"
238        );
239        assert_eq!(req.headers().get("trakt-api-version").unwrap(), "2");
240        assert_eq!(req.headers().get("trakt-api-key").unwrap(), "client id");
241        assert!(req.headers().get("Authorization").is_none());
242        assert_eq!(req.into_body(), "body");
243
244        md.auth = AuthRequirement::Required;
245        ctx.oauth_token = Some("token");
246
247        let req = construct_req(&ctx, &md, &(), &(), "body").unwrap();
248        assert_eq!(req.method(), &http::Method::GET);
249        assert_eq!(req.uri(), "https://api.trakt.tv/test");
250        assert_eq!(
251            req.headers().get("Content-Type").unwrap(),
252            "application/json"
253        );
254        assert_eq!(req.headers().get("trakt-api-version").unwrap(), "2");
255        assert_eq!(req.headers().get("trakt-api-key").unwrap(), "client id");
256        assert_eq!(req.headers().get("Authorization").unwrap(), "Bearer token");
257        assert_eq!(req.into_body(), "body");
258
259        md.auth = AuthRequirement::Required;
260        ctx.oauth_token = None;
261        let result = construct_req(&ctx, &md, &(), &(), "body").unwrap_err();
262        assert!(matches!(result, IntoHttpError::MissingToken));
263
264        md.auth = AuthRequirement::Optional;
265        ctx.oauth_token = None;
266
267        let req = construct_req(&ctx, &md, &(), &(), "body").unwrap();
268        assert_eq!(req.method(), &http::Method::GET);
269        assert_eq!(req.uri(), "https://api.trakt.tv/test");
270        assert_eq!(
271            req.headers().get("Content-Type").unwrap(),
272            "application/json"
273        );
274        assert_eq!(req.headers().get("trakt-api-version").unwrap(), "2");
275        assert_eq!(req.headers().get("trakt-api-key").unwrap(), "client id");
276        assert!(req.headers().get("Authorization").is_none());
277        assert_eq!(req.into_body(), "body");
278
279        md.auth = AuthRequirement::Optional;
280        ctx.oauth_token = Some("token");
281
282        let req = construct_req(&ctx, &md, &(), &(), "body").unwrap();
283        assert_eq!(req.method(), &http::Method::GET);
284        assert_eq!(req.uri(), "https://api.trakt.tv/test");
285        assert_eq!(
286            req.headers().get("Content-Type").unwrap(),
287            "application/json"
288        );
289        assert_eq!(req.headers().get("trakt-api-version").unwrap(), "2");
290        assert_eq!(req.headers().get("trakt-api-key").unwrap(), "client id");
291        assert_eq!(req.headers().get("Authorization").unwrap(), "Bearer token");
292        assert_eq!(req.into_body(), "body");
293    }
294}