ultrafast_models_sdk/providers/
http_client.rs

1use crate::error::ProviderError;
2use reqwest::{header::HeaderMap, Client, Method, Response};
3use serde::de::DeserializeOwned;
4use serde::Serialize;
5use std::collections::HashMap;
6use std::time::Duration;
7
8#[derive(Clone, Debug)]
9pub enum AuthStrategy {
10    Bearer { token: String },
11    Header { name: String, value: String },
12    None,
13}
14
15#[derive(Clone)]
16pub struct HttpProviderClient {
17    http: Client,
18    base_url: String,
19    default_headers: HeaderMap,
20}
21
22impl HttpProviderClient {
23    pub fn new(
24        timeout: Duration,
25        base_url: Option<String>,
26        default_base: &str,
27        headers: &HashMap<String, String>,
28        auth: AuthStrategy,
29    ) -> Result<Self, ProviderError> {
30        let http = Client::builder().timeout(timeout).build().map_err(|e| {
31            ProviderError::Configuration {
32                message: format!("Failed to create HTTP client: {e}"),
33            }
34        })?;
35
36        let mut default_headers = HeaderMap::new();
37
38        match auth {
39            AuthStrategy::Bearer { token } => {
40                default_headers.insert("Authorization", format!("Bearer {token}").parse().unwrap());
41            }
42            AuthStrategy::Header { name, value } => {
43                if let (Ok(name), Ok(value)) =
44                    (name.parse::<reqwest::header::HeaderName>(), value.parse())
45                {
46                    default_headers.insert(name, value);
47                }
48            }
49            AuthStrategy::None => {}
50        }
51
52        for (k, v) in headers {
53            if let (Ok(name), Ok(value)) = (k.parse::<reqwest::header::HeaderName>(), v.parse()) {
54                default_headers.insert(name, value);
55            }
56        }
57
58        let base_url = base_url.unwrap_or_else(|| default_base.to_string());
59
60        Ok(Self {
61            http,
62            base_url,
63            default_headers,
64        })
65    }
66
67    fn build_url(&self, path: &str) -> String {
68        if path.starts_with('/') {
69            format!("{}{}", self.base_url, path)
70        } else {
71            format!("{}/{}", self.base_url.trim_end_matches('/'), path)
72        }
73    }
74
75    fn build_headers(&self) -> HeaderMap {
76        self.default_headers.clone()
77    }
78
79    pub async fn post_json<TReq: Serialize, TResp: DeserializeOwned>(
80        &self,
81        path: &str,
82        body: &TReq,
83    ) -> Result<TResp, ProviderError> {
84        let url = self.build_url(path);
85        let resp = self
86            .http
87            .request(Method::POST, url)
88            .headers(self.build_headers())
89            .json(body)
90            .send()
91            .await?;
92
93        if !resp.status().is_success() {
94            return Err(map_error_response(resp).await);
95        }
96        Ok(resp.json::<TResp>().await?)
97    }
98
99    pub async fn post_json_raw<TReq: Serialize>(
100        &self,
101        path: &str,
102        body: &TReq,
103    ) -> Result<Response, ProviderError> {
104        let url = self.build_url(path);
105        let resp = self
106            .http
107            .request(Method::POST, url)
108            .headers(self.build_headers())
109            .json(body)
110            .send()
111            .await?;
112        Ok(resp)
113    }
114
115    pub async fn post_multipart(
116        &self,
117        path: &str,
118        form: reqwest::multipart::Form,
119    ) -> Result<Response, ProviderError> {
120        let url = self.build_url(path);
121        let resp = self
122            .http
123            .request(Method::POST, url)
124            .headers(self.build_headers())
125            .multipart(form)
126            .send()
127            .await?;
128        Ok(resp)
129    }
130
131    pub async fn get_json<TResp: DeserializeOwned>(
132        &self,
133        path: &str,
134    ) -> Result<TResp, ProviderError> {
135        let url = self.build_url(path);
136        let resp = self
137            .http
138            .request(Method::GET, url)
139            .headers(self.build_headers())
140            .send()
141            .await?;
142
143        if !resp.status().is_success() {
144            return Err(map_error_response(resp).await);
145        }
146        Ok(resp.json::<TResp>().await?)
147    }
148}
149
150pub async fn map_error_response(resp: Response) -> ProviderError {
151    let status = resp.status();
152    match resp.text().await {
153        Ok(body) => {
154            // Try to pull a message from common JSON error shapes
155            let message = serde_json::from_str::<serde_json::Value>(&body)
156                .ok()
157                .and_then(|v| v.get("error").cloned())
158                .and_then(|e| e.get("message").cloned())
159                .and_then(|m| m.as_str().map(|s| s.to_string()))
160                .unwrap_or_else(|| body.clone());
161
162            match status.as_u16() {
163                401 => ProviderError::InvalidApiKey,
164                404 => ProviderError::ModelNotFound {
165                    model: "unknown".to_string(),
166                },
167                429 => ProviderError::RateLimit,
168                code => ProviderError::Api { code, message },
169            }
170        }
171        Err(_) => ProviderError::Api {
172            code: status.as_u16(),
173            message: "Failed to read error response".to_string(),
174        },
175    }
176}