zeroentropy_community/
client.rs1use crate::error::{Error, Result};
2use reqwest::{Client as HttpClient, Response};
3use serde::de::DeserializeOwned;
4use serde::Serialize;
5use std::time::Duration;
6
7const DEFAULT_BASE_URL: &str = "https://api.zeroentropy.dev/v1";
8const DEFAULT_TIMEOUT: Duration = Duration::from_secs(60);
9const DEFAULT_MAX_RETRIES: u32 = 2;
10
11#[derive(Clone)]
13pub struct Client {
14 http_client: HttpClient,
15 api_key: String,
16 base_url: String,
17 max_retries: u32,
18}
19
20impl Client {
21 pub fn new(api_key: impl Into<String>) -> Result<Self> {
33 Self::builder().api_key(api_key).build()
34 }
35
36 pub fn from_env() -> Result<Self> {
40 let api_key = std::env::var("ZEROENTROPY_API_KEY")
41 .map_err(|_| Error::InvalidApiKey)?;
42 Self::new(api_key)
43 }
44
45 pub fn builder() -> ClientBuilder {
47 ClientBuilder::default()
48 }
49
50 pub(crate) async fn post<T, R>(&self, endpoint: &str, body: &T) -> Result<R>
52 where
53 T: Serialize + ?Sized,
54 R: DeserializeOwned,
55 {
56 let url = format!("{}{}", self.base_url, endpoint);
57
58 let mut attempts = 0;
59 loop {
60 let response = self
61 .http_client
62 .post(&url)
63 .header("Authorization", format!("Bearer {}", self.api_key))
64 .header("Content-Type", "application/json")
65 .json(body)
66 .send()
67 .await?;
68
69 let status = response.status();
70
71 if attempts < self.max_retries && Self::should_retry(status.as_u16()) {
73 attempts += 1;
74 let delay = Self::calculate_retry_delay(attempts);
75 tokio::time::sleep(delay).await;
76 continue;
77 }
78
79 return Self::handle_response(response).await;
80 }
81 }
82
83 async fn handle_response<R: DeserializeOwned>(response: Response) -> Result<R> {
85 let status = response.status();
86
87 if status.is_success() {
88 Ok(response.json().await?)
89 } else {
90 let status_code = status.as_u16();
91 let error_text = response.text().await.unwrap_or_else(|_| "Unknown error".to_string());
92
93 let message = serde_json::from_str::<serde_json::Value>(&error_text)
95 .ok()
96 .and_then(|v| v.get("message").and_then(|m| m.as_str()).map(String::from))
97 .unwrap_or(error_text);
98
99 Err(Error::from_status(status_code, message))
100 }
101 }
102
103 fn should_retry(status: u16) -> bool {
105 matches!(status, 408 | 409 | 429) || status >= 500
106 }
107
108 fn calculate_retry_delay(attempt: u32) -> Duration {
110 let base_delay = 500; let max_delay = 8000; let delay = base_delay * 2_u64.pow(attempt - 1);
113 Duration::from_millis(delay.min(max_delay))
114 }
115}
116
117#[derive(Default)]
119pub struct ClientBuilder {
120 api_key: Option<String>,
121 base_url: Option<String>,
122 timeout: Option<Duration>,
123 max_retries: Option<u32>,
124}
125
126impl ClientBuilder {
127 pub fn api_key(mut self, api_key: impl Into<String>) -> Self {
129 self.api_key = Some(api_key.into());
130 self
131 }
132
133 pub fn base_url(mut self, base_url: impl Into<String>) -> Self {
135 self.base_url = Some(base_url.into());
136 self
137 }
138
139 pub fn timeout(mut self, timeout: Duration) -> Self {
141 self.timeout = Some(timeout);
142 self
143 }
144
145 pub fn max_retries(mut self, max_retries: u32) -> Self {
147 self.max_retries = Some(max_retries);
148 self
149 }
150
151 pub fn build(self) -> Result<Client> {
153 let api_key = self.api_key
154 .or_else(|| std::env::var("ZEROENTROPY_API_KEY").ok())
155 .ok_or(Error::InvalidApiKey)?;
156
157 let base_url = self.base_url
158 .or_else(|| std::env::var("ZEROENTROPY_BASE_URL").ok())
159 .unwrap_or_else(|| DEFAULT_BASE_URL.to_string());
160
161 let timeout = self.timeout.unwrap_or(DEFAULT_TIMEOUT);
162 let max_retries = self.max_retries.unwrap_or(DEFAULT_MAX_RETRIES);
163
164 let http_client = HttpClient::builder()
165 .timeout(timeout)
166 .build()?;
167
168 Ok(Client {
169 http_client,
170 api_key,
171 base_url,
172 max_retries,
173 })
174 }
175}