Skip to main content

structured_proxy/transcode/
request.rs

1//! Request message construction for REST→gRPC transcoding.
2//!
3//! Assembles the gRPC request JSON from three `google.api.http` sources, in
4//! precedence order: path parameters (highest), the request body, then query
5//! parameters (lowest, fill only). Path and query values arrive as strings, so
6//! they are coerced to each field's proto type before prost-reflect decodes the
7//! message.
8
9use prost_reflect::{FieldDescriptor, Kind, MessageDescriptor};
10use serde_json::{Map, Value};
11use std::collections::HashMap;
12
13/// How the HTTP request body maps onto the gRPC request message.
14#[derive(Debug, Clone, PartialEq, Eq)]
15pub enum BodyMapping {
16    /// No body is read; fields come from path + query (typical for GET/DELETE).
17    None,
18    /// The entire body maps to the message root (`body: "*"`).
19    Root,
20    /// The body maps to a single named field of the message (`body: "field"`).
21    Field(String),
22}
23
24impl BodyMapping {
25    /// Parse the `body` value of a `google.api.http` rule.
26    ///
27    /// `""` (absent) → [`BodyMapping::None`], `"*"` → [`BodyMapping::Root`],
28    /// any other string → [`BodyMapping::Field`].
29    pub fn parse(raw: &str) -> Self {
30        match raw {
31            "" => BodyMapping::None,
32            "*" => BodyMapping::Root,
33            field => BodyMapping::Field(field.to_string()),
34        }
35    }
36}
37
38/// Build the request-message JSON from the body mapping, path params, and query.
39///
40/// Path-bound fields win over the body, and the body wins over query parameters
41/// (query only fills fields not already set). Unknown query keys are dropped
42/// rather than rejected, matching common transcoder behavior.
43///
44/// # Errors
45/// Returns an error string if `body` maps to the message root but the parsed
46/// body is not a JSON object.
47pub fn build_request_json(
48    input: &MessageDescriptor,
49    body_mapping: &BodyMapping,
50    body_json: Value,
51    path_params: &HashMap<String, String>,
52    query: &[(String, String)],
53) -> Result<Value, String> {
54    let mut root = match body_mapping {
55        BodyMapping::None => Value::Object(Map::new()),
56        BodyMapping::Root => match body_json {
57            Value::Object(_) => body_json,
58            Value::Null => Value::Object(Map::new()),
59            _ => return Err("request body must be a JSON object".to_string()),
60        },
61        BodyMapping::Field(field) => {
62            let mut m = Map::new();
63            m.insert(field.clone(), body_json);
64            Value::Object(m)
65        }
66    };
67
68    // Path params win over everything (the router already matched them).
69    for (key, raw) in path_params {
70        set_field(&mut root, input, key, true, |field| {
71            coerce(&field.kind(), raw)
72        });
73    }
74
75    // Query params: group repeated keys, fill only fields not already present.
76    for (key, values) in group_query(query) {
77        set_field(&mut root, input, &key, false, |field| {
78            if field.is_list() {
79                Value::Array(values.iter().map(|v| coerce(&field.kind(), v)).collect())
80            } else {
81                // A non-repeated field bound multiple times takes the last value.
82                coerce(&field.kind(), values.last().expect("group is non-empty"))
83            }
84        });
85    }
86
87    Ok(root)
88}
89
90/// Parse a raw query string into ordered key/value pairs.
91///
92/// `None` and the empty string yield no pairs. A non-empty string must be valid
93/// `application/x-www-form-urlencoded`.
94///
95/// # Errors
96/// Returns an error string when the query cannot be parsed, so the caller can
97/// reject the request rather than silently dropping every query-bound field.
98pub fn parse_query(raw: Option<&str>) -> Result<Vec<(String, String)>, String> {
99    match raw {
100        None | Some("") => Ok(Vec::new()),
101        Some(q) => serde_urlencoded::from_str(q).map_err(|e| format!("invalid query string: {e}")),
102    }
103}
104
105/// Extract a (possibly dotted) subfield of the response JSON for `response_body`.
106///
107/// Returns `None` when any path segment is missing, letting the caller
108/// distinguish a misconfigured path from a field that is legitimately null.
109pub fn extract_response_body(value: &Value, path: &str) -> Option<Value> {
110    let mut cur = value;
111    for seg in path.split('.') {
112        cur = cur.get(seg)?;
113    }
114    Some(cur.clone())
115}
116
117/// Group query pairs by key, preserving value order, so repeated keys
118/// (`?tag=a&tag=b`) collect into one entry for repeated-field binding.
119fn group_query(query: &[(String, String)]) -> Vec<(String, Vec<String>)> {
120    let mut grouped: Vec<(String, Vec<String>)> = Vec::new();
121    for (k, v) in query {
122        if let Some((_, vals)) = grouped.iter_mut().find(|(gk, _)| gk == k) {
123            vals.push(v.clone());
124        } else {
125            grouped.push((k.clone(), vec![v.clone()]));
126        }
127    }
128    grouped
129}
130
131/// Resolve a (possibly dotted) field path against the message descriptor and
132/// JSON tree, creating intermediate objects, then set the leaf via `make`.
133///
134/// `overwrite = false` leaves an already-present leaf untouched (query fill).
135/// Unknown fields or non-message intermediates are silently skipped.
136fn set_field<F>(root: &mut Value, input: &MessageDescriptor, dotted: &str, overwrite: bool, make: F)
137where
138    F: FnOnce(&FieldDescriptor) -> Value,
139{
140    let segments: Vec<&str> = dotted.split('.').collect();
141    let mut desc = input.clone();
142    let mut cur = root;
143
144    for seg in &segments[..segments.len() - 1] {
145        let Some(field) = desc.get_field_by_name(seg) else {
146            return;
147        };
148        let Kind::Message(message) = field.kind() else {
149            return;
150        };
151        desc = message;
152        let Some(obj) = cur.as_object_mut() else {
153            return;
154        };
155        cur = obj
156            .entry((*seg).to_string())
157            .or_insert_with(|| Value::Object(Map::new()));
158    }
159
160    let leaf = segments[segments.len() - 1];
161    let Some(field) = desc.get_field_by_name(leaf) else {
162        return;
163    };
164    let Some(obj) = cur.as_object_mut() else {
165        return;
166    };
167    if !overwrite && obj.contains_key(leaf) {
168        return;
169    }
170    obj.insert(leaf.to_string(), make(&field));
171}
172
173/// Coerce a path/query string to a JSON value matching the field's proto type,
174/// so prost-reflect's proto3-JSON decoder accepts it.
175///
176/// 32-bit integers and floats become JSON numbers; 64-bit integers stay strings
177/// (their canonical proto3-JSON form). Booleans parse to JSON booleans. Anything
178/// that fails to parse falls back to the raw string.
179fn coerce(kind: &Kind, raw: &str) -> Value {
180    match kind {
181        Kind::Bool => raw
182            .parse::<bool>()
183            .map(Value::Bool)
184            .unwrap_or_else(|_| Value::String(raw.to_string())),
185        Kind::Int32 | Kind::Sint32 | Kind::Sfixed32 => raw
186            .parse::<i32>()
187            .map(|n| Value::Number(n.into()))
188            .unwrap_or_else(|_| Value::String(raw.to_string())),
189        Kind::Uint32 | Kind::Fixed32 => raw
190            .parse::<u32>()
191            .map(|n| Value::Number(n.into()))
192            .unwrap_or_else(|_| Value::String(raw.to_string())),
193        Kind::Double | Kind::Float => raw
194            .parse::<f64>()
195            .ok()
196            .and_then(serde_json::Number::from_f64)
197            .map(Value::Number)
198            .unwrap_or_else(|| Value::String(raw.to_string())),
199        // 64-bit ints (canonical proto3 JSON is a string), strings, bytes, enums
200        // (name), and anything else pass through as a string.
201        _ => Value::String(raw.to_string()),
202    }
203}
204
205#[cfg(test)]
206mod tests {
207    use super::*;
208    use prost_reflect::prost::Message;
209    use prost_reflect::prost_types::{
210        field_descriptor_proto::{Label, Type},
211        DescriptorProto, FieldDescriptorProto, FileDescriptorProto, FileDescriptorSet,
212    };
213    use prost_reflect::DescriptorPool;
214
215    fn field(
216        name: &str,
217        num: i32,
218        ty: Type,
219        label: Label,
220        type_name: Option<&str>,
221    ) -> FieldDescriptorProto {
222        FieldDescriptorProto {
223            name: Some(name.to_string()),
224            number: Some(num),
225            label: Some(label as i32),
226            r#type: Some(ty as i32),
227            type_name: type_name.map(|s| s.to_string()),
228            ..Default::default()
229        }
230    }
231
232    /// Build a small descriptor pool with a typed message for coercion tests.
233    fn test_msg() -> MessageDescriptor {
234        let nested = DescriptorProto {
235            name: Some("Nested".to_string()),
236            field: vec![field("city", 1, Type::String, Label::Optional, None)],
237            ..Default::default()
238        };
239        let msg = DescriptorProto {
240            name: Some("TestMsg".to_string()),
241            field: vec![
242                field("name", 1, Type::String, Label::Optional, None),
243                field("age", 2, Type::Int32, Label::Optional, None),
244                field("active", 3, Type::Bool, Label::Optional, None),
245                field("tags", 4, Type::String, Label::Repeated, None),
246                field("count", 5, Type::Int64, Label::Optional, None),
247                field(
248                    "nested",
249                    6,
250                    Type::Message,
251                    Label::Optional,
252                    Some(".test.TestMsg.Nested"),
253                ),
254            ],
255            nested_type: vec![nested],
256            ..Default::default()
257        };
258        let file = FileDescriptorProto {
259            name: Some("test.proto".to_string()),
260            package: Some("test".to_string()),
261            message_type: vec![msg],
262            syntax: Some("proto3".to_string()),
263            ..Default::default()
264        };
265        let fds = FileDescriptorSet { file: vec![file] };
266        let pool = DescriptorPool::decode(fds.encode_to_vec().as_slice()).unwrap();
267        pool.get_message_by_name("test.TestMsg").unwrap()
268    }
269
270    fn pp(pairs: &[(&str, &str)]) -> HashMap<String, String> {
271        pairs
272            .iter()
273            .map(|(k, v)| (k.to_string(), v.to_string()))
274            .collect()
275    }
276
277    fn qq(pairs: &[(&str, &str)]) -> Vec<(String, String)> {
278        pairs
279            .iter()
280            .map(|(k, v)| (k.to_string(), v.to_string()))
281            .collect()
282    }
283
284    #[test]
285    fn coerce_unsigned_32_rejects_out_of_range() {
286        // u32 fields must not accept negatives or values above u32::MAX; those
287        // fall back to a raw string so prost-reflect rejects them precisely.
288        assert_eq!(coerce(&Kind::Uint32, "-1"), Value::String("-1".into()));
289        assert_eq!(
290            coerce(&Kind::Uint32, "4294967296"),
291            Value::String("4294967296".into())
292        );
293        assert_eq!(coerce(&Kind::Uint32, "42"), Value::Number(42.into()));
294        assert_eq!(coerce(&Kind::Fixed32, "-1"), Value::String("-1".into()));
295        // Signed 32-bit still accepts negatives.
296        assert_eq!(coerce(&Kind::Int32, "-5"), Value::Number((-5).into()));
297        // And rejects values outside i32 range.
298        assert_eq!(
299            coerce(&Kind::Int32, "2147483648"),
300            Value::String("2147483648".into())
301        );
302    }
303
304    #[test]
305    fn body_mapping_parse() {
306        assert_eq!(BodyMapping::parse(""), BodyMapping::None);
307        assert_eq!(BodyMapping::parse("*"), BodyMapping::Root);
308        assert_eq!(
309            BodyMapping::parse("resource"),
310            BodyMapping::Field("resource".into())
311        );
312    }
313
314    #[test]
315    fn body_root_merges_path_and_query() {
316        let m = test_msg();
317        let body = serde_json::json!({ "name": "alice" });
318        let out = build_request_json(
319            &m,
320            &BodyMapping::Root,
321            body,
322            &pp(&[("age", "30")]),
323            &qq(&[("active", "true")]),
324        )
325        .unwrap();
326        assert_eq!(out["name"], "alice");
327        assert_eq!(out["age"], 30); // Int32 coerced to a JSON number
328        assert_eq!(out["active"], true); // Bool coerced
329    }
330
331    #[test]
332    fn body_field_nests_body_under_named_field() {
333        let m = test_msg();
334        let body = serde_json::json!({ "city": "berlin" });
335        let out = build_request_json(
336            &m,
337            &BodyMapping::Field("nested".into()),
338            body,
339            &pp(&[]),
340            &qq(&[("name", "bob")]),
341        )
342        .unwrap();
343        assert_eq!(out["nested"]["city"], "berlin");
344        assert_eq!(out["name"], "bob");
345    }
346
347    #[test]
348    fn query_repeated_field_becomes_array() {
349        let m = test_msg();
350        let out = build_request_json(
351            &m,
352            &BodyMapping::None,
353            Value::Null,
354            &pp(&[]),
355            &qq(&[("tags", "a"), ("tags", "b")]),
356        )
357        .unwrap();
358        assert_eq!(out["tags"], serde_json::json!(["a", "b"]));
359    }
360
361    #[test]
362    fn query_dotted_path_sets_nested_field() {
363        let m = test_msg();
364        let out = build_request_json(
365            &m,
366            &BodyMapping::None,
367            Value::Null,
368            &pp(&[]),
369            &qq(&[("nested.city", "paris")]),
370        )
371        .unwrap();
372        assert_eq!(out["nested"]["city"], "paris");
373    }
374
375    #[test]
376    fn query_does_not_override_body_or_path() {
377        let m = test_msg();
378        let body = serde_json::json!({ "name": "from_body" });
379        let out = build_request_json(
380            &m,
381            &BodyMapping::Root,
382            body,
383            &pp(&[("age", "7")]),
384            &qq(&[("name", "from_query"), ("age", "99")]),
385        )
386        .unwrap();
387        assert_eq!(out["name"], "from_body"); // body wins over query
388        assert_eq!(out["age"], 7); // path wins over query
389    }
390
391    #[test]
392    fn int64_field_stays_string() {
393        let m = test_msg();
394        let out = build_request_json(
395            &m,
396            &BodyMapping::None,
397            Value::Null,
398            &pp(&[]),
399            &qq(&[("count", "9007199254740993")]),
400        )
401        .unwrap();
402        // 64-bit ints serialize as JSON strings in canonical proto3 JSON.
403        assert_eq!(out["count"], "9007199254740993");
404    }
405
406    #[test]
407    fn unknown_query_field_is_dropped() {
408        let m = test_msg();
409        let out = build_request_json(
410            &m,
411            &BodyMapping::None,
412            Value::Null,
413            &pp(&[]),
414            &qq(&[("does_not_exist", "x")]),
415        )
416        .unwrap();
417        assert_eq!(out.get("does_not_exist"), None);
418    }
419
420    #[test]
421    fn root_body_must_be_object() {
422        let m = test_msg();
423        let err = build_request_json(
424            &m,
425            &BodyMapping::Root,
426            serde_json::json!("a string"),
427            &pp(&[]),
428            &qq(&[]),
429        );
430        assert!(err.is_err());
431    }
432
433    #[test]
434    fn extract_response_body_walks_dotted_path() {
435        let v = serde_json::json!({ "result": { "token": "abc" } });
436        assert_eq!(
437            extract_response_body(&v, "result.token"),
438            Some(serde_json::json!("abc"))
439        );
440        assert_eq!(
441            extract_response_body(&v, "result"),
442            Some(serde_json::json!({ "token": "abc" }))
443        );
444        // A missing path is None (caller can warn), distinct from a null field.
445        assert_eq!(extract_response_body(&v, "missing"), None);
446    }
447
448    #[test]
449    fn parse_query_handles_empty_and_pairs() {
450        assert_eq!(parse_query(None).unwrap(), Vec::<(String, String)>::new());
451        assert_eq!(
452            parse_query(Some("")).unwrap(),
453            Vec::<(String, String)>::new()
454        );
455        assert_eq!(
456            parse_query(Some("a=1&b=2")).unwrap(),
457            vec![("a".into(), "1".into()), ("b".into(), "2".into())]
458        );
459    }
460}