Skip to main content

stoat_core/
transform.rs

1//! Request transformation logic.
2//!
3//! Pure functions for transforming proxy requests according to the configured
4//! translation rules:
5//!
6//! - Header stripping (case-insensitive)
7//! - Header setting with `{access_token}` template resolution
8//! - Query parameter appending
9//! - Upstream URL construction
10
11use std::collections::HashMap;
12
13use url::Url;
14
15/// Resolve the `{access_token}` template variable in a header value.
16///
17/// Currently the only supported template variable is `{access_token}`.
18#[must_use]
19#[allow(clippy::literal_string_with_formatting_args)]
20pub fn resolve_template(template: &str, access_token: &str) -> String {
21    template.replace("{access_token}", access_token)
22}
23
24/// Check whether a header name should be stripped (case-insensitive).
25#[must_use]
26pub fn should_strip_header(header_name: &str, strip_headers: &[String]) -> bool {
27    strip_headers
28        .iter()
29        .any(|h| h.eq_ignore_ascii_case(header_name))
30}
31
32/// Resolve all configured set-headers, replacing `{access_token}` in values.
33///
34/// Returns a list of (header-name, resolved-value) pairs.
35#[must_use]
36#[allow(clippy::implicit_hasher)]
37pub fn resolve_set_headers(
38    set_headers: &HashMap<String, String>,
39    access_token: &str,
40) -> Vec<(String, String)> {
41    set_headers
42        .iter()
43        .map(|(name, template)| (name.clone(), resolve_template(template, access_token)))
44        .collect()
45}
46
47/// Build the upstream URL from the base URL and request path/query.
48///
49/// The request path is appended to the base URL path. The incoming query
50/// string is preserved, and any configured extra query parameters are
51/// appended with proper percent-encoding.
52#[must_use]
53#[allow(clippy::implicit_hasher)]
54pub fn build_upstream_url(
55    base_url: &Url,
56    request_path: &str,
57    request_query: Option<&str>,
58    extra_query_params: Option<&HashMap<String, String>>,
59) -> Url {
60    let mut url = base_url.clone();
61
62    // Combine the base URL path with the request path.
63    let base_path = url.path().trim_end_matches('/');
64    let req_path = request_path.trim_start_matches('/');
65    let combined = if req_path.is_empty() {
66        base_path.to_owned()
67    } else {
68        format!("{base_path}/{req_path}")
69    };
70    url.set_path(if combined.is_empty() { "/" } else { &combined });
71
72    // Start with the incoming request's raw query string.
73    url.set_query(request_query.filter(|q| !q.is_empty()));
74
75    // Append extra query parameters from the translation config.
76    if let Some(params) = extra_query_params.filter(|p| !p.is_empty()) {
77        let mut pairs = url.query_pairs_mut();
78        for (key, value) in params {
79            pairs.append_pair(key, value);
80        }
81    }
82
83    url
84}
85
86#[cfg(test)]
87mod tests {
88    use super::*;
89
90    // --- resolve_template ---
91
92    #[test]
93    fn resolve_template_replaces_access_token() {
94        assert_eq!(
95            resolve_template("Bearer {access_token}", "tok123"),
96            "Bearer tok123",
97        );
98    }
99
100    #[test]
101    fn resolve_template_no_variable() {
102        assert_eq!(resolve_template("static-value", "tok"), "static-value");
103    }
104
105    #[test]
106    fn resolve_template_multiple_occurrences() {
107        assert_eq!(
108            resolve_template("{access_token}:{access_token}", "abc"),
109            "abc:abc",
110        );
111    }
112
113    #[test]
114    fn resolve_template_empty_token() {
115        assert_eq!(resolve_template("Bearer {access_token}", ""), "Bearer ");
116    }
117
118    // --- should_strip_header ---
119
120    #[test]
121    fn strip_header_case_insensitive() {
122        let strip = vec!["X-Api-Key".to_owned()];
123        assert!(should_strip_header("x-api-key", &strip));
124        assert!(should_strip_header("X-API-KEY", &strip));
125        assert!(should_strip_header("X-Api-Key", &strip));
126    }
127
128    #[test]
129    fn strip_header_no_match() {
130        let strip = vec!["X-Api-Key".to_owned()];
131        assert!(!should_strip_header("Authorization", &strip));
132    }
133
134    #[test]
135    fn strip_header_empty_list() {
136        assert!(!should_strip_header("x-api-key", &[]));
137    }
138
139    #[test]
140    fn strip_header_multiple_entries() {
141        let strip = vec!["X-Api-Key".to_owned(), "X-Custom".to_owned()];
142        assert!(should_strip_header("x-api-key", &strip));
143        assert!(should_strip_header("x-custom", &strip));
144        assert!(!should_strip_header("authorization", &strip));
145    }
146
147    // --- resolve_set_headers ---
148
149    #[test]
150    fn resolve_set_headers_applies_template() {
151        let mut headers = HashMap::new();
152        headers.insert(
153            "Authorization".to_owned(),
154            "Bearer {access_token}".to_owned(),
155        );
156        headers.insert("X-Custom".to_owned(), "static".to_owned());
157
158        let resolved = resolve_set_headers(&headers, "my-token");
159        let resolved_map: HashMap<_, _> = resolved.into_iter().collect();
160
161        assert_eq!(resolved_map["Authorization"], "Bearer my-token");
162        assert_eq!(resolved_map["X-Custom"], "static");
163    }
164
165    #[test]
166    fn resolve_set_headers_empty() {
167        let headers = HashMap::new();
168        let resolved = resolve_set_headers(&headers, "tok");
169        assert!(resolved.is_empty());
170    }
171
172    // --- build_upstream_url ---
173
174    #[test]
175    fn upstream_url_simple_path() {
176        let base = Url::parse("https://api.example.com").unwrap();
177        let url = build_upstream_url(&base, "/v1/chat", None, None);
178        assert_eq!(url.as_str(), "https://api.example.com/v1/chat");
179    }
180
181    #[test]
182    fn upstream_url_base_with_path() {
183        let base = Url::parse("https://api.example.com/api").unwrap();
184        let url = build_upstream_url(&base, "/v1/chat", None, None);
185        assert_eq!(url.as_str(), "https://api.example.com/api/v1/chat");
186    }
187
188    #[test]
189    fn upstream_url_preserves_query() {
190        let base = Url::parse("https://api.example.com").unwrap();
191        let url = build_upstream_url(&base, "/search", Some("q=hello+world"), None);
192        assert_eq!(url.as_str(), "https://api.example.com/search?q=hello+world");
193    }
194
195    #[test]
196    fn upstream_url_appends_extra_params() {
197        let base = Url::parse("https://api.example.com").unwrap();
198        let mut extra = HashMap::new();
199        extra.insert("beta".to_owned(), "true".to_owned());
200        let url = build_upstream_url(&base, "/v1/chat", None, Some(&extra));
201        assert_eq!(url.as_str(), "https://api.example.com/v1/chat?beta=true");
202    }
203
204    #[test]
205    fn upstream_url_merges_query_and_extra() {
206        let base = Url::parse("https://api.example.com").unwrap();
207        let mut extra = HashMap::new();
208        extra.insert("beta".to_owned(), "true".to_owned());
209        let url = build_upstream_url(&base, "/v1/chat", Some("model=gpt4"), Some(&extra));
210        let url_str = url.as_str();
211        assert!(url_str.starts_with("https://api.example.com/v1/chat?"));
212        assert!(url_str.contains("model=gpt4"));
213        assert!(url_str.contains("beta=true"));
214    }
215
216    #[test]
217    fn upstream_url_root_path() {
218        let base = Url::parse("https://api.example.com").unwrap();
219        let url = build_upstream_url(&base, "/", None, None);
220        assert_eq!(url.as_str(), "https://api.example.com/");
221    }
222
223    #[test]
224    fn upstream_url_empty_path() {
225        let base = Url::parse("https://api.example.com").unwrap();
226        let url = build_upstream_url(&base, "", None, None);
227        assert_eq!(url.as_str(), "https://api.example.com/");
228    }
229
230    #[test]
231    fn upstream_url_trailing_slash_base() {
232        let base = Url::parse("https://api.example.com/api/").unwrap();
233        let url = build_upstream_url(&base, "/v1/chat", None, None);
234        assert_eq!(url.as_str(), "https://api.example.com/api/v1/chat");
235    }
236
237    #[test]
238    fn upstream_url_empty_query_ignored() {
239        let base = Url::parse("https://api.example.com").unwrap();
240        let url = build_upstream_url(&base, "/v1/chat", Some(""), None);
241        assert_eq!(url.as_str(), "https://api.example.com/v1/chat");
242    }
243
244    #[test]
245    fn upstream_url_empty_extra_params_ignored() {
246        let base = Url::parse("https://api.example.com").unwrap();
247        let extra = HashMap::new();
248        let url = build_upstream_url(&base, "/v1/chat", None, Some(&extra));
249        assert_eq!(url.as_str(), "https://api.example.com/v1/chat");
250    }
251
252    #[test]
253    fn upstream_url_deep_path() {
254        let base = Url::parse("https://api.example.com/v1").unwrap();
255        let url = build_upstream_url(&base, "/a/b/c/d", None, None);
256        assert_eq!(url.as_str(), "https://api.example.com/v1/a/b/c/d");
257    }
258
259    #[test]
260    fn upstream_url_preserves_encoded_path() {
261        let base = Url::parse("https://api.example.com").unwrap();
262        let url = build_upstream_url(&base, "/path%20with%20spaces", None, None);
263        assert_eq!(url.as_str(), "https://api.example.com/path%20with%20spaces");
264    }
265
266    #[test]
267    fn upstream_url_extra_params_encoded() {
268        let base = Url::parse("https://api.example.com").unwrap();
269        let mut extra = HashMap::new();
270        extra.insert("name".to_owned(), "hello world".to_owned());
271        let url = build_upstream_url(&base, "/v1/chat", None, Some(&extra));
272        assert!(url.as_str().contains("name=hello+world"));
273    }
274}