ultrafast_models_sdk/providers/
http_client.rs1use crate::error::ProviderError;
2use reqwest::{header::HeaderMap, Client, Method, Response};
3use serde::de::DeserializeOwned;
4use serde::Serialize;
5use std::collections::HashMap;
6use std::time::Duration;
7
8#[derive(Clone, Debug)]
9pub enum AuthStrategy {
10 Bearer { token: String },
11 Header { name: String, value: String },
12 None,
13}
14
15#[derive(Clone)]
16pub struct HttpProviderClient {
17 http: Client,
18 base_url: String,
19 default_headers: HeaderMap,
20}
21
22impl HttpProviderClient {
23 pub fn new(
24 timeout: Duration,
25 base_url: Option<String>,
26 default_base: &str,
27 headers: &HashMap<String, String>,
28 auth: AuthStrategy,
29 ) -> Result<Self, ProviderError> {
30 let http = Client::builder().timeout(timeout).build().map_err(|e| {
31 ProviderError::Configuration {
32 message: format!("Failed to create HTTP client: {e}"),
33 }
34 })?;
35
36 let mut default_headers = HeaderMap::new();
37
38 match auth {
39 AuthStrategy::Bearer { token } => {
40 default_headers.insert("Authorization", format!("Bearer {token}").parse().unwrap());
41 }
42 AuthStrategy::Header { name, value } => {
43 if let (Ok(name), Ok(value)) =
44 (name.parse::<reqwest::header::HeaderName>(), value.parse())
45 {
46 default_headers.insert(name, value);
47 }
48 }
49 AuthStrategy::None => {}
50 }
51
52 for (k, v) in headers {
53 if let (Ok(name), Ok(value)) = (k.parse::<reqwest::header::HeaderName>(), v.parse()) {
54 default_headers.insert(name, value);
55 }
56 }
57
58 let base_url = base_url.unwrap_or_else(|| default_base.to_string());
59
60 Ok(Self {
61 http,
62 base_url,
63 default_headers,
64 })
65 }
66
67 fn build_url(&self, path: &str) -> String {
68 if path.starts_with('/') {
69 format!("{}{}", self.base_url, path)
70 } else {
71 format!("{}/{}", self.base_url.trim_end_matches('/'), path)
72 }
73 }
74
75 fn build_headers(&self) -> HeaderMap {
76 self.default_headers.clone()
77 }
78
79 pub async fn post_json<TReq: Serialize, TResp: DeserializeOwned>(
80 &self,
81 path: &str,
82 body: &TReq,
83 ) -> Result<TResp, ProviderError> {
84 let url = self.build_url(path);
85 let resp = self
86 .http
87 .request(Method::POST, url)
88 .headers(self.build_headers())
89 .json(body)
90 .send()
91 .await?;
92
93 if !resp.status().is_success() {
94 return Err(map_error_response(resp).await);
95 }
96 Ok(resp.json::<TResp>().await?)
97 }
98
99 pub async fn post_json_raw<TReq: Serialize>(
100 &self,
101 path: &str,
102 body: &TReq,
103 ) -> Result<Response, ProviderError> {
104 let url = self.build_url(path);
105 let resp = self
106 .http
107 .request(Method::POST, url)
108 .headers(self.build_headers())
109 .json(body)
110 .send()
111 .await?;
112 Ok(resp)
113 }
114
115 pub async fn post_multipart(
116 &self,
117 path: &str,
118 form: reqwest::multipart::Form,
119 ) -> Result<Response, ProviderError> {
120 let url = self.build_url(path);
121 let resp = self
122 .http
123 .request(Method::POST, url)
124 .headers(self.build_headers())
125 .multipart(form)
126 .send()
127 .await?;
128 Ok(resp)
129 }
130
131 pub async fn get_json<TResp: DeserializeOwned>(
132 &self,
133 path: &str,
134 ) -> Result<TResp, ProviderError> {
135 let url = self.build_url(path);
136 let resp = self
137 .http
138 .request(Method::GET, url)
139 .headers(self.build_headers())
140 .send()
141 .await?;
142
143 if !resp.status().is_success() {
144 return Err(map_error_response(resp).await);
145 }
146 Ok(resp.json::<TResp>().await?)
147 }
148}
149
150pub async fn map_error_response(resp: Response) -> ProviderError {
151 let status = resp.status();
152 match resp.text().await {
153 Ok(body) => {
154 let message = serde_json::from_str::<serde_json::Value>(&body)
156 .ok()
157 .and_then(|v| v.get("error").cloned())
158 .and_then(|e| e.get("message").cloned())
159 .and_then(|m| m.as_str().map(|s| s.to_string()))
160 .unwrap_or_else(|| body.clone());
161
162 match status.as_u16() {
163 401 => ProviderError::InvalidApiKey,
164 404 => ProviderError::ModelNotFound {
165 model: "unknown".to_string(),
166 },
167 429 => ProviderError::RateLimit,
168 code => ProviderError::Api { code, message },
169 }
170 }
171 Err(_) => ProviderError::Api {
172 code: status.as_u16(),
173 message: "Failed to read error response".to_string(),
174 },
175 }
176}