Skip to main content

satay_runtime/
lib.rs

1#![forbid(unsafe_code)]
2
3use std::fmt;
4use std::str::FromStr;
5
6use http::header::{self, CONTENT_TYPE, HeaderName, HeaderValue};
7#[cfg(feature = "json")]
8use serde::de;
9use time::Month;
10use time::format_description::well_known::Rfc3339;
11pub use time::{Date, OffsetDateTime, PrimitiveDateTime, Time};
12
13use tracing::{debug, instrument};
14
15#[derive(Debug, Clone, PartialEq, Eq)]
16pub struct RequestParts<B> {
17    pub method: http::Method,
18    pub uri: String,
19    pub headers: http::HeaderMap,
20    pub body: B,
21}
22
23#[derive(Debug, Clone, PartialEq, Eq)]
24pub struct ResponseParts<B> {
25    pub status: http::StatusCode,
26    pub headers: http::HeaderMap,
27    pub body: B,
28}
29
30#[derive(Debug, thiserror::Error)]
31pub enum Error {
32    #[error("failed to build HTTP message: {0}")]
33    Http(#[from] http::Error),
34
35    #[error("invalid HTTP header value: {0}")]
36    InvalidHeaderValue(#[from] header::InvalidHeaderValue),
37
38    #[error("invalid HTTP header name: {0}")]
39    InvalidHeaderName(#[from] header::InvalidHeaderName),
40
41    #[error("missing required field `{0}`")]
42    MissingRequired(&'static str),
43
44    #[error("{0}")]
45    InvalidResponse(&'static str),
46
47    #[cfg(feature = "json")]
48    #[error("JSON error: {0}")]
49    Json(#[from] serde_json::Error),
50}
51
52#[derive(Debug, Clone, PartialEq, Eq, thiserror::Error)]
53pub enum ParseRangeError {
54    #[error("range contains more than one `-` separator")]
55    TooManySeparators,
56
57    #[error("invalid range minimum `{value}`: {message}")]
58    InvalidMinimum { value: String, message: String },
59
60    #[error("invalid range maximum `{value}`: {message}")]
61    InvalidMaximum { value: String, message: String },
62}
63
64#[derive(Debug, Clone, PartialEq, Eq, thiserror::Error)]
65pub enum ParseTimeError {
66    #[error("time must be exactly four ASCII digits in HHMM format")]
67    InvalidFormat,
68
69    #[error("time is outside valid HHMM range")]
70    ComponentRange,
71}
72
73#[derive(Debug, Clone, PartialEq, Eq, thiserror::Error)]
74pub enum ParseDateError {
75    #[error("date must be in YYYY-MM-DD format")]
76    InvalidFormat,
77}
78
79#[derive(Debug, Clone, PartialEq, Eq, thiserror::Error)]
80pub enum ParseNaiveDateTimeError {
81    #[error("datetime must be in YYYY-MM-DDTHH:mm:ss format")]
82    InvalidFormat,
83}
84
85pub trait Action {
86    type Response;
87
88    fn request(self) -> Result<http::Request<Vec<u8>>, Error>;
89    fn decode<B: AsRef<[u8]>>(response: ResponseParts<B>) -> Result<Self::Response, Error>;
90}
91
92#[instrument(skip_all, fields(method = %method, uri = %uri))]
93pub fn into_request<B>(
94    RequestParts {
95        method,
96        uri,
97        headers,
98        body,
99    }: RequestParts<B>,
100) -> Result<http::Request<B>, Error> {
101    debug!("building HTTP request");
102    let mut request = http::Request::builder()
103        .method(method)
104        .uri(uri)
105        .body(body)?;
106    *request.headers_mut() = headers;
107    Ok(request)
108}
109
110#[instrument(skip_all, fields(method = %method, uri = %uri))]
111pub fn into_empty_request(
112    RequestParts {
113        method,
114        uri,
115        headers,
116        body: _,
117    }: RequestParts<()>,
118) -> Result<http::Request<Vec<u8>>, Error> {
119    debug!("building empty HTTP request");
120    let mut request = http::Request::builder()
121        .method(method)
122        .uri(uri)
123        .body(vec![])?;
124    *request.headers_mut() = headers;
125    Ok(request)
126}
127
128#[cfg(feature = "json")]
129#[instrument(skip_all, fields(method = %method, uri = %uri))]
130pub fn into_json_request<T>(
131    RequestParts {
132        method,
133        uri,
134        headers,
135        body,
136    }: RequestParts<T>,
137) -> Result<http::Request<Vec<u8>>, Error>
138where
139    T: serde::Serialize,
140{
141    debug!("building JSON HTTP request");
142    let body = serde_json::to_vec(&body)?;
143    let mut request = http::Request::builder()
144        .method(method)
145        .uri(uri)
146        .body(body)?;
147    *request.headers_mut() = headers;
148    if !request.headers().contains_key(CONTENT_TYPE) {
149        request.headers_mut().insert(
150            CONTENT_TYPE,
151            http::HeaderValue::from_static("application/json"),
152        );
153    }
154    Ok(request)
155}
156
157#[cfg(feature = "json")]
158#[instrument(skip_all, fields(method = %method, uri = %uri))]
159pub fn into_optional_json_request<T>(
160    RequestParts {
161        method,
162        uri,
163        headers,
164        body,
165    }: RequestParts<Option<T>>,
166) -> Result<http::Request<Vec<u8>>, Error>
167where
168    T: serde::Serialize,
169{
170    match body {
171        Some(body) => into_json_request(RequestParts {
172            method,
173            uri,
174            headers,
175            body,
176        }),
177        None => into_empty_request(RequestParts {
178            method,
179            uri,
180            headers,
181            body: (),
182        }),
183    }
184}
185
186#[cfg(feature = "json")]
187#[instrument(skip_all)]
188pub fn from_json_slice<T>(body: &[u8]) -> Result<T, Error>
189where
190    T: de::DeserializeOwned,
191{
192    debug!("deserializing JSON response");
193    Ok(serde_json::from_slice(body)?)
194}
195
196pub fn append_path_segment(out: &mut String, value: &str) {
197    append_percent_encoded(out, value.as_bytes());
198}
199
200pub fn append_query_pair(out: &mut String, first: &mut bool, key: &str, value: &str) {
201    if *first {
202        out.push('?');
203        *first = false;
204    } else {
205        out.push('&');
206    }
207    append_percent_encoded(out, key.as_bytes());
208    out.push('=');
209    append_percent_encoded(out, value.as_bytes());
210}
211
212pub fn format_offset_datetime(value: &OffsetDateTime) -> String {
213    value.format(&Rfc3339).unwrap_or_else(|_| value.to_string())
214}
215
216pub fn format_date(value: &Date) -> String {
217    format!(
218        "{:04}-{:02}-{:02}",
219        value.year(),
220        u8::from(value.month()),
221        value.day()
222    )
223}
224
225pub fn parse_date(value: &str) -> Result<Date, ParseDateError> {
226    let value = value.trim().as_bytes();
227    if value.len() != 10 || value[4] != b'-' || value[7] != b'-' {
228        return Err(ParseDateError::InvalidFormat);
229    }
230
231    for (index, byte) in value.iter().enumerate() {
232        if matches!(index, 4 | 7) {
233            if *byte != b'-' {
234                return Err(ParseDateError::InvalidFormat);
235            }
236        } else if !byte.is_ascii_digit() {
237            return Err(ParseDateError::InvalidFormat);
238        }
239    }
240
241    let year = parse_date_year(&value[0..4])?;
242    let month = parse_date_u8(&value[5..7])?;
243    let day = parse_date_u8(&value[8..10])?;
244    let month = Month::try_from(month).map_err(|_| ParseDateError::InvalidFormat)?;
245    Date::from_calendar_date(year, month, day).map_err(|_| ParseDateError::InvalidFormat)
246}
247
248fn parse_date_year(bytes: &[u8]) -> Result<i32, ParseDateError> {
249    let mut value = 0i32;
250    for byte in bytes {
251        value = value
252            .checked_mul(10)
253            .and_then(|value| value.checked_add(i32::from(*byte - b'0')))
254            .ok_or(ParseDateError::InvalidFormat)?;
255    }
256    Ok(value)
257}
258
259fn parse_date_u8(bytes: &[u8]) -> Result<u8, ParseDateError> {
260    let mut value = 0u16;
261    for byte in bytes {
262        value = value
263            .checked_mul(10)
264            .and_then(|value| value.checked_add(u16::from(*byte - b'0')))
265            .ok_or(ParseDateError::InvalidFormat)?;
266    }
267    u8::try_from(value).map_err(|_| ParseDateError::InvalidFormat)
268}
269
270pub fn format_naive_datetime(value: &PrimitiveDateTime) -> String {
271    format!(
272        "{}T{:02}:{:02}:{:02}",
273        format_date(&value.date()),
274        value.hour(),
275        value.minute(),
276        value.second()
277    )
278}
279
280pub fn parse_naive_datetime(value: &str) -> Result<PrimitiveDateTime, ParseNaiveDateTimeError> {
281    let value = value.trim();
282    let bytes = value.as_bytes();
283    if bytes.len() != 19
284        || bytes[4] != b'-'
285        || bytes[7] != b'-'
286        || bytes[10] != b'T'
287        || bytes[13] != b':'
288        || bytes[16] != b':'
289    {
290        return Err(ParseNaiveDateTimeError::InvalidFormat);
291    }
292
293    for (index, byte) in bytes.iter().enumerate() {
294        if matches!(index, 4 | 7 | 10 | 13 | 16) {
295            continue;
296        }
297        if !byte.is_ascii_digit() {
298            return Err(ParseNaiveDateTimeError::InvalidFormat);
299        }
300    }
301
302    let date = parse_date(&value[0..10]).map_err(|_| ParseNaiveDateTimeError::InvalidFormat)?;
303    let hour = parse_date_u8(&bytes[11..13]).map_err(|_| ParseNaiveDateTimeError::InvalidFormat)?;
304    let minute =
305        parse_date_u8(&bytes[14..16]).map_err(|_| ParseNaiveDateTimeError::InvalidFormat)?;
306    let second =
307        parse_date_u8(&bytes[17..19]).map_err(|_| ParseNaiveDateTimeError::InvalidFormat)?;
308    let time =
309        Time::from_hms(hour, minute, second).map_err(|_| ParseNaiveDateTimeError::InvalidFormat)?;
310    Ok(PrimitiveDateTime::new(date, time))
311}
312
313pub fn parse_time(value: &str) -> Result<Time, ParseTimeError> {
314    let value = value.trim();
315    let bytes = value.as_bytes();
316    if bytes.len() != 4 || !bytes.iter().all(u8::is_ascii_digit) {
317        return Err(ParseTimeError::InvalidFormat);
318    }
319
320    let hour = (bytes[0] - b'0') * 10 + (bytes[1] - b'0');
321    let minute = (bytes[2] - b'0') * 10 + (bytes[3] - b'0');
322    Time::from_hms(hour, minute, 0).map_err(|_| ParseTimeError::ComponentRange)
323}
324
325pub fn format_time(value: &Time) -> String {
326    format!("{:02}{:02}", value.hour(), value.minute())
327}
328
329pub fn format_bool(value: &bool) -> &'static str {
330    if *value { "1" } else { "0" }
331}
332
333pub fn parse_range<T>(value: &str) -> Result<(Option<T>, Option<T>), ParseRangeError>
334where
335    T: FromStr,
336    T::Err: fmt::Display,
337{
338    let value = value.trim();
339    if value.is_empty() {
340        return Ok((None, None));
341    }
342
343    let (min, max) = match value.split_once('-') {
344        Some((min, max)) => {
345            if max.contains('-') {
346                return Err(ParseRangeError::TooManySeparators);
347            }
348            (min, max)
349        }
350        None => (value, value),
351    };
352
353    Ok((parse_range_min(min)?, parse_range_max(max)?))
354}
355
356pub fn format_range<T>(min: &Option<T>, max: &Option<T>) -> String
357where
358    T: fmt::Display,
359{
360    match (min, max) {
361        (Some(min), Some(max)) => format!("{min}-{max}"),
362        (Some(min), None) => format!("{min}-"),
363        (None, Some(max)) => format!("-{max}"),
364        (None, None) => String::new(),
365    }
366}
367
368fn parse_range_min<T>(value: &str) -> Result<Option<T>, ParseRangeError>
369where
370    T: FromStr,
371    T::Err: fmt::Display,
372{
373    parse_range_bound(value, |value, message| ParseRangeError::InvalidMinimum {
374        value,
375        message,
376    })
377}
378
379fn parse_range_max<T>(value: &str) -> Result<Option<T>, ParseRangeError>
380where
381    T: FromStr,
382    T::Err: fmt::Display,
383{
384    parse_range_bound(value, |value, message| ParseRangeError::InvalidMaximum {
385        value,
386        message,
387    })
388}
389
390fn parse_range_bound<T>(
391    value: &str,
392    invalid: impl FnOnce(String, String) -> ParseRangeError,
393) -> Result<Option<T>, ParseRangeError>
394where
395    T: FromStr,
396    T::Err: fmt::Display,
397{
398    let value = value.trim();
399    if value.is_empty() {
400        return Ok(None);
401    }
402
403    value
404        .parse::<T>()
405        .map(Some)
406        .map_err(|err| invalid(value.to_owned(), err.to_string()))
407}
408
409#[cfg(feature = "serde")]
410pub mod serde_string {
411    use std::fmt;
412    use std::str::FromStr;
413
414    use serde::Deserialize;
415    use serde::de::Error as DeError;
416    use time::format_description::well_known::Rfc3339;
417
418    use crate::{Date, OffsetDateTime, PrimitiveDateTime, Time};
419
420    macro_rules! string_from_str_module {
421        ($module:ident, $ty:ty) => {
422            pub mod $module {
423                pub fn serialize<S>(value: &$ty, serializer: S) -> Result<S::Ok, S::Error>
424                where
425                    S: serde::Serializer,
426                {
427                    super::serialize_display(value, serializer)
428                }
429
430                pub fn deserialize<'de, D>(deserializer: D) -> Result<$ty, D::Error>
431                where
432                    D: serde::Deserializer<'de>,
433                {
434                    super::deserialize_from_str(deserializer)
435                }
436
437                pub mod option {
438                    pub fn serialize<S>(
439                        value: &Option<$ty>,
440                        serializer: S,
441                    ) -> Result<S::Ok, S::Error>
442                    where
443                        S: serde::Serializer,
444                    {
445                        super::super::serialize_option_display(value, serializer)
446                    }
447
448                    pub fn deserialize<'de, D>(deserializer: D) -> Result<Option<$ty>, D::Error>
449                    where
450                        D: serde::Deserializer<'de>,
451                    {
452                        super::super::deserialize_option_from_str(deserializer)
453                    }
454                }
455            }
456        };
457    }
458
459    macro_rules! string_float_module {
460        ($module:ident, $ty:ty) => {
461            pub mod $module {
462                use serde::Deserialize;
463                use serde::de::Error as DeError;
464
465                pub fn serialize<S>(value: &$ty, serializer: S) -> Result<S::Ok, S::Error>
466                where
467                    S: serde::Serializer,
468                {
469                    super::serialize_display(value, serializer)
470                }
471
472                pub fn deserialize<'de, D>(deserializer: D) -> Result<$ty, D::Error>
473                where
474                    D: serde::Deserializer<'de>,
475                {
476                    let value = <String as Deserialize>::deserialize(deserializer)?;
477                    fast_float2::parse::<$ty, _>(&value).map_err(DeError::custom)
478                }
479
480                pub mod option {
481                    use serde::Deserialize;
482                    use serde::de::Error as DeError;
483
484                    pub fn serialize<S>(
485                        value: &Option<$ty>,
486                        serializer: S,
487                    ) -> Result<S::Ok, S::Error>
488                    where
489                        S: serde::Serializer,
490                    {
491                        super::super::serialize_option_display(value, serializer)
492                    }
493
494                    pub fn deserialize<'de, D>(deserializer: D) -> Result<Option<$ty>, D::Error>
495                    where
496                        D: serde::Deserializer<'de>,
497                    {
498                        let value = <Option<String> as Deserialize>::deserialize(deserializer)?;
499                        value
500                            .map(|value| {
501                                fast_float2::parse::<$ty, _>(&value).map_err(DeError::custom)
502                            })
503                            .transpose()
504                    }
505                }
506            }
507        };
508    }
509
510    string_from_str_module!(as_u8, u8);
511    string_from_str_module!(as_u16, u16);
512    string_from_str_module!(as_u32, u32);
513    string_from_str_module!(as_u64, u64);
514    string_from_str_module!(as_i8, i8);
515    string_from_str_module!(as_i16, i16);
516    string_from_str_module!(as_i32, i32);
517    string_from_str_module!(as_i64, i64);
518    string_float_module!(as_f32, f32);
519    string_float_module!(as_f64, f64);
520
521    pub mod as_bool {
522        use std::fmt;
523
524        use serde::de::{Error as DeError, Visitor};
525
526        pub fn serialize<S>(value: &bool, serializer: S) -> Result<S::Ok, S::Error>
527        where
528            S: serde::Serializer,
529        {
530            serializer.serialize_str(crate::format_bool(value))
531        }
532
533        pub fn deserialize<'de, D>(deserializer: D) -> Result<bool, D::Error>
534        where
535            D: serde::Deserializer<'de>,
536        {
537            deserializer.deserialize_any(BoolVisitor)
538        }
539
540        struct BoolVisitor;
541
542        impl Visitor<'_> for BoolVisitor {
543            type Value = bool;
544
545            fn expecting(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
546                formatter.write_str("a boolean string or numeric boolean")
547            }
548
549            fn visit_bool<E>(self, value: bool) -> Result<Self::Value, E> {
550                Ok(value)
551            }
552
553            fn visit_str<E>(self, value: &str) -> Result<Self::Value, E>
554            where
555                E: DeError,
556            {
557                super::deserialize_bool(value).map_err(DeError::custom)
558            }
559
560            fn visit_string<E>(self, value: String) -> Result<Self::Value, E>
561            where
562                E: DeError,
563            {
564                self.visit_str(&value)
565            }
566
567            fn visit_i64<E>(self, value: i64) -> Result<Self::Value, E>
568            where
569                E: DeError,
570            {
571                match value {
572                    0 => Ok(false),
573                    1 => Ok(true),
574                    _ => Err(DeError::custom("invalid boolean number")),
575                }
576            }
577
578            fn visit_u64<E>(self, value: u64) -> Result<Self::Value, E>
579            where
580                E: DeError,
581            {
582                match value {
583                    0 => Ok(false),
584                    1 => Ok(true),
585                    _ => Err(DeError::custom("invalid boolean number")),
586                }
587            }
588        }
589
590        pub mod option {
591            use std::fmt;
592
593            use serde::de::Visitor;
594
595            pub fn serialize<S>(value: &Option<bool>, serializer: S) -> Result<S::Ok, S::Error>
596            where
597                S: serde::Serializer,
598            {
599                match value {
600                    Some(value) => super::serialize(value, serializer),
601                    None => serializer.serialize_none(),
602                }
603            }
604
605            pub fn deserialize<'de, D>(deserializer: D) -> Result<Option<bool>, D::Error>
606            where
607                D: serde::Deserializer<'de>,
608            {
609                deserializer.deserialize_option(BoolOptionVisitor)
610            }
611
612            struct BoolOptionVisitor;
613
614            impl<'de> Visitor<'de> for BoolOptionVisitor {
615                type Value = Option<bool>;
616
617                fn expecting(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
618                    formatter.write_str("an optional boolean string or numeric boolean")
619                }
620
621                fn visit_none<E>(self) -> Result<Self::Value, E> {
622                    Ok(None)
623                }
624
625                fn visit_unit<E>(self) -> Result<Self::Value, E> {
626                    Ok(None)
627                }
628
629                fn visit_some<D>(self, deserializer: D) -> Result<Self::Value, D::Error>
630                where
631                    D: serde::Deserializer<'de>,
632                {
633                    super::deserialize(deserializer).map(Some)
634                }
635            }
636        }
637    }
638
639    pub mod as_date {
640        use serde::Deserialize;
641        use serde::de::Error as DeError;
642
643        use super::*;
644
645        pub fn serialize<S>(value: &Date, serializer: S) -> Result<S::Ok, S::Error>
646        where
647            S: serde::Serializer,
648        {
649            serializer.serialize_str(&crate::format_date(value))
650        }
651
652        pub fn deserialize<'de, D>(deserializer: D) -> Result<Date, D::Error>
653        where
654            D: serde::Deserializer<'de>,
655        {
656            let value = <String as Deserialize>::deserialize(deserializer)?;
657            crate::parse_date(&value).map_err(DeError::custom)
658        }
659
660        pub mod option {
661            use serde::Deserialize;
662            use serde::de::Error as DeError;
663
664            use super::*;
665
666            pub fn serialize<S>(value: &Option<Date>, serializer: S) -> Result<S::Ok, S::Error>
667            where
668                S: serde::Serializer,
669            {
670                match value {
671                    Some(value) => super::serialize(value, serializer),
672                    None => serializer.serialize_none(),
673                }
674            }
675
676            pub fn deserialize<'de, D>(deserializer: D) -> Result<Option<Date>, D::Error>
677            where
678                D: serde::Deserializer<'de>,
679            {
680                let value = <Option<String> as Deserialize>::deserialize(deserializer)?;
681                value
682                    .map(|value| crate::parse_date(&value).map_err(DeError::custom))
683                    .transpose()
684            }
685        }
686    }
687
688    pub mod as_naive_datetime {
689        use serde::Deserialize;
690        use serde::de::Error as DeError;
691
692        use super::*;
693
694        pub fn serialize<S>(value: &PrimitiveDateTime, serializer: S) -> Result<S::Ok, S::Error>
695        where
696            S: serde::Serializer,
697        {
698            serializer.serialize_str(&crate::format_naive_datetime(value))
699        }
700
701        pub fn deserialize<'de, D>(deserializer: D) -> Result<PrimitiveDateTime, D::Error>
702        where
703            D: serde::Deserializer<'de>,
704        {
705            let value = <String as Deserialize>::deserialize(deserializer)?;
706            crate::parse_naive_datetime(&value).map_err(DeError::custom)
707        }
708
709        pub mod option {
710            use serde::Deserialize;
711            use serde::de::Error as DeError;
712
713            use super::*;
714
715            pub fn serialize<S>(
716                value: &Option<PrimitiveDateTime>,
717                serializer: S,
718            ) -> Result<S::Ok, S::Error>
719            where
720                S: serde::Serializer,
721            {
722                match value {
723                    Some(value) => super::serialize(value, serializer),
724                    None => serializer.serialize_none(),
725                }
726            }
727
728            pub fn deserialize<'de, D>(
729                deserializer: D,
730            ) -> Result<Option<PrimitiveDateTime>, D::Error>
731            where
732                D: serde::Deserializer<'de>,
733            {
734                let value = <Option<String> as Deserialize>::deserialize(deserializer)?;
735                value
736                    .map(|value| crate::parse_naive_datetime(&value).map_err(DeError::custom))
737                    .transpose()
738            }
739        }
740    }
741
742    pub mod as_offset_datetime {
743        use serde::Deserialize;
744        use serde::de::Error as DeError;
745        use serde::ser::Error as SerError;
746
747        use super::*;
748
749        pub fn serialize<S>(value: &OffsetDateTime, serializer: S) -> Result<S::Ok, S::Error>
750        where
751            S: serde::Serializer,
752        {
753            let value = value.format(&Rfc3339).map_err(SerError::custom)?;
754            serializer.serialize_str(&value)
755        }
756
757        pub fn deserialize<'de, D>(deserializer: D) -> Result<OffsetDateTime, D::Error>
758        where
759            D: serde::Deserializer<'de>,
760        {
761            let value = <String as Deserialize>::deserialize(deserializer)?;
762            OffsetDateTime::parse(&value, &Rfc3339).map_err(DeError::custom)
763        }
764
765        pub mod option {
766            use serde::Deserialize;
767            use serde::de::Error as DeError;
768
769            use super::*;
770
771            pub fn serialize<S>(
772                value: &Option<OffsetDateTime>,
773                serializer: S,
774            ) -> Result<S::Ok, S::Error>
775            where
776                S: serde::Serializer,
777            {
778                match value {
779                    Some(value) => super::serialize(value, serializer),
780                    None => serializer.serialize_none(),
781                }
782            }
783
784            pub fn deserialize<'de, D>(deserializer: D) -> Result<Option<OffsetDateTime>, D::Error>
785            where
786                D: serde::Deserializer<'de>,
787            {
788                let value = <Option<String> as Deserialize>::deserialize(deserializer)?;
789                value
790                    .map(|value| OffsetDateTime::parse(&value, &Rfc3339).map_err(DeError::custom))
791                    .transpose()
792            }
793        }
794    }
795
796    pub mod as_time {
797        use serde::Deserialize;
798        use serde::de::Error as DeError;
799
800        use super::*;
801
802        pub fn serialize<S>(value: &Time, serializer: S) -> Result<S::Ok, S::Error>
803        where
804            S: serde::Serializer,
805        {
806            serializer.serialize_str(&crate::format_time(value))
807        }
808
809        pub fn deserialize<'de, D>(deserializer: D) -> Result<Time, D::Error>
810        where
811            D: serde::Deserializer<'de>,
812        {
813            let value = <String as Deserialize>::deserialize(deserializer)?;
814            crate::parse_time(&value).map_err(DeError::custom)
815        }
816
817        pub mod option {
818            use serde::Deserialize;
819            use serde::de::Error as DeError;
820
821            use super::*;
822
823            pub fn serialize<S>(value: &Option<Time>, serializer: S) -> Result<S::Ok, S::Error>
824            where
825                S: serde::Serializer,
826            {
827                match value {
828                    Some(value) => super::serialize(value, serializer),
829                    None => serializer.serialize_none(),
830                }
831            }
832
833            pub fn deserialize<'de, D>(deserializer: D) -> Result<Option<Time>, D::Error>
834            where
835                D: serde::Deserializer<'de>,
836            {
837                let value = <Option<String> as Deserialize>::deserialize(deserializer)?;
838                let Some(value) = value else {
839                    return Ok(None);
840                };
841                let value = value.trim();
842                if value.is_empty() {
843                    return Ok(None);
844                }
845                crate::parse_time(value).map(Some).map_err(DeError::custom)
846            }
847        }
848    }
849
850    fn serialize_display<T, S>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
851    where
852        T: fmt::Display,
853        S: serde::Serializer,
854    {
855        serializer.serialize_str(&value.to_string())
856    }
857
858    fn serialize_option_display<T, S>(value: &Option<T>, serializer: S) -> Result<S::Ok, S::Error>
859    where
860        T: fmt::Display,
861        S: serde::Serializer,
862    {
863        match value {
864            Some(value) => serialize_display(value, serializer),
865            None => serializer.serialize_none(),
866        }
867    }
868
869    fn deserialize_from_str<'de, T, D>(deserializer: D) -> Result<T, D::Error>
870    where
871        T: FromStr,
872        T::Err: fmt::Display,
873        D: serde::Deserializer<'de>,
874    {
875        let value = String::deserialize(deserializer)?;
876        value.parse::<T>().map_err(DeError::custom)
877    }
878
879    fn deserialize_option_from_str<'de, T, D>(deserializer: D) -> Result<Option<T>, D::Error>
880    where
881        T: FromStr,
882        T::Err: fmt::Display,
883        D: serde::Deserializer<'de>,
884    {
885        let value = Option::<String>::deserialize(deserializer)?;
886        value
887            .map(|value| value.parse::<T>().map_err(DeError::custom))
888            .transpose()
889    }
890
891    fn deserialize_bool(value: &str) -> Result<bool, &'static str> {
892        match value {
893            "1" => Ok(true),
894            "0" => Ok(false),
895            value if value.eq_ignore_ascii_case("true") => Ok(true),
896            value if value.eq_ignore_ascii_case("false") => Ok(false),
897            _ => Err("invalid boolean string"),
898        }
899    }
900}
901
902#[cfg(feature = "serde")]
903pub mod serde_integer {
904    pub mod as_bool {
905        use crate::serde_string::as_bool as string_bool;
906
907        pub fn serialize<S>(value: &bool, serializer: S) -> Result<S::Ok, S::Error>
908        where
909            S: serde::Serializer,
910        {
911            serializer.serialize_u8(u8::from(*value))
912        }
913
914        pub fn deserialize<'de, D>(deserializer: D) -> Result<bool, D::Error>
915        where
916            D: serde::Deserializer<'de>,
917        {
918            string_bool::deserialize(deserializer)
919        }
920
921        pub mod option {
922            use crate::serde_string::as_bool::option as string_bool_option;
923
924            pub fn serialize<S>(value: &Option<bool>, serializer: S) -> Result<S::Ok, S::Error>
925            where
926                S: serde::Serializer,
927            {
928                match value {
929                    Some(value) => super::serialize(value, serializer),
930                    None => serializer.serialize_none(),
931                }
932            }
933
934            pub fn deserialize<'de, D>(deserializer: D) -> Result<Option<bool>, D::Error>
935            where
936                D: serde::Deserializer<'de>,
937            {
938                string_bool_option::deserialize(deserializer)
939            }
940        }
941    }
942}
943
944#[cfg(feature = "json")]
945pub mod treat_error_as_none {
946    use serde::de::DeserializeOwned;
947    use serde::{Deserialize, Deserializer, Serialize, Serializer};
948
949    pub fn serialize<S, T>(value: &Option<T>, serializer: S) -> Result<S::Ok, S::Error>
950    where
951        S: Serializer,
952        T: Serialize,
953    {
954        match value {
955            Some(inner) => inner.serialize(serializer),
956            None => serializer.serialize_none(),
957        }
958    }
959
960    pub fn deserialize<'de, D, T>(deserializer: D) -> Result<Option<T>, D::Error>
961    where
962        D: Deserializer<'de>,
963        T: DeserializeOwned,
964    {
965        let value = serde_json::Value::deserialize(deserializer)?;
966        match T::deserialize(value) {
967            Ok(parsed) => Ok(Some(parsed)),
968            Err(_) => Ok(None),
969        }
970    }
971}
972
973pub fn insert_header(
974    headers: &mut http::HeaderMap,
975    name: &'static str,
976    value: &str,
977) -> Result<(), Error> {
978    headers.insert(
979        HeaderName::from_bytes(name.as_bytes())?,
980        HeaderValue::from_str(value)?,
981    );
982    Ok(())
983}
984
985pub fn has_json_content_type(headers: &http::HeaderMap) -> bool {
986    headers
987        .get(CONTENT_TYPE)
988        .and_then(|value| value.to_str().ok())
989        .is_some_and(is_json_media_type)
990}
991
992fn is_json_media_type(value: &str) -> bool {
993    let media_type = value.split(';').next().unwrap_or(value).trim();
994    if media_type.eq_ignore_ascii_case("application/json") {
995        return true;
996    }
997
998    let Some((_, subtype)) = media_type.rsplit_once('/') else {
999        return false;
1000    };
1001    ends_with_ignore_ascii_case(subtype, "+json")
1002}
1003
1004fn ends_with_ignore_ascii_case(value: &str, suffix: &str) -> bool {
1005    let value = value.as_bytes();
1006    let suffix = suffix.as_bytes();
1007    value.len() >= suffix.len() && value[value.len() - suffix.len()..].eq_ignore_ascii_case(suffix)
1008}
1009
1010fn append_percent_encoded(out: &mut String, bytes: &[u8]) {
1011    const HEX: &[u8; 16] = b"0123456789ABCDEF";
1012
1013    for &byte in bytes {
1014        if is_unreserved(byte) {
1015            out.push(byte as char);
1016        } else {
1017            out.push('%');
1018            out.push(HEX[(byte >> 4) as usize] as char);
1019            out.push(HEX[(byte & 0x0f) as usize] as char);
1020        }
1021    }
1022}
1023
1024const fn is_unreserved(byte: u8) -> bool {
1025    matches!(
1026        byte,
1027        b'A'..=b'Z' | b'a'..=b'z' | b'0'..=b'9' | b'-' | b'.' | b'_' | b'~'
1028    )
1029}
1030
1031#[cfg(test)]
1032mod tests {
1033    use super::*;
1034
1035    #[test]
1036    fn encodes_path_segments() {
1037        let mut out = String::new();
1038        append_path_segment(&mut out, "a/b c");
1039        assert_eq!(out, "a%2Fb%20c");
1040    }
1041
1042    #[test]
1043    fn appends_query_pairs() {
1044        let mut out = String::from("/pets");
1045        let mut first = true;
1046        append_query_pair(&mut out, &mut first, "tag name", "small/dog");
1047        append_query_pair(&mut out, &mut first, "limit", "10");
1048        assert_eq!(out, "/pets?tag%20name=small%2Fdog&limit=10");
1049    }
1050
1051    #[test]
1052    fn parses_range_strings() {
1053        assert_eq!(parse_range::<u8>("14-17").unwrap(), (Some(14), Some(17)));
1054        assert_eq!(parse_range::<u8>("14-").unwrap(), (Some(14), None));
1055        assert_eq!(parse_range::<u8>("-17").unwrap(), (None, Some(17)));
1056        assert_eq!(parse_range::<u8>("").unwrap(), (None, None));
1057        assert!(matches!(
1058            parse_range::<u8>("14-17-20"),
1059            Err(ParseRangeError::TooManySeparators)
1060        ));
1061    }
1062
1063    #[test]
1064    fn formats_range_strings() {
1065        assert_eq!(format_range(&Some(14), &Some(17)), "14-17");
1066        assert_eq!(format_range(&Some(14), &None::<u8>), "14-");
1067        assert_eq!(format_range(&None::<u8>, &Some(17)), "-17");
1068        assert_eq!(format_range(&None::<u8>, &None::<u8>), "");
1069    }
1070
1071    #[test]
1072    fn parses_and_formats_time_strings() {
1073        let time = parse_time("0620").unwrap();
1074        assert_eq!(time.hour(), 6);
1075        assert_eq!(time.minute(), 20);
1076        assert_eq!(format_time(&time), "0620");
1077        assert_eq!(parse_time("6:20"), Err(ParseTimeError::InvalidFormat));
1078        assert_eq!(parse_time("2400"), Err(ParseTimeError::ComponentRange));
1079    }
1080
1081    #[test]
1082    fn parses_and_formats_date_strings() {
1083        let date = parse_date("2024-07-16").unwrap();
1084        assert_eq!(
1085            date,
1086            Date::from_calendar_date(2024, Month::July, 16).unwrap()
1087        );
1088        assert_eq!(format_date(&date), "2024-07-16");
1089        assert_eq!(parse_date("2024-7-16"), Err(ParseDateError::InvalidFormat));
1090        assert_eq!(parse_date("not-a-date"), Err(ParseDateError::InvalidFormat));
1091    }
1092
1093    #[test]
1094    fn parses_and_formats_naive_datetime_strings() {
1095        let datetime = parse_naive_datetime("2024-07-16T23:59:00").unwrap();
1096        assert_eq!(datetime.hour(), 23);
1097        assert_eq!(datetime.minute(), 59);
1098        assert_eq!(datetime.second(), 0);
1099        assert_eq!(format_naive_datetime(&datetime), "2024-07-16T23:59:00");
1100        assert_eq!(
1101            parse_naive_datetime("2024-07-16T23:59"),
1102            Err(ParseNaiveDateTimeError::InvalidFormat)
1103        );
1104        assert_eq!(
1105            parse_naive_datetime("2024-07-16 23:59:00"),
1106            Err(ParseNaiveDateTimeError::InvalidFormat)
1107        );
1108    }
1109
1110    #[cfg(all(feature = "serde", feature = "json"))]
1111    #[test]
1112    fn serde_string_naive_datetime_round_trips() {
1113        #[derive(serde::Deserialize, serde::Serialize)]
1114        struct Value {
1115            #[serde(with = "crate::serde_string::as_naive_datetime")]
1116            at: PrimitiveDateTime,
1117        }
1118
1119        let at = parse_naive_datetime("2024-07-16T23:59:00").unwrap();
1120        let encoded = serde_json::to_value(Value { at }).unwrap();
1121        assert_eq!(encoded, serde_json::json!({ "at": "2024-07-16T23:59:00" }));
1122
1123        let decoded = serde_json::from_value::<Value>(encoded).unwrap();
1124        assert_eq!(decoded.at, at);
1125    }
1126
1127    #[cfg(all(feature = "serde", feature = "json"))]
1128    #[test]
1129    fn serde_string_date_round_trips() {
1130        #[derive(serde::Deserialize, serde::Serialize)]
1131        struct Value {
1132            #[serde(with = "crate::serde_string::as_date")]
1133            day: Date,
1134        }
1135
1136        let date = Date::from_calendar_date(2024, Month::July, 16).unwrap();
1137        let encoded = serde_json::to_value(Value { day: date }).unwrap();
1138        assert_eq!(encoded, serde_json::json!({ "day": "2024-07-16" }));
1139
1140        let decoded = serde_json::from_value::<Value>(encoded).unwrap();
1141        assert_eq!(decoded.day, date);
1142    }
1143
1144    #[test]
1145    fn recognizes_json_content_types() {
1146        let mut headers = http::HeaderMap::new();
1147        headers.insert(
1148            CONTENT_TYPE,
1149            http::HeaderValue::from_static("application/problem+json; charset=utf-8"),
1150        );
1151        assert!(has_json_content_type(&headers));
1152    }
1153
1154    #[test]
1155    fn response_parts_holds_status_headers_body() {
1156        let mut headers = http::HeaderMap::new();
1157        headers.insert(
1158            CONTENT_TYPE,
1159            http::HeaderValue::from_static("application/json"),
1160        );
1161        let body = br#"{"ok":true}"#.to_vec();
1162        let parts = ResponseParts {
1163            status: http::StatusCode::OK,
1164            headers,
1165            body,
1166        };
1167        assert_eq!(parts.status, http::StatusCode::OK);
1168        assert_eq!(parts.headers.get(CONTENT_TYPE).unwrap(), "application/json");
1169        assert_eq!(parts.body, br#"{"ok":true}"#);
1170    }
1171
1172    #[cfg(all(feature = "serde", feature = "json"))]
1173    #[test]
1174    fn serde_string_bool_accepts_string_and_numeric_values() {
1175        #[derive(serde::Deserialize, serde::Serialize)]
1176        struct Value {
1177            #[serde(with = "crate::serde_string::as_bool")]
1178            monitored: bool,
1179        }
1180
1181        let numeric = serde_json::from_str::<Value>(r#"{"monitored":0}"#).unwrap();
1182        assert!(!numeric.monitored);
1183
1184        let string = serde_json::from_str::<Value>(r#"{"monitored":"1"}"#).unwrap();
1185        assert!(string.monitored);
1186
1187        let encoded = serde_json::to_value(Value { monitored: false }).unwrap();
1188        assert_eq!(encoded, serde_json::json!({ "monitored": "0" }));
1189    }
1190
1191    #[cfg(all(feature = "serde", feature = "json"))]
1192    #[test]
1193    fn serde_integer_bool_accepts_numeric_values() {
1194        #[derive(serde::Deserialize, serde::Serialize)]
1195        struct Value {
1196            #[serde(with = "crate::serde_integer::as_bool")]
1197            monitored: bool,
1198        }
1199
1200        let numeric = serde_json::from_str::<Value>(r#"{"monitored":0}"#).unwrap();
1201        assert!(!numeric.monitored);
1202
1203        let encoded = serde_json::to_value(Value { monitored: true }).unwrap();
1204        assert_eq!(encoded, serde_json::json!({ "monitored": 1 }));
1205    }
1206}