1use crate::error::{QueryResolverError, AIError};
2use crate::json_utils;
3use serde::de::DeserializeOwned;
4use std::collections::HashMap;
5use std::fmt::Debug;
6use async_trait::async_trait;
7use tracing::{info, warn, error, debug, instrument};
8use schemars::{JsonSchema, schema_for};
9
10#[async_trait]
14pub trait LowLevelClient: Send + Sync + Debug{
15 async fn ask_raw(&self, prompt: String) -> Result<String, AIError>;
17
18 async fn ask_json(&self, prompt: String) -> Result<String, AIError> {
20 let raw_response = self.ask_raw(prompt).await?;
21 Ok(json_utils::find_json(&raw_response))
22 }
23
24 fn clone_box(&self) -> Box<dyn LowLevelClient>;
26}
27
28impl Clone for Box<dyn LowLevelClient> {
30 fn clone(&self) -> Self {
31 self.clone_box()
32 }
33}
34
35#[async_trait]
37impl LowLevelClient for Box<dyn LowLevelClient> {
38 async fn ask_raw(&self, prompt: String) -> Result<String, AIError> {
39 self.as_ref().ask_raw(prompt).await
40 }
41
42 fn clone_box(&self) -> Box<dyn LowLevelClient> {
43 self.as_ref().clone_box()
44 }
45}
46
47
48
49#[derive(Debug, Clone)]
50pub struct RetryConfig {
51 pub max_retries: HashMap<String, usize>,
52 pub default_max_retries: usize,
53}
54
55impl Default for RetryConfig {
56 fn default() -> Self {
57 let mut max_retries = HashMap::new();
58 max_retries.insert("rate_limit".to_string(), 1);
59 max_retries.insert("api_error".to_string(), 1);
60 max_retries.insert("http_error".to_string(), 1);
61 max_retries.insert("json_parse_error".to_string(), 2);
62
63 Self {
64 max_retries,
65 default_max_retries: 1,
66 }
67 }
68}
69
70
71#[derive(Clone)]
72pub struct QueryResolver<C: LowLevelClient> {
76 client: C,
77 config: RetryConfig,
78}
79
80impl<C: LowLevelClient> QueryResolver<C> {
81 pub fn new(client: C, config: RetryConfig) -> Self {
82 info!(default_max_retries = config.default_max_retries, "Creating new QueryResolver with retry config");
83 Self { client, config }
84 }
85
86 pub fn client(&self) -> &C {
88 &self.client
89 }
90
91 pub fn config(&self) -> &RetryConfig {
93 &self.config
94 }
95
96 pub fn with_config(mut self, config: RetryConfig) -> Self {
98 self.config = config;
99 self
100 }
101
102 #[instrument(skip(self, prompt), fields(prompt_len = prompt.len()))]
104 pub async fn query_raw<T>(&self, prompt: String) -> Result<T, QueryResolverError>
105 where
106 T: DeserializeOwned + Send,
107 {
108 info!(prompt_len = prompt.len(), "Starting query");
109 let result = self.ask_with_retry(prompt).await;
110 match &result {
111 Ok(_) => info!("Query completed successfully"),
112 Err(e) => error!(error = %e, "Query failed"),
113 }
114 result
115 }
116
117 #[instrument(skip(self, prompt), fields(prompt_len = prompt.len()))]
119 pub async fn query<T>(&self, prompt: String) -> Result<T, QueryResolverError>
120 where
121 T: DeserializeOwned + JsonSchema + Send,
122 {
123 info!(prompt_len = prompt.len(), "Starting schema-aware query");
124 let result = self.ask_with_schema(prompt).await;
125 match &result {
126 Ok(_) => info!("Schema-aware query completed successfully"),
127 Err(e) => error!(error = %e, "Schema-aware query failed"),
128 }
129 result
130 }
131
132 #[instrument(skip(self, prompt), fields(prompt_len = prompt.len()))]
134 async fn ask_with_retry<T>(&self, prompt: String) -> Result<T, QueryResolverError>
135 where
136 T: DeserializeOwned + Send,
137 {
138 let mut attempt = 0;
139 let mut context = String::new();
140
141 info!(attempt = 0, max_retries = self.config.default_max_retries, "Starting retry loop for prompt");
142
143 loop {
144 let full_prompt = if context.is_empty() {
145 prompt.clone()
146 } else {
147 format!("{}\n\nPrevious attempt failed: {}\nPlease fix the issue and respond with valid JSON.", prompt, context)
148 };
149
150 debug!(attempt = attempt + 1, prompt_len = full_prompt.len(), "Making API call");
151 match self.client.ask_json(full_prompt.clone()).await {
152 Ok(response) => {
153 debug!(response_len = response.len(), "Received API response");
154 match serde_json::from_str::<T>(&response) {
155 Ok(parsed) => {
156 info!(attempt = attempt + 1, "Successfully parsed JSON response");
157 return Ok(parsed);
158 },
159 Err(json_err) => {
160 warn!(
161 error = %json_err,
162 response_preview = &response[..response.len().min(200)],
163 "Initial JSON parsing failed, trying advanced extraction"
164 );
165
166 if let Ok(raw_response) = self.client.ask_raw(full_prompt.clone()).await {
168 if let Some(extracted_json) = json_utils::extract_json_advanced(&raw_response) {
169 debug!(extracted_len = extracted_json.len(), "Trying to parse extracted JSON after initial failure");
170 match serde_json::from_str::<T>(&extracted_json) {
171 Ok(parsed) => {
172 info!(attempt = attempt + 1, "Successfully parsed extracted JSON after initial deserialization failure");
173 return Ok(parsed);
174 },
175 Err(extracted_err) => {
176 warn!(
177 error = %extracted_err,
178 extracted_preview = &extracted_json[..extracted_json.len().min(200)],
179 "Advanced extraction also failed to parse"
180 );
181 }
182 }
183 } else {
184 warn!("Advanced extraction could not find valid JSON in raw response");
185 }
186 }
187
188 let max_retries = self.config.max_retries.get("json_parse_error")
190 .unwrap_or(&self.config.default_max_retries);
191
192 if attempt >= *max_retries {
193 error!(
194 error = %json_err,
195 attempt = attempt + 1,
196 max_retries = max_retries,
197 "Max retries exceeded for JSON parsing"
198 );
199 return Err(QueryResolverError::JsonDeserialization(json_err, response));
200 }
201
202 warn!(
204 attempt = attempt + 1,
205 max_retries = max_retries,
206 "Retrying due to JSON parsing failure"
207 );
208 context = format!("JSON parsing failed: {}. Response was: {}",
209 json_err,
210 &response[..response.len().min(500)]);
211 attempt += 1;
212 }
213 }
214 }
215 Err(ai_error) => {
216 warn!(error = %ai_error, attempt = attempt + 1, "API call failed");
217 let error_type = match &ai_error {
218 AIError::Claude(claude_err) => match claude_err {
219 crate::error::ClaudeError::RateLimit => "rate_limit",
220 crate::error::ClaudeError::Http(_) => "http_error",
221 crate::error::ClaudeError::Api(_) => "api_error",
222 _ => "other",
223 },
224 AIError::OpenAI(openai_err) => match openai_err {
225 crate::error::OpenAIError::RateLimit => "rate_limit",
226 crate::error::OpenAIError::Http(_) => "http_error",
227 crate::error::OpenAIError::Api(_) => "api_error",
228 _ => "other",
229 },
230 AIError::DeepSeek(deepseek_err) => match deepseek_err {
231 crate::error::DeepSeekError::RateLimit => "rate_limit",
232 crate::error::DeepSeekError::Http(_) => "http_error",
233 crate::error::DeepSeekError::Api(_) => "api_error",
234 _ => "other",
235 },
236 AIError::Mock(_) => "mock_error",
237 };
238
239 let max_retries = self.config.max_retries.get(error_type)
240 .unwrap_or(&self.config.default_max_retries);
241
242 if attempt >= *max_retries {
243 error!(
244 error = %ai_error,
245 error_type = error_type,
246 max_retries = max_retries,
247 "Max retries exceeded for API error"
248 );
249 return Err(QueryResolverError::Ai(ai_error));
250 }
251
252 info!(
253 error_type = error_type,
254 attempt = attempt + 1,
255 max_retries = max_retries,
256 "Retrying after API error"
257 );
258 context = format!("API call failed: {}", ai_error);
259 attempt += 1;
260 }
261 }
262 }
263 }
264
265 pub fn augment_prompt_with_schema<T>(&self, prompt: String) -> String
267 where
268 T: JsonSchema,
269 {
270 let schema = schema_for!(T);
271 let schema_json = serde_json::to_string_pretty(&schema)
272 .unwrap_or_else(|_| "{}".to_string());
273
274 debug!(schema_len = schema_json.len(), "Generated JSON schema for return type");
275
276 format!(
277 r#"{prompt}
278You are tasked with generating a value satisfying a schema. First I will give you an example exchange then I will provide the schema of interest
279Example Schema:
280{{
281 "type": "object",
282 "properties": {{
283 "name": {{"type": "string"}},
284 "age": {{"type": "integer", "minimum": 0}},
285 "email": {{"type": "string"}},
286 "isActive": {{"type": "boolean"}},
287 "hobbies": {{"type": "array", "items": {{"type": "string"}}}}
288 }},
289 "required": ["name", "age", "email", "isActive"]
290}}
291Example response:
292{{"name": "Alice Smith", "age": 28, "email": "alice@example.com", "isActive": true, "hobbies": ["reading", "cooking"]}}
293Please provide a response matching this schema
294```json
295{schema_json}
296```
297"#
298 )
299 }
300
301 #[instrument(skip(self, prompt), fields(prompt_len = prompt.len()))]
303 async fn ask_with_schema<T>(&self, prompt: String) -> Result<T, QueryResolverError>
304 where
305 T: DeserializeOwned + JsonSchema + Send,
306 {
307 info!("Starting schema-aware query");
308 let augmented_prompt = self.augment_prompt_with_schema::<T>(prompt);
309 debug!(augmented_prompt_len = augmented_prompt.len(), "Generated schema-augmented prompt");
310 self.ask_with_retry(augmented_prompt).await
311 }
312}
313
314