1use std::{num::ParseIntError, str::FromStr};
2
3use http::{header::AsHeaderName, HeaderMap, StatusCode};
4use serde::Serialize;
5
6use crate::{
7 error::{ApiError, DeserializeError, FromHttpError, HeaderError, IntoHttpError},
8 AuthRequirement, Context, Metadata,
9};
10
11#[derive(Debug, Copy, Clone, Eq, PartialEq, Hash, Serialize)]
15pub struct Pagination {
16 pub page: usize,
17 pub limit: usize,
18}
19
20impl Default for Pagination {
21 fn default() -> Self {
22 Self::DEFAULT
23 }
24}
25
26impl Pagination {
27 const DEFAULT: Self = Self::new(1, 10);
28
29 #[inline]
30 #[must_use]
31 pub const fn new(page: usize, limit: usize) -> Self {
32 Self { page, limit }
33 }
34}
35
36#[derive(Debug, Clone, Eq, PartialEq, Hash)]
38pub struct PaginationResponse<T> {
39 pub items: Vec<T>,
40 pub current_page: usize,
41 pub items_per_page: usize,
42 pub total_pages: usize,
43 pub total_items: usize,
44}
45
46impl<T> PaginationResponse<T> {
47 pub fn from_headers(items: Vec<T>, map: &HeaderMap) -> Result<Self, DeserializeError> {
53 let current_page = parse_from_header(map, "X-Pagination-Page")?;
54 let items_per_page = parse_from_header(map, "X-Pagination-Limit")?;
55 let total_pages = parse_from_header(map, "X-Pagination-Page-Count")?;
56 let total_items = parse_from_header(map, "X-Pagination-Item-Count")?;
57
58 Ok(Self {
59 items,
60 current_page,
61 items_per_page,
62 total_pages,
63 total_items,
64 })
65 }
66
67 #[inline]
68 #[must_use]
69 pub const fn next_page(&self) -> Option<Pagination> {
70 if self.current_page < self.total_pages {
71 Some(Pagination::new(self.current_page + 1, self.items_per_page))
72 } else {
73 None
74 }
75 }
76}
77
78pub fn parse_from_header<T, K>(map: &HeaderMap, key: K) -> Result<T, DeserializeError>
85where
86 T: FromStr<Err = ParseIntError>,
87 K: AsHeaderName,
88{
89 map.get(key)
90 .ok_or(HeaderError::MissingHeader)?
91 .to_str()
92 .map_err(HeaderError::ToStrError)?
93 .parse()
94 .map_err(DeserializeError::ParseInt)
95}
96
97pub fn handle_response_body<B, T>(
107 response: &http::Response<B>,
108 expected: StatusCode,
109) -> Result<T, FromHttpError>
110where
111 B: AsRef<[u8]>,
112 T: serde::de::DeserializeOwned,
113{
114 if response.status() == expected {
115 Ok(serde_json::from_slice(response.body().as_ref()).map_err(DeserializeError::Json)?)
116 } else {
117 Err(FromHttpError::Api(ApiError::from(response.status())))
118 }
119}
120
121pub fn construct_req<B>(
128 ctx: &Context,
129 md: &Metadata,
130 path: &impl Serialize,
131 query: &impl Serialize,
132 body: B,
133) -> Result<http::Request<B>, IntoHttpError> {
134 let url = crate::construct_url(ctx.base_url, md.endpoint, path, query)?;
135
136 let request = http::Request::builder()
137 .method(&md.method)
138 .uri(url)
139 .header("Content-Type", "application/json")
140 .header("trakt-api-version", "2")
141 .header("trakt-api-key", ctx.client_id);
142 let request = match (md.auth, ctx.oauth_token) {
143 (AuthRequirement::None, _) | (AuthRequirement::Optional, None) => request,
144 (AuthRequirement::Optional | AuthRequirement::Required, Some(token)) => {
145 request.header("Authorization", format!("Bearer {token}"))
146 }
147 (AuthRequirement::Required, None) => {
148 return Err(IntoHttpError::MissingToken);
149 }
150 };
151 Ok(request.body(body)?)
152}
153
154#[cfg(test)]
155mod tests {
156 use http::HeaderValue;
157
158 use super::*;
159
160 #[test]
161 fn test_parse_from_header() {
162 let mut map = HeaderMap::new();
163 map.insert("B", HeaderValue::from_bytes(b"hello\xfa").unwrap());
164 map.insert("C", HeaderValue::from_static("hello"));
165 map.insert("D", HeaderValue::from_static("10"));
166
167 assert!(matches!(
168 parse_from_header::<u32, _>(&map, "A"),
169 Err(DeserializeError::Header(HeaderError::MissingHeader))
170 ));
171 assert!(matches!(
172 parse_from_header::<u32, _>(&map, "B"),
173 Err(DeserializeError::Header(HeaderError::ToStrError(_)))
174 ));
175 assert!(matches!(
176 parse_from_header::<u32, _>(&map, "C"),
177 Err(DeserializeError::ParseInt(_))
178 ));
179 assert_eq!(parse_from_header::<u32, _>(&map, "D").unwrap(), 10);
180 }
181
182 #[test]
183 fn test_handle_response_body_ok() {
184 let response = http::Response::builder()
185 .status(StatusCode::OK)
186 .body(b"\"hello\"")
187 .unwrap();
188 assert_eq!(
189 handle_response_body::<_, String>(&response, StatusCode::OK).unwrap(),
190 "hello"
191 );
192 }
193
194 #[test]
195 fn test_handle_response_body_bad_request() {
196 let response = http::Response::builder()
197 .status(StatusCode::BAD_REQUEST)
198 .body(b"\"hello\"")
199 .unwrap();
200 assert!(matches!(
201 handle_response_body::<_, String>(&response, StatusCode::OK),
202 Err(FromHttpError::Api(ApiError::BadRequest))
203 ));
204 }
205
206 #[test]
207 fn test_handle_response_body_deserialize_error() {
208 let response = http::Response::builder()
209 .status(StatusCode::OK)
210 .body(b"\"hello\xfa\"")
211 .unwrap();
212 assert!(matches!(
213 handle_response_body::<_, String>(&response, StatusCode::OK),
214 Err(FromHttpError::Deserialize(DeserializeError::Json(_)))
215 ));
216 }
217
218 #[allow(clippy::cognitive_complexity)]
219 #[test]
220 fn test_construct_req() {
221 let mut ctx = Context {
222 base_url: "https://api.trakt.tv",
223 client_id: "client id",
224 oauth_token: None,
225 };
226 let mut md = Metadata {
227 endpoint: "/test",
228 method: http::Method::GET,
229 auth: AuthRequirement::None,
230 };
231
232 let req = construct_req(&ctx, &md, &(), &(), "body").unwrap();
233 assert_eq!(req.method(), &http::Method::GET);
234 assert_eq!(req.uri(), "https://api.trakt.tv/test");
235 assert_eq!(
236 req.headers().get("Content-Type").unwrap(),
237 "application/json"
238 );
239 assert_eq!(req.headers().get("trakt-api-version").unwrap(), "2");
240 assert_eq!(req.headers().get("trakt-api-key").unwrap(), "client id");
241 assert!(req.headers().get("Authorization").is_none());
242 assert_eq!(req.into_body(), "body");
243
244 md.auth = AuthRequirement::Required;
245 ctx.oauth_token = Some("token");
246
247 let req = construct_req(&ctx, &md, &(), &(), "body").unwrap();
248 assert_eq!(req.method(), &http::Method::GET);
249 assert_eq!(req.uri(), "https://api.trakt.tv/test");
250 assert_eq!(
251 req.headers().get("Content-Type").unwrap(),
252 "application/json"
253 );
254 assert_eq!(req.headers().get("trakt-api-version").unwrap(), "2");
255 assert_eq!(req.headers().get("trakt-api-key").unwrap(), "client id");
256 assert_eq!(req.headers().get("Authorization").unwrap(), "Bearer token");
257 assert_eq!(req.into_body(), "body");
258
259 md.auth = AuthRequirement::Required;
260 ctx.oauth_token = None;
261 let result = construct_req(&ctx, &md, &(), &(), "body").unwrap_err();
262 assert!(matches!(result, IntoHttpError::MissingToken));
263
264 md.auth = AuthRequirement::Optional;
265 ctx.oauth_token = None;
266
267 let req = construct_req(&ctx, &md, &(), &(), "body").unwrap();
268 assert_eq!(req.method(), &http::Method::GET);
269 assert_eq!(req.uri(), "https://api.trakt.tv/test");
270 assert_eq!(
271 req.headers().get("Content-Type").unwrap(),
272 "application/json"
273 );
274 assert_eq!(req.headers().get("trakt-api-version").unwrap(), "2");
275 assert_eq!(req.headers().get("trakt-api-key").unwrap(), "client id");
276 assert!(req.headers().get("Authorization").is_none());
277 assert_eq!(req.into_body(), "body");
278
279 md.auth = AuthRequirement::Optional;
280 ctx.oauth_token = Some("token");
281
282 let req = construct_req(&ctx, &md, &(), &(), "body").unwrap();
283 assert_eq!(req.method(), &http::Method::GET);
284 assert_eq!(req.uri(), "https://api.trakt.tv/test");
285 assert_eq!(
286 req.headers().get("Content-Type").unwrap(),
287 "application/json"
288 );
289 assert_eq!(req.headers().get("trakt-api-version").unwrap(), "2");
290 assert_eq!(req.headers().get("trakt-api-key").unwrap(), "client id");
291 assert_eq!(req.headers().get("Authorization").unwrap(), "Bearer token");
292 assert_eq!(req.into_body(), "body");
293 }
294}