Skip to main content

synth_ai/
client.rs

1use std::env;
2
3use reqwest::header::{HeaderMap, HeaderValue, AUTHORIZATION, CONTENT_TYPE};
4use serde::Serialize;
5use serde_json::Value;
6
7use crate::types::{Result, SynthError};
8
9const DEFAULT_BASE_URL: &str = "https://api.usesynth.ai";
10
11#[derive(Clone, Copy, Debug)]
12pub enum AuthStyle {
13    Bearer,
14    ApiKey,
15    Both,
16}
17
18#[derive(Clone)]
19pub struct SynthClient {
20    base_url: String,
21    api_key: String,
22    http: reqwest::Client,
23}
24
25impl SynthClient {
26    pub fn new(base_url: impl Into<String>, api_key: impl Into<String>) -> Self {
27        let base_url = base_url.into();
28        let api_key = api_key.into();
29        Self {
30            base_url: base_url.trim_end_matches('/').to_string(),
31            api_key,
32            http: reqwest::Client::new(),
33        }
34    }
35
36    pub fn from_env() -> Result<Self> {
37        let base_url = env::var("SYNTH_BACKEND_URL").unwrap_or_else(|_| DEFAULT_BASE_URL.to_string());
38        let api_key = env::var("SYNTH_API_KEY").map_err(|_| SynthError::MissingApiKey)?;
39        Ok(Self::new(base_url, api_key))
40    }
41
42    pub fn api_base(&self) -> String {
43        let trimmed = self.base_url.trim_end_matches('/');
44        if trimmed.ends_with("/api") {
45            trimmed.to_string()
46        } else {
47            format!("{trimmed}/api")
48        }
49    }
50
51    pub fn base_url(&self) -> &str {
52        &self.base_url
53    }
54
55    pub fn api_key(&self) -> &str {
56        &self.api_key
57    }
58
59    pub fn http(&self) -> &reqwest::Client {
60        &self.http
61    }
62
63    pub(crate) fn auth_headers(&self, auth: AuthStyle) -> HeaderMap {
64        let mut headers = HeaderMap::new();
65        if matches!(auth, AuthStyle::Bearer | AuthStyle::Both) {
66            let value = format!("Bearer {}", self.api_key);
67            if let Ok(hv) = HeaderValue::from_str(&value) {
68                headers.insert(AUTHORIZATION, hv);
69            }
70        }
71        if matches!(auth, AuthStyle::ApiKey | AuthStyle::Both) {
72            if let Ok(hv) = HeaderValue::from_str(&self.api_key) {
73                headers.insert("X-API-Key", hv);
74            }
75        }
76        headers.insert(CONTENT_TYPE, HeaderValue::from_static("application/json"));
77        headers
78    }
79
80    fn url(&self, path: &str) -> String {
81        if path.starts_with("http://") || path.starts_with("https://") {
82            return path.to_string();
83        }
84        let mut rel = path.trim_start_matches('/');
85        if rel.starts_with("api/") {
86            rel = &rel[4..];
87        }
88        format!("{}/{}", self.api_base(), rel)
89    }
90
91    pub async fn get_json(&self, path: &str, auth: AuthStyle) -> Result<Value> {
92        let url = self.url(path);
93        let resp = self
94            .http
95            .get(url)
96            .headers(self.auth_headers(auth))
97            .send()
98            .await?;
99        Self::json_or_error(resp).await
100    }
101
102    pub async fn post_json<T: Serialize + ?Sized>(
103        &self,
104        path: &str,
105        body: &T,
106        auth: AuthStyle,
107    ) -> Result<Value> {
108        let url = self.url(path);
109        let resp = self
110            .http
111            .post(url)
112            .headers(self.auth_headers(auth))
113            .json(body)
114            .send()
115            .await?;
116        Self::json_or_error(resp).await
117    }
118
119    pub async fn get_json_fallback(&self, paths: &[&str], auth: AuthStyle) -> Result<Value> {
120        let mut last_error = None;
121        for path in paths {
122            match self.get_json(path, auth).await {
123                Ok(val) => return Ok(val),
124                Err(err) => {
125                    if let SynthError::Api { status, .. } = &err {
126                        if *status == 404 {
127                            last_error = Some(err);
128                            continue;
129                        }
130                    }
131                    return Err(err);
132                }
133            }
134        }
135        Err(last_error.unwrap_or_else(|| {
136            SynthError::UnexpectedResponse("no fallback endpoints succeeded".to_string())
137        }))
138    }
139
140    pub async fn post_json_fallback<T: Serialize + ?Sized>(
141        &self,
142        paths: &[&str],
143        body: &T,
144        auth: AuthStyle,
145    ) -> Result<Value> {
146        let mut last_error = None;
147        for path in paths {
148            match self.post_json(path, body, auth).await {
149                Ok(val) => return Ok(val),
150                Err(err) => {
151                    if let SynthError::Api { status, .. } = &err {
152                        if *status == 404 {
153                            last_error = Some(err);
154                            continue;
155                        }
156                    }
157                    return Err(err);
158                }
159            }
160        }
161        Err(last_error.unwrap_or_else(|| {
162            SynthError::UnexpectedResponse("no fallback endpoints succeeded".to_string())
163        }))
164    }
165
166    async fn json_or_error(resp: reqwest::Response) -> Result<Value> {
167        let status = resp.status();
168        if status.is_success() {
169            return Ok(resp.json::<Value>().await?);
170        }
171        let body = resp.text().await.unwrap_or_default();
172        Err(SynthError::Api {
173            status: status.as_u16(),
174            body,
175        })
176    }
177}