Skip to main content

shodh_memory/query_parsing/
llm_parser.rs

1//! LLM-Based Query Parser
2//!
3//! Uses a local LLM server (Ollama, LM Studio, etc.) via HTTP API for query parsing.
4//! Provides better temporal reasoning and entity extraction than rule-based.
5
6use super::parser_trait::*;
7use chrono::{DateTime, NaiveDate, Utc};
8use parking_lot::Mutex;
9use serde::{Deserialize, Serialize};
10use std::collections::HashMap;
11use std::time::Duration;
12
13/// LLM-based query parser using a local HTTP API (Ollama, LM Studio, etc.)
14pub struct LlmParser {
15    client: reqwest::blocking::Client,
16    endpoint: String,
17    model: String,
18    generation_lock: Mutex<()>,
19}
20
21/// Request format for Ollama API
22#[derive(Debug, Serialize)]
23struct OllamaRequest {
24    model: String,
25    prompt: String,
26    stream: bool,
27    options: OllamaOptions,
28}
29
30#[derive(Debug, Serialize)]
31struct OllamaOptions {
32    temperature: f32,
33    num_predict: i32,
34}
35
36/// Response format from Ollama API
37#[derive(Debug, Deserialize)]
38struct OllamaResponse {
39    response: String,
40}
41
42/// Request format for OpenAI-compatible APIs (LM Studio, vLLM, etc.)
43#[derive(Debug, Serialize)]
44struct OpenAIRequest {
45    model: String,
46    messages: Vec<OpenAIMessage>,
47    temperature: f32,
48    max_tokens: i32,
49}
50
51#[derive(Debug, Serialize)]
52struct OpenAIMessage {
53    role: String,
54    content: String,
55}
56
57/// Response format from OpenAI-compatible APIs
58#[derive(Debug, Deserialize)]
59struct OpenAIResponse {
60    choices: Vec<OpenAIChoice>,
61}
62
63#[derive(Debug, Deserialize)]
64struct OpenAIChoice {
65    message: OpenAIMessageResponse,
66}
67
68#[derive(Debug, Deserialize)]
69struct OpenAIMessageResponse {
70    content: String,
71}
72
73/// Expected JSON output format from the LLM
74#[derive(Debug, Deserialize, Serialize)]
75struct LlmOutput {
76    entities: Vec<LlmEntity>,
77    events: Vec<String>,
78    modifiers: Vec<String>,
79    temporal: LlmTemporal,
80    is_attribute_query: bool,
81    attribute_entity: Option<String>,
82    attribute_name: Option<String>,
83    confidence: f32,
84}
85
86#[derive(Debug, Deserialize, Serialize)]
87struct LlmEntity {
88    text: String,
89    #[serde(rename = "type")]
90    entity_type: String,
91    negated: bool,
92}
93
94#[derive(Debug, Deserialize, Serialize)]
95struct LlmTemporal {
96    has_temporal_intent: bool,
97    intent: String,
98    relative_refs: Vec<LlmRelativeRef>,
99    resolved_dates: Vec<String>,
100}
101
102#[derive(Debug, Deserialize, Serialize)]
103struct LlmRelativeRef {
104    text: String,
105    resolved_date: Option<String>,
106    direction: String,
107}
108
109/// API type for the LLM server
110#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
111pub enum ApiType {
112    /// Ollama API (default)
113    #[default]
114    Ollama,
115    /// OpenAI-compatible API (LM Studio, vLLM, text-generation-webui, etc.)
116    OpenAI,
117}
118
119impl LlmParser {
120    /// Create a new LLM parser with Ollama backend
121    ///
122    /// # Arguments
123    /// * `endpoint` - Base URL (e.g., "http://localhost:11434" for Ollama)
124    /// * `model` - Model name (e.g., "qwen2.5:1.5b", "llama3.2:1b")
125    pub fn new(endpoint: &str, model: &str) -> Self {
126        Self::with_api_type(endpoint, model, ApiType::Ollama)
127    }
128
129    /// Create a new LLM parser with specified API type
130    pub fn with_api_type(endpoint: &str, model: &str, _api_type: ApiType) -> Self {
131        let client = reqwest::blocking::Client::builder()
132            .timeout(Duration::from_secs(30))
133            .build()
134            .expect("Failed to create HTTP client");
135
136        Self {
137            client,
138            endpoint: endpoint.trim_end_matches('/').to_string(),
139            model: model.to_string(),
140            generation_lock: Mutex::new(()),
141        }
142    }
143
144    /// Build the prompt for query parsing
145    fn build_prompt(&self, query: &str, context_date: Option<DateTime<Utc>>) -> String {
146        let date_context = context_date
147            .map(|d| format!("Today's date: {}", d.format("%B %d, %Y")))
148            .unwrap_or_else(|| "Today's date: unknown".to_string());
149
150        format!(
151            r#"You are a query parser. Extract structured information from the query.
152Output ONLY valid JSON, no explanation or markdown.
153
154{date_context}
155
156Parse this query: "{query}"
157
158Output this exact JSON structure:
159{{"entities":[{{"text":"name","type":"person|place|thing|event|time","negated":false}}],"events":["verb"],"modifiers":["adjective"],"temporal":{{"has_temporal_intent":true,"intent":"when_question|specific_time|ordering|duration|none","relative_refs":[{{"text":"last year","resolved_date":"2024-01-01","direction":"past"}}],"resolved_dates":["2024-01-01"]}},"is_attribute_query":false,"attribute_entity":null,"attribute_name":null,"confidence":0.9}}"#
160        )
161    }
162
163    /// Generate using Ollama API
164    fn generate_ollama(&self, prompt: &str) -> Result<String, String> {
165        let request = OllamaRequest {
166            model: self.model.clone(),
167            prompt: prompt.to_string(),
168            stream: false,
169            options: OllamaOptions {
170                temperature: 0.1,
171                num_predict: 512,
172            },
173        };
174
175        let url = format!("{}/api/generate", self.endpoint);
176
177        let response = self
178            .client
179            .post(&url)
180            .json(&request)
181            .send()
182            .map_err(|e| format!("HTTP request failed: {}", e))?;
183
184        if !response.status().is_success() {
185            return Err(format!("API returned status: {}", response.status()));
186        }
187
188        let ollama_response: OllamaResponse = response
189            .json()
190            .map_err(|e| format!("Failed to parse response: {}", e))?;
191
192        Ok(ollama_response.response)
193    }
194
195    /// Generate using OpenAI-compatible API
196    fn generate_openai(&self, prompt: &str) -> Result<String, String> {
197        let request = OpenAIRequest {
198            model: self.model.clone(),
199            messages: vec![OpenAIMessage {
200                role: "user".to_string(),
201                content: prompt.to_string(),
202            }],
203            temperature: 0.1,
204            max_tokens: 512,
205        };
206
207        let url = format!("{}/v1/chat/completions", self.endpoint);
208
209        let response = self
210            .client
211            .post(&url)
212            .json(&request)
213            .send()
214            .map_err(|e| format!("HTTP request failed: {}", e))?;
215
216        if !response.status().is_success() {
217            return Err(format!("API returned status: {}", response.status()));
218        }
219
220        let openai_response: OpenAIResponse = response
221            .json()
222            .map_err(|e| format!("Failed to parse response: {}", e))?;
223
224        openai_response
225            .choices
226            .first()
227            .map(|c| c.message.content.clone())
228            .ok_or_else(|| "No response from API".to_string())
229    }
230
231    /// Try Ollama first, fall back to OpenAI-compatible API
232    fn generate(&self, prompt: &str) -> Result<String, String> {
233        // Try Ollama first
234        if let Ok(response) = self.generate_ollama(prompt) {
235            return Ok(response);
236        }
237
238        // Fall back to OpenAI-compatible API
239        self.generate_openai(prompt)
240    }
241
242    /// Parse the LLM output JSON into ParsedQuery
243    fn parse_output(&self, output: &str, original_query: &str) -> ParsedQuery {
244        let json_str = extract_json(output);
245
246        match serde_json::from_str::<LlmOutput>(&json_str) {
247            Ok(llm_out) => self.convert_llm_output(llm_out, original_query),
248            Err(e) => {
249                tracing::warn!("Failed to parse LLM output: {}, raw: {}", e, output);
250                ParsedQuery::empty(original_query)
251            }
252        }
253    }
254
255    /// Convert LLM output to ParsedQuery
256    fn convert_llm_output(&self, llm_out: LlmOutput, original_query: &str) -> ParsedQuery {
257        let entities: Vec<Entity> = llm_out
258            .entities
259            .into_iter()
260            .map(|e| Entity {
261                text: e.text.clone(),
262                stem: stem_word(&e.text),
263                entity_type: parse_entity_type(&e.entity_type),
264                ic_weight: 1.0,
265                negated: e.negated,
266            })
267            .collect();
268
269        let events: Vec<Event> = llm_out
270            .events
271            .into_iter()
272            .map(|e| Event {
273                text: e.clone(),
274                stem: stem_word(&e),
275                ic_weight: 0.7,
276            })
277            .collect();
278
279        let relative_refs: Vec<RelativeTimeRef> = llm_out
280            .temporal
281            .relative_refs
282            .into_iter()
283            .map(|r| RelativeTimeRef {
284                text: r.text,
285                resolved: r
286                    .resolved_date
287                    .and_then(|d| NaiveDate::parse_from_str(&d, "%Y-%m-%d").ok()),
288                direction: parse_direction(&r.direction),
289                unit: TimeUnit::Unknown,
290                offset: 1,
291            })
292            .collect();
293
294        let resolved_dates: Vec<NaiveDate> = llm_out
295            .temporal
296            .resolved_dates
297            .iter()
298            .filter_map(|d| NaiveDate::parse_from_str(d, "%Y-%m-%d").ok())
299            .collect();
300
301        let attribute = if llm_out.is_attribute_query {
302            llm_out.attribute_entity.map(|entity| AttributeQuery {
303                entity,
304                attribute: llm_out.attribute_name.unwrap_or_default(),
305                synonyms: Vec::new(),
306            })
307        } else {
308            None
309        };
310
311        let mut ic_weights = HashMap::new();
312        for e in &entities {
313            ic_weights.insert(e.text.to_lowercase(), e.ic_weight);
314        }
315        for e in &events {
316            ic_weights.insert(e.text.to_lowercase(), e.ic_weight);
317        }
318
319        ParsedQuery {
320            original: original_query.to_string(),
321            entities,
322            events,
323            modifiers: llm_out.modifiers,
324            temporal: TemporalInfo {
325                has_temporal_intent: llm_out.temporal.has_temporal_intent,
326                intent: parse_temporal_intent(&llm_out.temporal.intent),
327                relative_refs,
328                resolved_dates,
329                absolute_dates: Vec::new(),
330            },
331            is_attribute_query: llm_out.is_attribute_query,
332            attribute,
333            compounds: Vec::new(),
334            ic_weights,
335            confidence: llm_out.confidence,
336        }
337    }
338
339    /// Check if the LLM server is reachable
340    pub fn is_server_available(&self) -> bool {
341        // Try Ollama health check
342        if self
343            .client
344            .get(format!("{}/api/tags", self.endpoint))
345            .send()
346            .map(|r| r.status().is_success())
347            .unwrap_or(false)
348        {
349            return true;
350        }
351
352        // Try OpenAI-compatible models endpoint
353        self.client
354            .get(format!("{}/v1/models", self.endpoint))
355            .send()
356            .map(|r| r.status().is_success())
357            .unwrap_or(false)
358    }
359}
360
361impl QueryParser for LlmParser {
362    fn parse(&self, query: &str, context_date: Option<DateTime<Utc>>) -> ParsedQuery {
363        let _lock = self.generation_lock.lock();
364
365        let prompt = self.build_prompt(query, context_date);
366
367        match self.generate(&prompt) {
368            Ok(output) => self.parse_output(&output, query),
369            Err(e) => {
370                tracing::error!("LLM generation failed: {}", e);
371                ParsedQuery::empty(query)
372            }
373        }
374    }
375
376    fn name(&self) -> &'static str {
377        "LlmParser"
378    }
379
380    fn is_available(&self) -> bool {
381        self.is_server_available()
382    }
383}
384
385/// Extract JSON from potentially messy LLM output
386fn extract_json(output: &str) -> String {
387    // Remove markdown code blocks if present
388    let cleaned = output
389        .trim()
390        .trim_start_matches("```json")
391        .trim_start_matches("```")
392        .trim_end_matches("```")
393        .trim();
394
395    // Find the first { and matching }
396    if let Some(start) = cleaned.find('{') {
397        let mut depth = 0;
398        let mut end = start;
399        for (i, c) in cleaned[start..].chars().enumerate() {
400            match c {
401                '{' => depth += 1,
402                '}' => {
403                    depth -= 1;
404                    if depth == 0 {
405                        end = start + i + 1;
406                        break;
407                    }
408                }
409                _ => {}
410            }
411        }
412        cleaned[start..end].to_string()
413    } else {
414        cleaned.to_string()
415    }
416}
417
418/// Simple stemming using rust_stemmers
419fn stem_word(word: &str) -> String {
420    use rust_stemmers::{Algorithm, Stemmer};
421    let stemmer = Stemmer::create(Algorithm::English);
422    stemmer.stem(&word.to_lowercase()).to_string()
423}
424
425/// Parse entity type string
426fn parse_entity_type(s: &str) -> EntityType {
427    match s.to_lowercase().as_str() {
428        "person" => EntityType::Person,
429        "place" => EntityType::Place,
430        "thing" => EntityType::Thing,
431        "event" => EntityType::Event,
432        "time" => EntityType::Time,
433        _ => EntityType::Unknown,
434    }
435}
436
437/// Parse direction string
438fn parse_direction(s: &str) -> TimeDirection {
439    match s.to_lowercase().as_str() {
440        "past" => TimeDirection::Past,
441        "future" => TimeDirection::Future,
442        "current" => TimeDirection::Current,
443        _ => TimeDirection::Past,
444    }
445}
446
447/// Parse temporal intent string
448fn parse_temporal_intent(s: &str) -> TemporalIntent {
449    match s.to_lowercase().as_str() {
450        "when_question" => TemporalIntent::WhenQuestion,
451        "specific_time" => TemporalIntent::SpecificTime,
452        "ordering" => TemporalIntent::Ordering,
453        "duration" => TemporalIntent::Duration,
454        _ => TemporalIntent::None,
455    }
456}
457
458#[cfg(test)]
459mod tests {
460    use super::*;
461
462    #[test]
463    fn test_extract_json() {
464        let output = r#"Here is the JSON: {"entities": [], "confidence": 0.9} and some more text"#;
465        let json = extract_json(output);
466        assert!(json.starts_with('{'));
467        assert!(json.ends_with('}'));
468    }
469
470    #[test]
471    fn test_extract_json_with_markdown() {
472        let output = r#"```json
473{"entities": [], "confidence": 0.9}
474```"#;
475        let json = extract_json(output);
476        assert_eq!(json, r#"{"entities": [], "confidence": 0.9}"#);
477    }
478
479    #[test]
480    fn test_parse_entity_type() {
481        assert_eq!(parse_entity_type("person"), EntityType::Person);
482        assert_eq!(parse_entity_type("PLACE"), EntityType::Place);
483        assert_eq!(parse_entity_type("unknown_type"), EntityType::Unknown);
484    }
485}