semantic_query/
core.rs

1use crate::error::{QueryResolverError, AIError};
2use crate::json_utils;
3use serde::de::DeserializeOwned;
4use std::collections::HashMap;
5use std::fmt::Debug;
6use async_trait::async_trait;
7use tracing::{info, warn, error, debug, instrument};
8use schemars::{JsonSchema, schema_for};
9
10/// Low-level client trait that only requires implementing ask_raw.
11/// This trait can be used as dyn LowLevelClient for dynamic dispatch.
12/// JSON processing is handled by utility functions with a convenience method.
13#[async_trait]
14pub trait LowLevelClient: Send + Sync + Debug{
15    /// The only method that implementations must provide
16    async fn ask_raw(&self, prompt: String) -> Result<String, AIError>;
17    
18    /// Simple JSON extraction from a prompt response (default implementation)
19    async fn ask_json(&self, prompt: String) -> Result<String, AIError> {
20        let raw_response = self.ask_raw(prompt).await?;
21        Ok(json_utils::find_json(&raw_response))
22    }
23    
24    /// Clone this client into a boxed trait object
25    fn clone_box(&self) -> Box<dyn LowLevelClient>;
26}
27
28// Implement Clone for Box<dyn LowLevelClient>
29impl Clone for Box<dyn LowLevelClient> {
30    fn clone(&self) -> Self {
31        self.clone_box()
32    }
33}
34
35// Implement LowLevelClient for Box<dyn LowLevelClient>
36#[async_trait]
37impl LowLevelClient for Box<dyn LowLevelClient> {
38    async fn ask_raw(&self, prompt: String) -> Result<String, AIError> {
39        self.as_ref().ask_raw(prompt).await
40    }
41    
42    fn clone_box(&self) -> Box<dyn LowLevelClient> {
43        self.as_ref().clone_box()
44    }
45}
46
47
48
49#[derive(Debug, Clone)]
50pub struct RetryConfig {
51    pub max_retries: HashMap<String, usize>,
52    pub default_max_retries: usize,
53}
54
55impl Default for RetryConfig {
56    fn default() -> Self {
57        let mut max_retries = HashMap::new();
58        max_retries.insert("rate_limit".to_string(), 1);
59        max_retries.insert("api_error".to_string(), 1);
60        max_retries.insert("http_error".to_string(), 1);
61        max_retries.insert("json_parse_error".to_string(), 2);
62        
63        Self {
64            max_retries,
65            default_max_retries: 1,
66        }
67    }
68}
69
70
71#[derive(Clone)]
72/// Query resolver that wraps a LowLevelClient and provides all generic methods.
73/// This allows for flexible composition - you can have arrays of dyn LowLevelClient
74/// and wrap them in QueryResolver as needed.
75pub struct QueryResolver<C: LowLevelClient> {
76    client: C,
77    config: RetryConfig,
78}
79
80impl<C: LowLevelClient> QueryResolver<C> {
81    pub fn new(client: C, config: RetryConfig) -> Self {
82        info!(default_max_retries = config.default_max_retries, "Creating new QueryResolver with retry config");
83        Self { client, config }
84    }
85    
86    /// Get a reference to the underlying client
87    pub fn client(&self) -> &C {
88        &self.client
89    }
90    
91    /// Get a reference to the retry configuration
92    pub fn config(&self) -> &RetryConfig {
93        &self.config
94    }
95    
96    /// Update the retry configuration
97    pub fn with_config(mut self, config: RetryConfig) -> Self {
98        self.config = config;
99        self
100    }
101    
102    /// Query with retry logic and automatic JSON parsing
103    #[instrument(skip(self, prompt), fields(prompt_len = prompt.len()))]
104    pub async fn query_raw<T>(&self, prompt: String) -> Result<T, QueryResolverError>
105    where
106        T: DeserializeOwned + Send,
107    {
108        info!(prompt_len = prompt.len(), "Starting query");
109        let result = self.ask_with_retry(prompt).await;
110        match &result {
111            Ok(_) => info!("Query completed successfully"),
112            Err(e) => error!(error = %e, "Query failed"),
113        }
114        result
115    }
116    
117    /// Query with automatic schema-aware prompt augmentation
118    #[instrument(skip(self, prompt), fields(prompt_len = prompt.len()))]
119    pub async fn query<T>(&self, prompt: String) -> Result<T, QueryResolverError>
120    where
121        T: DeserializeOwned + JsonSchema + Send,
122    {
123        info!(prompt_len = prompt.len(), "Starting schema-aware query");
124        let result = self.ask_with_schema(prompt).await;
125        match &result {
126            Ok(_) => info!("Schema-aware query completed successfully"),
127            Err(e) => error!(error = %e, "Schema-aware query failed"),
128        }
129        result
130    }
131    
132    /// Internal method for retry logic with JSON parsing
133    #[instrument(skip(self, prompt), fields(prompt_len = prompt.len()))]
134    async fn ask_with_retry<T>(&self, prompt: String) -> Result<T, QueryResolverError>
135    where
136        T: DeserializeOwned + Send,
137    {
138        let mut attempt = 0;
139        let mut context = String::new();
140        
141        info!(attempt = 0, max_retries = self.config.default_max_retries, "Starting retry loop for prompt");
142        
143        loop {
144            let full_prompt = if context.is_empty() {
145                prompt.clone()
146            } else {
147                format!("{}\n\nPrevious attempt failed: {}\nPlease fix the issue and respond with valid JSON.", prompt, context)
148            };
149            
150            debug!(attempt = attempt + 1, prompt_len = full_prompt.len(), "Making API call");
151            match self.client.ask_json(full_prompt.clone()).await {
152                Ok(response) => {
153                    debug!(response_len = response.len(), "Received API response");
154                    match serde_json::from_str::<T>(&response) {
155                        Ok(parsed) => {
156                            info!(attempt = attempt + 1, "Successfully parsed JSON response");
157                            return Ok(parsed);
158                        },
159                        Err(json_err) => {
160                            warn!(
161                                error = %json_err, 
162                                response_preview = &response[..response.len().min(200)],
163                                "Initial JSON parsing failed, trying advanced extraction"
164                            );
165                            
166                            // Try advanced JSON extraction on the raw response
167                            if let Ok(raw_response) = self.client.ask_raw(full_prompt.clone()).await {
168                                if let Some(extracted_json) = json_utils::extract_json_advanced(&raw_response) {
169                                    debug!(extracted_len = extracted_json.len(), "Trying to parse extracted JSON after initial failure");
170                                    match serde_json::from_str::<T>(&extracted_json) {
171                                        Ok(parsed) => {
172                                            info!(attempt = attempt + 1, "Successfully parsed extracted JSON after initial deserialization failure");
173                                            return Ok(parsed);
174                                        },
175                                        Err(extracted_err) => {
176                                            warn!(
177                                                error = %extracted_err,
178                                                extracted_preview = &extracted_json[..extracted_json.len().min(200)],
179                                                "Advanced extraction also failed to parse"
180                                            );
181                                        }
182                                    }
183                                } else {
184                                    warn!("Advanced extraction could not find valid JSON in raw response");
185                                }
186                            }
187                            
188                            // If we're at max retries, return the error
189                            let max_retries = self.config.max_retries.get("json_parse_error")
190                                .unwrap_or(&self.config.default_max_retries);
191                            
192                            if attempt >= *max_retries {
193                                error!(
194                                    error = %json_err,
195                                    attempt = attempt + 1,
196                                    max_retries = max_retries,
197                                    "Max retries exceeded for JSON parsing"
198                                );
199                                return Err(QueryResolverError::JsonDeserialization(json_err, response));
200                            }
201                            
202                            // Otherwise, retry with context about the JSON parsing failure
203                            warn!(
204                                attempt = attempt + 1,
205                                max_retries = max_retries,
206                                "Retrying due to JSON parsing failure"
207                            );
208                            context = format!("JSON parsing failed: {}. Response was: {}", 
209                                             json_err, 
210                                             &response[..response.len().min(500)]);
211                            attempt += 1;
212                        }
213                    }
214                }
215                Err(ai_error) => {
216                    warn!(error = %ai_error, attempt = attempt + 1, "API call failed");
217                    let error_type = match &ai_error {
218                        AIError::Claude(claude_err) => match claude_err {
219                            crate::error::ClaudeError::RateLimit => "rate_limit",
220                            crate::error::ClaudeError::Http(_) => "http_error",
221                            crate::error::ClaudeError::Api(_) => "api_error",
222                            _ => "other",
223                        },
224                        AIError::OpenAI(openai_err) => match openai_err {
225                            crate::error::OpenAIError::RateLimit => "rate_limit",
226                            crate::error::OpenAIError::Http(_) => "http_error", 
227                            crate::error::OpenAIError::Api(_) => "api_error",
228                            _ => "other",
229                        },
230                        AIError::DeepSeek(deepseek_err) => match deepseek_err {
231                            crate::error::DeepSeekError::RateLimit => "rate_limit",
232                            crate::error::DeepSeekError::Http(_) => "http_error",
233                            crate::error::DeepSeekError::Api(_) => "api_error",
234                            _ => "other",
235                        },
236                        AIError::Mock(_) => "mock_error",
237                    };
238                    
239                    let max_retries = self.config.max_retries.get(error_type)
240                        .unwrap_or(&self.config.default_max_retries);
241                    
242                    if attempt >= *max_retries {
243                        error!(
244                            error = %ai_error, 
245                            error_type = error_type,
246                            max_retries = max_retries,
247                            "Max retries exceeded for API error"
248                        );
249                        return Err(QueryResolverError::Ai(ai_error));
250                    }
251                    
252                    info!(
253                        error_type = error_type,
254                        attempt = attempt + 1,
255                        max_retries = max_retries,
256                        "Retrying after API error"
257                    );
258                    context = format!("API call failed: {}", ai_error);
259                    attempt += 1;
260                }
261            }
262        }
263    }
264    
265    /// Generate a JSON schema for the return type and append it to the prompt
266    pub fn augment_prompt_with_schema<T>(&self, prompt: String) -> String
267    where
268        T: JsonSchema,
269    {
270        let schema = schema_for!(T);
271        let schema_json = serde_json::to_string_pretty(&schema)
272            .unwrap_or_else(|_| "{}".to_string());
273        
274        debug!(schema_len = schema_json.len(), "Generated JSON schema for return type");
275        
276        format!(
277            r#"{prompt}
278You are tasked with generating a value satisfying a schema. First I will give you an example exchange then I will provide the schema of interest
279Example Schema:
280{{
281    "type": "object",
282    "properties": {{
283        "name": {{"type": "string"}},
284        "age": {{"type": "integer", "minimum": 0}},
285        "email": {{"type": "string"}},
286        "isActive": {{"type": "boolean"}},
287        "hobbies": {{"type": "array", "items": {{"type": "string"}}}}
288    }},
289    "required": ["name", "age", "email", "isActive"]
290}}
291Example response:
292{{"name": "Alice Smith", "age": 28, "email": "alice@example.com", "isActive": true, "hobbies": ["reading", "cooking"]}}
293Please provide a response matching this schema
294```json
295{schema_json}
296```
297"#
298        )
299    }
300    
301    /// Ask with automatic schema-aware prompt augmentation
302    #[instrument(skip(self, prompt), fields(prompt_len = prompt.len()))]
303    async fn ask_with_schema<T>(&self, prompt: String) -> Result<T, QueryResolverError>
304    where
305        T: DeserializeOwned + JsonSchema + Send,
306    {
307        info!("Starting schema-aware query");
308        let augmented_prompt = self.augment_prompt_with_schema::<T>(prompt);
309        debug!(augmented_prompt_len = augmented_prompt.len(), "Generated schema-augmented prompt");
310        self.ask_with_retry(augmented_prompt).await
311    }
312}
313
314