1use super::parser_trait::*;
7use chrono::{DateTime, NaiveDate, Utc};
8use parking_lot::Mutex;
9use serde::{Deserialize, Serialize};
10use std::collections::HashMap;
11use std::time::Duration;
12
13pub struct LlmParser {
15 client: reqwest::blocking::Client,
16 endpoint: String,
17 model: String,
18 generation_lock: Mutex<()>,
19}
20
21#[derive(Debug, Serialize)]
23struct OllamaRequest {
24 model: String,
25 prompt: String,
26 stream: bool,
27 options: OllamaOptions,
28}
29
30#[derive(Debug, Serialize)]
31struct OllamaOptions {
32 temperature: f32,
33 num_predict: i32,
34}
35
36#[derive(Debug, Deserialize)]
38struct OllamaResponse {
39 response: String,
40}
41
42#[derive(Debug, Serialize)]
44struct OpenAIRequest {
45 model: String,
46 messages: Vec<OpenAIMessage>,
47 temperature: f32,
48 max_tokens: i32,
49}
50
51#[derive(Debug, Serialize)]
52struct OpenAIMessage {
53 role: String,
54 content: String,
55}
56
57#[derive(Debug, Deserialize)]
59struct OpenAIResponse {
60 choices: Vec<OpenAIChoice>,
61}
62
63#[derive(Debug, Deserialize)]
64struct OpenAIChoice {
65 message: OpenAIMessageResponse,
66}
67
68#[derive(Debug, Deserialize)]
69struct OpenAIMessageResponse {
70 content: String,
71}
72
73#[derive(Debug, Deserialize, Serialize)]
75struct LlmOutput {
76 entities: Vec<LlmEntity>,
77 events: Vec<String>,
78 modifiers: Vec<String>,
79 temporal: LlmTemporal,
80 is_attribute_query: bool,
81 attribute_entity: Option<String>,
82 attribute_name: Option<String>,
83 confidence: f32,
84}
85
86#[derive(Debug, Deserialize, Serialize)]
87struct LlmEntity {
88 text: String,
89 #[serde(rename = "type")]
90 entity_type: String,
91 negated: bool,
92}
93
94#[derive(Debug, Deserialize, Serialize)]
95struct LlmTemporal {
96 has_temporal_intent: bool,
97 intent: String,
98 relative_refs: Vec<LlmRelativeRef>,
99 resolved_dates: Vec<String>,
100}
101
102#[derive(Debug, Deserialize, Serialize)]
103struct LlmRelativeRef {
104 text: String,
105 resolved_date: Option<String>,
106 direction: String,
107}
108
109#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
111pub enum ApiType {
112 #[default]
114 Ollama,
115 OpenAI,
117}
118
119impl LlmParser {
120 pub fn new(endpoint: &str, model: &str) -> Self {
126 Self::with_api_type(endpoint, model, ApiType::Ollama)
127 }
128
129 pub fn with_api_type(endpoint: &str, model: &str, _api_type: ApiType) -> Self {
131 let client = reqwest::blocking::Client::builder()
132 .timeout(Duration::from_secs(30))
133 .build()
134 .expect("Failed to create HTTP client");
135
136 Self {
137 client,
138 endpoint: endpoint.trim_end_matches('/').to_string(),
139 model: model.to_string(),
140 generation_lock: Mutex::new(()),
141 }
142 }
143
144 fn build_prompt(&self, query: &str, context_date: Option<DateTime<Utc>>) -> String {
146 let date_context = context_date
147 .map(|d| format!("Today's date: {}", d.format("%B %d, %Y")))
148 .unwrap_or_else(|| "Today's date: unknown".to_string());
149
150 format!(
151 r#"You are a query parser. Extract structured information from the query.
152Output ONLY valid JSON, no explanation or markdown.
153
154{date_context}
155
156Parse this query: "{query}"
157
158Output this exact JSON structure:
159{{"entities":[{{"text":"name","type":"person|place|thing|event|time","negated":false}}],"events":["verb"],"modifiers":["adjective"],"temporal":{{"has_temporal_intent":true,"intent":"when_question|specific_time|ordering|duration|none","relative_refs":[{{"text":"last year","resolved_date":"2024-01-01","direction":"past"}}],"resolved_dates":["2024-01-01"]}},"is_attribute_query":false,"attribute_entity":null,"attribute_name":null,"confidence":0.9}}"#
160 )
161 }
162
163 fn generate_ollama(&self, prompt: &str) -> Result<String, String> {
165 let request = OllamaRequest {
166 model: self.model.clone(),
167 prompt: prompt.to_string(),
168 stream: false,
169 options: OllamaOptions {
170 temperature: 0.1,
171 num_predict: 512,
172 },
173 };
174
175 let url = format!("{}/api/generate", self.endpoint);
176
177 let response = self
178 .client
179 .post(&url)
180 .json(&request)
181 .send()
182 .map_err(|e| format!("HTTP request failed: {}", e))?;
183
184 if !response.status().is_success() {
185 return Err(format!("API returned status: {}", response.status()));
186 }
187
188 let ollama_response: OllamaResponse = response
189 .json()
190 .map_err(|e| format!("Failed to parse response: {}", e))?;
191
192 Ok(ollama_response.response)
193 }
194
195 fn generate_openai(&self, prompt: &str) -> Result<String, String> {
197 let request = OpenAIRequest {
198 model: self.model.clone(),
199 messages: vec![OpenAIMessage {
200 role: "user".to_string(),
201 content: prompt.to_string(),
202 }],
203 temperature: 0.1,
204 max_tokens: 512,
205 };
206
207 let url = format!("{}/v1/chat/completions", self.endpoint);
208
209 let response = self
210 .client
211 .post(&url)
212 .json(&request)
213 .send()
214 .map_err(|e| format!("HTTP request failed: {}", e))?;
215
216 if !response.status().is_success() {
217 return Err(format!("API returned status: {}", response.status()));
218 }
219
220 let openai_response: OpenAIResponse = response
221 .json()
222 .map_err(|e| format!("Failed to parse response: {}", e))?;
223
224 openai_response
225 .choices
226 .first()
227 .map(|c| c.message.content.clone())
228 .ok_or_else(|| "No response from API".to_string())
229 }
230
231 fn generate(&self, prompt: &str) -> Result<String, String> {
233 if let Ok(response) = self.generate_ollama(prompt) {
235 return Ok(response);
236 }
237
238 self.generate_openai(prompt)
240 }
241
242 fn parse_output(&self, output: &str, original_query: &str) -> ParsedQuery {
244 let json_str = extract_json(output);
245
246 match serde_json::from_str::<LlmOutput>(&json_str) {
247 Ok(llm_out) => self.convert_llm_output(llm_out, original_query),
248 Err(e) => {
249 tracing::warn!("Failed to parse LLM output: {}, raw: {}", e, output);
250 ParsedQuery::empty(original_query)
251 }
252 }
253 }
254
255 fn convert_llm_output(&self, llm_out: LlmOutput, original_query: &str) -> ParsedQuery {
257 let entities: Vec<Entity> = llm_out
258 .entities
259 .into_iter()
260 .map(|e| Entity {
261 text: e.text.clone(),
262 stem: stem_word(&e.text),
263 entity_type: parse_entity_type(&e.entity_type),
264 ic_weight: 1.0,
265 negated: e.negated,
266 })
267 .collect();
268
269 let events: Vec<Event> = llm_out
270 .events
271 .into_iter()
272 .map(|e| Event {
273 text: e.clone(),
274 stem: stem_word(&e),
275 ic_weight: 0.7,
276 })
277 .collect();
278
279 let relative_refs: Vec<RelativeTimeRef> = llm_out
280 .temporal
281 .relative_refs
282 .into_iter()
283 .map(|r| RelativeTimeRef {
284 text: r.text,
285 resolved: r
286 .resolved_date
287 .and_then(|d| NaiveDate::parse_from_str(&d, "%Y-%m-%d").ok()),
288 direction: parse_direction(&r.direction),
289 unit: TimeUnit::Unknown,
290 offset: 1,
291 })
292 .collect();
293
294 let resolved_dates: Vec<NaiveDate> = llm_out
295 .temporal
296 .resolved_dates
297 .iter()
298 .filter_map(|d| NaiveDate::parse_from_str(d, "%Y-%m-%d").ok())
299 .collect();
300
301 let attribute = if llm_out.is_attribute_query {
302 llm_out.attribute_entity.map(|entity| AttributeQuery {
303 entity,
304 attribute: llm_out.attribute_name.unwrap_or_default(),
305 synonyms: Vec::new(),
306 })
307 } else {
308 None
309 };
310
311 let mut ic_weights = HashMap::new();
312 for e in &entities {
313 ic_weights.insert(e.text.to_lowercase(), e.ic_weight);
314 }
315 for e in &events {
316 ic_weights.insert(e.text.to_lowercase(), e.ic_weight);
317 }
318
319 ParsedQuery {
320 original: original_query.to_string(),
321 entities,
322 events,
323 modifiers: llm_out.modifiers,
324 temporal: TemporalInfo {
325 has_temporal_intent: llm_out.temporal.has_temporal_intent,
326 intent: parse_temporal_intent(&llm_out.temporal.intent),
327 relative_refs,
328 resolved_dates,
329 absolute_dates: Vec::new(),
330 },
331 is_attribute_query: llm_out.is_attribute_query,
332 attribute,
333 compounds: Vec::new(),
334 ic_weights,
335 confidence: llm_out.confidence,
336 }
337 }
338
339 pub fn is_server_available(&self) -> bool {
341 if self
343 .client
344 .get(format!("{}/api/tags", self.endpoint))
345 .send()
346 .map(|r| r.status().is_success())
347 .unwrap_or(false)
348 {
349 return true;
350 }
351
352 self.client
354 .get(format!("{}/v1/models", self.endpoint))
355 .send()
356 .map(|r| r.status().is_success())
357 .unwrap_or(false)
358 }
359}
360
361impl QueryParser for LlmParser {
362 fn parse(&self, query: &str, context_date: Option<DateTime<Utc>>) -> ParsedQuery {
363 let _lock = self.generation_lock.lock();
364
365 let prompt = self.build_prompt(query, context_date);
366
367 match self.generate(&prompt) {
368 Ok(output) => self.parse_output(&output, query),
369 Err(e) => {
370 tracing::error!("LLM generation failed: {}", e);
371 ParsedQuery::empty(query)
372 }
373 }
374 }
375
376 fn name(&self) -> &'static str {
377 "LlmParser"
378 }
379
380 fn is_available(&self) -> bool {
381 self.is_server_available()
382 }
383}
384
385fn extract_json(output: &str) -> String {
387 let cleaned = output
389 .trim()
390 .trim_start_matches("```json")
391 .trim_start_matches("```")
392 .trim_end_matches("```")
393 .trim();
394
395 if let Some(start) = cleaned.find('{') {
397 let mut depth = 0;
398 let mut end = start;
399 for (i, c) in cleaned[start..].chars().enumerate() {
400 match c {
401 '{' => depth += 1,
402 '}' => {
403 depth -= 1;
404 if depth == 0 {
405 end = start + i + 1;
406 break;
407 }
408 }
409 _ => {}
410 }
411 }
412 cleaned[start..end].to_string()
413 } else {
414 cleaned.to_string()
415 }
416}
417
418fn stem_word(word: &str) -> String {
420 use rust_stemmers::{Algorithm, Stemmer};
421 let stemmer = Stemmer::create(Algorithm::English);
422 stemmer.stem(&word.to_lowercase()).to_string()
423}
424
425fn parse_entity_type(s: &str) -> EntityType {
427 match s.to_lowercase().as_str() {
428 "person" => EntityType::Person,
429 "place" => EntityType::Place,
430 "thing" => EntityType::Thing,
431 "event" => EntityType::Event,
432 "time" => EntityType::Time,
433 _ => EntityType::Unknown,
434 }
435}
436
437fn parse_direction(s: &str) -> TimeDirection {
439 match s.to_lowercase().as_str() {
440 "past" => TimeDirection::Past,
441 "future" => TimeDirection::Future,
442 "current" => TimeDirection::Current,
443 _ => TimeDirection::Past,
444 }
445}
446
447fn parse_temporal_intent(s: &str) -> TemporalIntent {
449 match s.to_lowercase().as_str() {
450 "when_question" => TemporalIntent::WhenQuestion,
451 "specific_time" => TemporalIntent::SpecificTime,
452 "ordering" => TemporalIntent::Ordering,
453 "duration" => TemporalIntent::Duration,
454 _ => TemporalIntent::None,
455 }
456}
457
458#[cfg(test)]
459mod tests {
460 use super::*;
461
462 #[test]
463 fn test_extract_json() {
464 let output = r#"Here is the JSON: {"entities": [], "confidence": 0.9} and some more text"#;
465 let json = extract_json(output);
466 assert!(json.starts_with('{'));
467 assert!(json.ends_with('}'));
468 }
469
470 #[test]
471 fn test_extract_json_with_markdown() {
472 let output = r#"```json
473{"entities": [], "confidence": 0.9}
474```"#;
475 let json = extract_json(output);
476 assert_eq!(json, r#"{"entities": [], "confidence": 0.9}"#);
477 }
478
479 #[test]
480 fn test_parse_entity_type() {
481 assert_eq!(parse_entity_type("person"), EntityType::Person);
482 assert_eq!(parse_entity_type("PLACE"), EntityType::Place);
483 assert_eq!(parse_entity_type("unknown_type"), EntityType::Unknown);
484 }
485}