Skip to main content

steer_core/api/
util.rs

1use crate::api::error::ApiError;
2
3/// Normalize a chat completions URL.
4/// Ensures the URL ends with the correct path for chat completions.
5pub fn normalize_chat_url(base_url: Option<&str>, default_url: &str) -> String {
6    let base_url = base_url.map_or_else(|| default_url.to_string(), |s| s.to_string());
7
8    // If URL already ends with chat/completions, return as-is
9    if base_url.ends_with("/chat/completions") || base_url.ends_with("/v1/chat/completions") {
10        return base_url;
11    }
12
13    // Parse the URL to better handle path segments
14    if let Ok(mut parsed) = url::Url::parse(&base_url) {
15        let path = parsed.path().trim_end_matches('/');
16
17        // If the path already contains /v1, append just /chat/completions
18        if path.ends_with("/v1") {
19            parsed.set_path(&format!("{path}/chat/completions"));
20        } else if path.is_empty() || path == "/" {
21            // No meaningful path, add /v1/chat/completions
22            parsed.set_path("/v1/chat/completions");
23        } else {
24            // Has some path but not /v1, append /v1/chat/completions
25            parsed.set_path(&format!("{path}/v1/chat/completions"));
26        }
27        parsed.to_string()
28    } else {
29        // Fallback for non-URL strings (shouldn't happen with valid base URLs)
30        if base_url.ends_with('/') {
31            format!("{base_url}v1/chat/completions")
32        } else {
33            format!("{base_url}/v1/chat/completions")
34        }
35    }
36}
37
38/// Normalize a responses URL.
39/// Ensures the URL ends with the correct path for Responses API.
40pub fn normalize_responses_url(base_url: Option<&str>, default_url: &str) -> String {
41    let base_url = base_url.map_or_else(|| default_url.to_string(), |s| s.to_string());
42
43    if base_url.ends_with("/responses") || base_url.ends_with("/v1/responses") {
44        return base_url;
45    }
46
47    if let Ok(mut parsed) = url::Url::parse(&base_url) {
48        let path = parsed.path().trim_end_matches('/');
49        if path.ends_with("/v1") {
50            parsed.set_path(&format!("{path}/responses"));
51        } else if path.is_empty() || path == "/" {
52            parsed.set_path("/v1/responses");
53        } else {
54            parsed.set_path(&format!("{path}/v1/responses"));
55        }
56        parsed.to_string()
57    } else if base_url.ends_with('/') {
58        format!("{base_url}v1/responses")
59    } else {
60        format!("{base_url}/v1/responses")
61    }
62}
63
64pub fn map_http_status_to_api_error(provider: &str, status_code: u16, details: String) -> ApiError {
65    match status_code {
66        401 | 403 => ApiError::AuthenticationFailed {
67            provider: provider.to_string(),
68            details,
69        },
70        408 => ApiError::Timeout {
71            provider: provider.to_string(),
72        },
73        429 => ApiError::RateLimited {
74            provider: provider.to_string(),
75            details,
76        },
77        409 => ApiError::ServerError {
78            provider: provider.to_string(),
79            status_code,
80            details,
81        },
82        400..=499 => ApiError::InvalidRequest {
83            provider: provider.to_string(),
84            details,
85        },
86        500..=599 => ApiError::ServerError {
87            provider: provider.to_string(),
88            status_code,
89            details,
90        },
91        _ => ApiError::Unknown {
92            provider: provider.to_string(),
93            details,
94        },
95    }
96}
97
98#[cfg(test)]
99mod tests {
100    use super::*;
101
102    #[test]
103    fn test_normalize_chat_url() {
104        assert_eq!(
105            normalize_chat_url(Some("https://api.example.com"), ""),
106            "https://api.example.com/v1/chat/completions"
107        );
108
109        assert_eq!(
110            normalize_chat_url(Some("https://api.example.com/"), ""),
111            "https://api.example.com/v1/chat/completions"
112        );
113
114        assert_eq!(
115            normalize_chat_url(Some("https://api.example.com/v1"), ""),
116            "https://api.example.com/v1/chat/completions"
117        );
118
119        assert_eq!(
120            normalize_chat_url(Some("https://api.example.com/chat/completions"), ""),
121            "https://api.example.com/chat/completions"
122        );
123
124        assert_eq!(
125            normalize_chat_url(Some("https://api.example.com/v1/chat/completions"), ""),
126            "https://api.example.com/v1/chat/completions"
127        );
128
129        assert_eq!(
130            normalize_chat_url(None, "https://default.com/v1/chat/completions"),
131            "https://default.com/v1/chat/completions"
132        );
133    }
134
135    #[test]
136    fn test_map_http_status_to_api_error() {
137        let provider = "test";
138
139        assert!(matches!(
140            map_http_status_to_api_error(provider, 401, "auth".to_string()),
141            ApiError::AuthenticationFailed { .. }
142        ));
143        assert!(matches!(
144            map_http_status_to_api_error(provider, 408, "timeout".to_string()),
145            ApiError::Timeout { .. }
146        ));
147        assert!(matches!(
148            map_http_status_to_api_error(provider, 409, "conflict".to_string()),
149            ApiError::ServerError {
150                status_code: 409,
151                ..
152            }
153        ));
154        assert!(matches!(
155            map_http_status_to_api_error(provider, 429, "rate".to_string()),
156            ApiError::RateLimited { .. }
157        ));
158        assert!(matches!(
159            map_http_status_to_api_error(provider, 400, "bad request".to_string()),
160            ApiError::InvalidRequest { .. }
161        ));
162        assert!(matches!(
163            map_http_status_to_api_error(provider, 503, "server".to_string()),
164            ApiError::ServerError {
165                status_code: 503,
166                ..
167            }
168        ));
169        assert!(matches!(
170            map_http_status_to_api_error(provider, 302, "redirect".to_string()),
171            ApiError::Unknown { .. }
172        ));
173    }
174}