Skip to main content

rustauth_core/api/
body.rs

1//! Request body parsing helpers for framework-neutral auth endpoints.
2
3use http::header;
4use serde::de::DeserializeOwned;
5use serde_json::{Map, Value};
6
7use super::ApiRequest;
8use crate::error::RustAuthError;
9
10/// Parse a request body as JSON or `application/x-www-form-urlencoded`.
11pub fn parse_request_body<T>(request: &ApiRequest) -> Result<T, RustAuthError>
12where
13    T: DeserializeOwned,
14{
15    match request_content_type(request) {
16        Some("application/json") => parse_json_body(request.body()),
17        Some("application/x-www-form-urlencoded") => parse_form_body(request.body()),
18        Some(content_type) => Err(RustAuthError::UnsupportedContentType {
19            content_type: content_type.to_owned(),
20        }),
21        None => Err(RustAuthError::MissingContentType),
22    }
23}
24
25fn parse_json_body<T>(body: &[u8]) -> Result<T, RustAuthError>
26where
27    T: DeserializeOwned,
28{
29    serde_json::from_slice(body).map_err(|error| RustAuthError::InvalidRequestBody {
30        encoding: "JSON",
31        message: error.to_string(),
32    })
33}
34
35fn parse_form_body<T>(body: &[u8]) -> Result<T, RustAuthError>
36where
37    T: DeserializeOwned,
38{
39    let body = std::str::from_utf8(body).map_err(|error| RustAuthError::InvalidRequestBody {
40        encoding: "form",
41        message: error.to_string(),
42    })?;
43    let mut map = Map::new();
44
45    if !body.is_empty() {
46        for pair in body.split('&') {
47            let (name, value) = pair.split_once('=').unwrap_or((pair, ""));
48            let name =
49                decode_form_component(name).map_err(|error| RustAuthError::InvalidRequestBody {
50                    encoding: "form",
51                    message: error.to_owned(),
52                })?;
53            let value = decode_form_component(value).map_err(|error| {
54                RustAuthError::InvalidRequestBody {
55                    encoding: "form",
56                    message: error.to_owned(),
57                }
58            })?;
59            insert_form_value(&mut map, name, form_value(value));
60        }
61    }
62
63    serde_json::from_value(Value::Object(map)).map_err(|error| RustAuthError::InvalidRequestBody {
64        encoding: "form",
65        message: error.to_string(),
66    })
67}
68
69fn request_content_type(request: &ApiRequest) -> Option<&str> {
70    let content_type = request.headers().get(header::CONTENT_TYPE)?.to_str().ok()?;
71    let media_type = content_type
72        .split(';')
73        .next()
74        .unwrap_or(content_type)
75        .trim();
76    media_type
77        .eq_ignore_ascii_case("application/json")
78        .then_some("application/json")
79        .or_else(|| {
80            media_type
81                .eq_ignore_ascii_case("application/x-www-form-urlencoded")
82                .then_some("application/x-www-form-urlencoded")
83        })
84        .or(Some(media_type))
85}
86
87fn form_value(value: String) -> Value {
88    match value.as_str() {
89        "true" => Value::Bool(true),
90        "false" => Value::Bool(false),
91        _ => Value::String(value),
92    }
93}
94
95fn insert_form_value(map: &mut Map<String, Value>, name: String, value: Value) {
96    match map.get_mut(&name) {
97        Some(Value::Array(values)) => values.push(value),
98        Some(existing) => {
99            let previous = std::mem::replace(existing, Value::Null);
100            *existing = Value::Array(vec![previous, value]);
101        }
102        None => {
103            map.insert(name, value);
104        }
105    }
106}
107
108fn decode_form_component(value: &str) -> Result<String, &'static str> {
109    let mut decoded = Vec::with_capacity(value.len());
110    let bytes = value.as_bytes();
111    let mut index = 0;
112
113    while index < bytes.len() {
114        match bytes[index] {
115            b'+' => {
116                decoded.push(b' ');
117                index += 1;
118            }
119            b'%' => {
120                if index + 2 >= bytes.len() {
121                    return Err("incomplete percent escape");
122                }
123                let high = hex_value(bytes[index + 1]).ok_or("invalid percent escape")?;
124                let low = hex_value(bytes[index + 2]).ok_or("invalid percent escape")?;
125                decoded.push((high << 4) | low);
126                index += 3;
127            }
128            byte => {
129                decoded.push(byte);
130                index += 1;
131            }
132        }
133    }
134
135    String::from_utf8(decoded).map_err(|_| "decoded form value is not valid UTF-8")
136}
137
138fn hex_value(byte: u8) -> Option<u8> {
139    match byte {
140        b'0'..=b'9' => Some(byte - b'0'),
141        b'a'..=b'f' => Some(byte - b'a' + 10),
142        b'A'..=b'F' => Some(byte - b'A' + 10),
143        _ => None,
144    }
145}