Skip to main content

sgr_agent/openapi/
caller.rs

1//! HTTP caller — execute API requests from endpoint definitions.
2
3use super::spec::{Endpoint, ParamLocation};
4use std::collections::HashMap;
5
6/// Auth configuration for API calls.
7#[derive(Debug, Clone)]
8pub enum ApiAuth {
9    None,
10    Bearer(String),
11    Basic(String),          // "user:pass"
12    Header(String, String), // custom header name + value
13}
14
15/// Build the full URL from an endpoint, base URL, and parameter values.
16///
17/// Path params are substituted in the URL template.
18/// Query params are appended as `?key=value&...`.
19pub fn build_url(
20    base_url: &str,
21    endpoint: &Endpoint,
22    params: &HashMap<String, String>,
23) -> Result<String, String> {
24    // Check required params
25    for p in &endpoint.params {
26        if p.required && !params.contains_key(&p.name) {
27            return Err(format!("Missing required parameter: {}", p.name));
28        }
29    }
30
31    // Substitute path params
32    let mut path = endpoint.path.clone();
33    for p in &endpoint.params {
34        if p.location == ParamLocation::Path {
35            if let Some(value) = params.get(&p.name) {
36                let token = format!("{{{}}}", p.name);
37                path = path.replace(&token, value);
38            }
39        }
40    }
41
42    let base = base_url.trim_end_matches('/');
43    let mut url = format!("{}{}", base, path);
44
45    // Append query params
46    let query_parts: Vec<String> = endpoint
47        .params
48        .iter()
49        .filter(|p| p.location == ParamLocation::Query)
50        .filter_map(|p| {
51            params
52                .get(&p.name)
53                .map(|v| format!("{}={}", p.name, urlencod(v)))
54        })
55        .collect();
56
57    if !query_parts.is_empty() {
58        url.push('?');
59        url.push_str(&query_parts.join("&"));
60    }
61
62    Ok(url)
63}
64
65/// Percent-encoding for query values according to RFC 3986.
66fn urlencod(s: &str) -> String {
67    let mut encoded = String::with_capacity(s.len() + s.len() / 2);
68    for b in s.bytes() {
69        match b {
70            b'A'..=b'Z' | b'a'..=b'z' | b'0'..=b'9' | b'-' | b'.' | b'_' | b'~' => {
71                encoded.push(b as char);
72            }
73            _ => {
74                use std::fmt::Write;
75                write!(encoded, "%{:02X}", b).unwrap();
76            }
77        }
78    }
79    encoded
80}
81
82/// Execute an API call. Returns the response body as string.
83pub async fn call_api(
84    base_url: &str,
85    endpoint: &Endpoint,
86    params: &HashMap<String, String>,
87    body: Option<&serde_json::Value>,
88    auth: &ApiAuth,
89) -> Result<String, String> {
90    let url = build_url(base_url, endpoint, params)?;
91
92    let client = reqwest::Client::builder()
93        .user_agent("rust-code/1.0")
94        .build()
95        .map_err(|e| format!("Failed to build HTTP client: {}", e))?;
96    let mut req = match endpoint.method.as_str() {
97        "GET" => client.get(&url),
98        "POST" => client.post(&url),
99        "PUT" => client.put(&url),
100        "DELETE" => client.delete(&url),
101        "PATCH" => client.patch(&url),
102        "HEAD" => client.head(&url),
103        other => return Err(format!("Unsupported method: {}", other)),
104    };
105
106    // Auth
107    match auth {
108        ApiAuth::None => {}
109        ApiAuth::Bearer(token) => {
110            req = req.header("Authorization", format!("Bearer {}", token));
111        }
112        ApiAuth::Basic(credentials) => {
113            let encoded = simple_base64(credentials.as_bytes());
114            req = req.header("Authorization", format!("Basic {}", encoded));
115        }
116        ApiAuth::Header(name, value) => {
117            req = req.header(name, value);
118        }
119    }
120
121    // Body
122    if let Some(body_val) = body {
123        req = req
124            .header("Content-Type", "application/json")
125            .json(body_val);
126    }
127
128    let response = req.send().await.map_err(|e| format!("HTTP error: {}", e))?;
129    let status = response.status();
130    let text = response
131        .text()
132        .await
133        .map_err(|e| format!("Read error: {}", e))?;
134
135    if status.is_success() {
136        Ok(text)
137    } else {
138        Err(format!("HTTP {} — {}", status, truncate(&text, 500)))
139    }
140}
141
142/// Minimal base64 encoder (no external dep needed for just auth headers).
143fn simple_base64(data: &[u8]) -> String {
144    const CHARS: &[u8] = b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/";
145    let mut result = String::with_capacity(data.len().div_ceil(3) * 4);
146    for chunk in data.chunks(3) {
147        let b0 = chunk[0] as u32;
148        let b1 = if chunk.len() > 1 { chunk[1] as u32 } else { 0 };
149        let b2 = if chunk.len() > 2 { chunk[2] as u32 } else { 0 };
150        let n = (b0 << 16) | (b1 << 8) | b2;
151        result.push(CHARS[((n >> 18) & 63) as usize] as char);
152        result.push(CHARS[((n >> 12) & 63) as usize] as char);
153        if chunk.len() > 1 {
154            result.push(CHARS[((n >> 6) & 63) as usize] as char);
155        } else {
156            result.push('=');
157        }
158        if chunk.len() > 2 {
159            result.push(CHARS[(n & 63) as usize] as char);
160        } else {
161            result.push('=');
162        }
163    }
164    result
165}
166
167fn truncate(s: &str, max: usize) -> &str {
168    if s.len() <= max {
169        s
170    } else {
171        &s[..max]
172    }
173}
174
175#[cfg(test)]
176mod tests {
177    use super::*;
178    use crate::openapi::spec::{Endpoint, Param, ParamLocation};
179
180    fn issue_endpoint() -> Endpoint {
181        Endpoint {
182            name: "repos_owner_repo_issues_get".into(),
183            method: "GET".into(),
184            path: "/repos/{owner}/{repo}/issues".into(),
185            description: "List issues".into(),
186            params: vec![
187                Param {
188                    name: "owner".into(),
189                    location: ParamLocation::Path,
190                    required: true,
191                    param_type: "string".into(),
192                    description: "".into(),
193                },
194                Param {
195                    name: "repo".into(),
196                    location: ParamLocation::Path,
197                    required: true,
198                    param_type: "string".into(),
199                    description: "".into(),
200                },
201                Param {
202                    name: "state".into(),
203                    location: ParamLocation::Query,
204                    required: false,
205                    param_type: "string".into(),
206                    description: "open/closed/all".into(),
207                },
208            ],
209        }
210    }
211
212    #[test]
213    fn build_url_substitutes_path_params() {
214        let ep = issue_endpoint();
215        let mut params = HashMap::new();
216        params.insert("owner".into(), "rust-lang".into());
217        params.insert("repo".into(), "rust".into());
218
219        let url = build_url("https://api.github.com", &ep, &params).unwrap();
220        assert_eq!(url, "https://api.github.com/repos/rust-lang/rust/issues");
221    }
222
223    #[test]
224    fn build_url_with_query_params() {
225        let ep = issue_endpoint();
226        let mut params = HashMap::new();
227        params.insert("owner".into(), "foo".into());
228        params.insert("repo".into(), "bar".into());
229        params.insert("state".into(), "open".into());
230
231        let url = build_url("https://api.github.com", &ep, &params).unwrap();
232        assert!(url.contains("?state=open"));
233    }
234
235    #[test]
236    fn build_url_missing_required_param() {
237        let ep = issue_endpoint();
238        let params = HashMap::new();
239        let err = build_url("https://api.github.com", &ep, &params).unwrap_err();
240        assert!(err.contains("Missing required parameter: owner"));
241    }
242
243    #[test]
244    fn build_url_trailing_slash_base() {
245        let ep = Endpoint {
246            name: "test".into(),
247            method: "GET".into(),
248            path: "/test".into(),
249            description: "".into(),
250            params: vec![],
251        };
252        let url = build_url("https://example.com/", &ep, &HashMap::new()).unwrap();
253        assert_eq!(url, "https://example.com/test");
254    }
255
256    #[test]
257    fn urlencod_special_chars() {
258        assert_eq!(urlencod("hello world"), "hello%20world");
259        assert_eq!(urlencod("a&b=c"), "a%26b%3Dc");
260        assert_eq!(urlencod("foo+bar"), "foo%2Bbar");
261        assert_eq!(urlencod("path/to/file"), "path%2Fto%2Ffile");
262        assert_eq!(urlencod("user@host.com"), "user%40host.com");
263        assert_eq!(urlencod("~-_."), "~-_.");
264        assert_eq!(urlencod("🚀"), "%F0%9F%9A%80");
265    }
266
267    #[test]
268    fn test_simple_base64() {
269        assert_eq!(simple_base64(b""), "");
270        assert_eq!(simple_base64(b"f"), "Zg==");
271        assert_eq!(simple_base64(b"fo"), "Zm8=");
272        assert_eq!(simple_base64(b"foo"), "Zm9v");
273        assert_eq!(simple_base64(b"foob"), "Zm9vYg==");
274        assert_eq!(simple_base64(b"fooba"), "Zm9vYmE=");
275        assert_eq!(simple_base64(b"foobar"), "Zm9vYmFy");
276    }
277}