Skip to main content

shodh_memory/query_parsing/
rule_based.rs

1//! Rule-Based Query Parser
2//!
3//! Wraps the existing YAKE/regex-based query_parser module into the QueryParser trait.
4//! This is the default, battle-tested implementation.
5
6use super::parser_trait::*;
7use crate::memory::query_parser as legacy;
8use chrono::{DateTime, Datelike, NaiveDate, Utc};
9
10/// Rule-based query parser using YAKE, regex patterns, and heuristics
11pub struct RuleBasedParser {
12    _private: (),
13}
14
15impl RuleBasedParser {
16    /// Create a new rule-based parser
17    pub fn new() -> Self {
18        Self { _private: () }
19    }
20}
21
22impl Default for RuleBasedParser {
23    fn default() -> Self {
24        Self::new()
25    }
26}
27
28impl QueryParser for RuleBasedParser {
29    fn parse(&self, query: &str, context_date: Option<DateTime<Utc>>) -> ParsedQuery {
30        // Use the existing analyze_query function
31        let analysis = legacy::analyze_query(query);
32
33        // Detect temporal intent
34        let temporal_intent = legacy::detect_temporal_intent(query);
35        let has_temporal = !matches!(temporal_intent, legacy::TemporalIntent::None);
36
37        // Extract temporal references
38        let temporal_refs = legacy::extract_temporal_refs(query);
39        let relative_refs = extract_relative_refs(query, context_date);
40
41        // Resolve relative dates if context provided
42        let resolved_dates = if context_date.is_some() {
43            relative_refs.iter().filter_map(|r| r.resolved).collect()
44        } else {
45            Vec::new()
46        };
47
48        // Extract absolute dates from temporal refs
49        let absolute_dates: Vec<NaiveDate> = temporal_refs.refs.iter().map(|r| r.date).collect();
50
51        // Convert focal entities
52        let entities: Vec<Entity> = analysis
53            .focal_entities
54            .iter()
55            .map(|e| Entity {
56                text: e.text.clone(),
57                stem: e.stem.clone(),
58                entity_type: detect_entity_type(&e.text),
59                ic_weight: e.ic_weight,
60                negated: e.negated,
61            })
62            .collect();
63
64        // Convert relational context to events
65        let events: Vec<Event> = analysis
66            .relational_context
67            .iter()
68            .map(|r| Event {
69                text: r.text.clone(),
70                stem: r.stem.clone(),
71                ic_weight: r.ic_weight,
72            })
73            .collect();
74
75        // Get modifiers
76        let modifiers: Vec<String> = analysis
77            .discriminative_modifiers
78            .iter()
79            .map(|m| m.text.clone())
80            .collect();
81
82        // Check for attribute query
83        let (is_attribute_query, attribute) = match legacy::detect_attribute_query(query) {
84            Some(aq) => (
85                true,
86                Some(AttributeQuery {
87                    entity: aq.entity.clone(),
88                    attribute: aq.attribute.clone(),
89                    synonyms: aq.attribute_synonyms.clone(),
90                }),
91            ),
92            None => (false, None),
93        };
94
95        // Get IC weights
96        let ic_weights = analysis.to_ic_weights();
97
98        ParsedQuery {
99            original: query.to_string(),
100            entities,
101            events,
102            modifiers,
103            temporal: TemporalInfo {
104                has_temporal_intent: has_temporal,
105                intent: convert_temporal_intent(temporal_intent),
106                relative_refs,
107                resolved_dates,
108                absolute_dates,
109            },
110            is_attribute_query,
111            attribute,
112            compounds: analysis.compound_nouns.clone(),
113            ic_weights,
114            confidence: 0.85, // Rule-based has consistent but not perfect accuracy
115        }
116    }
117
118    fn name(&self) -> &'static str {
119        "RuleBasedParser"
120    }
121}
122
123/// Convert legacy TemporalIntent to new format
124fn convert_temporal_intent(intent: legacy::TemporalIntent) -> TemporalIntent {
125    match intent {
126        legacy::TemporalIntent::WhenQuestion => TemporalIntent::WhenQuestion,
127        legacy::TemporalIntent::SpecificTime => TemporalIntent::SpecificTime,
128        legacy::TemporalIntent::Ordering => TemporalIntent::Ordering,
129        legacy::TemporalIntent::Duration => TemporalIntent::Duration,
130        legacy::TemporalIntent::None => TemporalIntent::None,
131    }
132}
133
134/// Detect entity type from text (basic heuristics)
135fn detect_entity_type(text: &str) -> EntityType {
136    let text_lower = text.to_lowercase();
137
138    // Check if it starts with capital (likely proper noun / person)
139    if text
140        .chars()
141        .next()
142        .map(|c| c.is_uppercase())
143        .unwrap_or(false)
144        && !text.chars().all(|c| c.is_uppercase())
145    {
146        // Common person name patterns
147        let first_word = text.split_whitespace().next().unwrap_or("");
148        if is_likely_person_name(first_word) {
149            return EntityType::Person;
150        }
151    }
152
153    // Time-related words
154    if [
155        "morning",
156        "evening",
157        "afternoon",
158        "night",
159        "day",
160        "week",
161        "month",
162        "year",
163    ]
164    .iter()
165    .any(|t| text_lower.contains(t))
166    {
167        return EntityType::Time;
168    }
169
170    // Event-related words
171    if [
172        "meeting", "party", "wedding", "concert", "race", "trip", "vacation",
173    ]
174    .iter()
175    .any(|e| text_lower.contains(e))
176    {
177        return EntityType::Event;
178    }
179
180    EntityType::Unknown
181}
182
183/// Check if a word is likely a person's name
184fn is_likely_person_name(word: &str) -> bool {
185    // Simple heuristic: capitalized, not a common noun, reasonable length
186    if word.len() < 2 || word.len() > 20 {
187        return false;
188    }
189
190    let first_char = word.chars().next().unwrap_or(' ');
191    if !first_char.is_uppercase() {
192        return false;
193    }
194
195    // Common non-person capitalized words
196    let non_names = [
197        "The",
198        "This",
199        "That",
200        "What",
201        "When",
202        "Where",
203        "Who",
204        "How",
205        "Why",
206        "Monday",
207        "Tuesday",
208        "Wednesday",
209        "Thursday",
210        "Friday",
211        "Saturday",
212        "Sunday",
213        "January",
214        "February",
215        "March",
216        "April",
217        "May",
218        "June",
219        "July",
220        "August",
221        "September",
222        "October",
223        "November",
224        "December",
225    ];
226
227    !non_names.iter().any(|n| n.eq_ignore_ascii_case(word))
228}
229
230/// Extract relative time references from query
231fn extract_relative_refs(query: &str, context_date: Option<DateTime<Utc>>) -> Vec<RelativeTimeRef> {
232    let query_lower = query.to_lowercase();
233    let mut refs = Vec::new();
234
235    // Patterns for relative time references
236    let patterns = [
237        ("last year", TimeDirection::Past, TimeUnit::Year, 1),
238        ("last month", TimeDirection::Past, TimeUnit::Month, 1),
239        ("last week", TimeDirection::Past, TimeUnit::Week, 1),
240        ("last saturday", TimeDirection::Past, TimeUnit::Day, -1), // Special handling
241        ("last sunday", TimeDirection::Past, TimeUnit::Day, -1),
242        ("last friday", TimeDirection::Past, TimeUnit::Day, -1),
243        ("yesterday", TimeDirection::Past, TimeUnit::Day, 1),
244        ("next year", TimeDirection::Future, TimeUnit::Year, 1),
245        ("next month", TimeDirection::Future, TimeUnit::Month, 1),
246        ("next week", TimeDirection::Future, TimeUnit::Week, 1),
247        ("tomorrow", TimeDirection::Future, TimeUnit::Day, 1),
248        ("two weeks ago", TimeDirection::Past, TimeUnit::Week, 2),
249        ("three weeks ago", TimeDirection::Past, TimeUnit::Week, 3),
250        ("two months ago", TimeDirection::Past, TimeUnit::Month, 2),
251        ("a week ago", TimeDirection::Past, TimeUnit::Week, 1),
252        ("a month ago", TimeDirection::Past, TimeUnit::Month, 1),
253        ("a year ago", TimeDirection::Past, TimeUnit::Year, 1),
254    ];
255
256    for (pattern, direction, unit, offset) in patterns {
257        if query_lower.contains(pattern) {
258            let resolved = context_date
259                .and_then(|ctx| resolve_relative_date(ctx, direction, unit, offset, pattern));
260
261            refs.push(RelativeTimeRef {
262                text: pattern.to_string(),
263                resolved,
264                direction,
265                unit,
266                offset,
267            });
268        }
269    }
270
271    refs
272}
273
274/// Resolve a relative date reference to an absolute date
275fn resolve_relative_date(
276    context: DateTime<Utc>,
277    direction: TimeDirection,
278    unit: TimeUnit,
279    offset: i32,
280    pattern: &str,
281) -> Option<NaiveDate> {
282    use chrono::Duration;
283
284    let base_date = context.date_naive();
285
286    // Handle "last <weekday>" specially
287    if pattern.starts_with("last ") && pattern.len() > 5 {
288        let weekday_str = &pattern[5..];
289        if let Some(target_weekday) = parse_weekday(weekday_str) {
290            // Find the most recent occurrence of this weekday before context date
291            let current_weekday = base_date.weekday();
292            let days_back = (current_weekday.num_days_from_monday() as i32
293                - target_weekday.num_days_from_monday() as i32
294                + 7)
295                % 7;
296            let days_back = if days_back == 0 { 7 } else { days_back };
297            return Some(base_date - Duration::days(days_back as i64));
298        }
299    }
300
301    let result = match (direction, unit) {
302        (TimeDirection::Past, TimeUnit::Day) => base_date - Duration::days(offset as i64),
303        (TimeDirection::Past, TimeUnit::Week) => base_date - Duration::weeks(offset as i64),
304        (TimeDirection::Past, TimeUnit::Month) => {
305            // Approximate month subtraction
306            let months_back = offset as i64;
307            let new_month = (base_date.month() as i64 - months_back - 1).rem_euclid(12) + 1;
308            let year_offset = (base_date.month() as i64 - months_back - 1).div_euclid(12);
309            NaiveDate::from_ymd_opt(
310                base_date.year() + year_offset as i32,
311                new_month as u32,
312                base_date.day().min(28),
313            )?
314        }
315        (TimeDirection::Past, TimeUnit::Year) => NaiveDate::from_ymd_opt(
316            base_date.year() - offset,
317            base_date.month(),
318            base_date.day(),
319        )?,
320        (TimeDirection::Future, TimeUnit::Day) => base_date + Duration::days(offset as i64),
321        (TimeDirection::Future, TimeUnit::Week) => base_date + Duration::weeks(offset as i64),
322        (TimeDirection::Future, TimeUnit::Month) => {
323            let months_forward = offset as i64;
324            let new_month = (base_date.month() as i64 + months_forward - 1).rem_euclid(12) + 1;
325            let year_offset = (base_date.month() as i64 + months_forward - 1).div_euclid(12);
326            NaiveDate::from_ymd_opt(
327                base_date.year() + year_offset as i32,
328                new_month as u32,
329                base_date.day().min(28),
330            )?
331        }
332        (TimeDirection::Future, TimeUnit::Year) => NaiveDate::from_ymd_opt(
333            base_date.year() + offset,
334            base_date.month(),
335            base_date.day(),
336        )?,
337        _ => return None,
338    };
339
340    Some(result)
341}
342
343/// Parse a weekday string
344fn parse_weekday(s: &str) -> Option<chrono::Weekday> {
345    use chrono::Weekday;
346    match s.to_lowercase().as_str() {
347        "monday" | "mon" => Some(Weekday::Mon),
348        "tuesday" | "tue" | "tues" => Some(Weekday::Tue),
349        "wednesday" | "wed" => Some(Weekday::Wed),
350        "thursday" | "thu" | "thur" | "thurs" => Some(Weekday::Thu),
351        "friday" | "fri" => Some(Weekday::Fri),
352        "saturday" | "sat" => Some(Weekday::Sat),
353        "sunday" | "sun" => Some(Weekday::Sun),
354        _ => None,
355    }
356}
357
358#[cfg(test)]
359mod tests {
360    use super::*;
361    use chrono::TimeZone;
362
363    #[test]
364    fn test_parse_basic_query() {
365        let parser = RuleBasedParser::new();
366        let parsed = parser.parse("What is Caroline's relationship status?", None);
367
368        assert!(parsed.is_attribute_query);
369        assert!(!parsed.entities.is_empty());
370    }
371
372    #[test]
373    fn test_parse_temporal_query() {
374        let parser = RuleBasedParser::new();
375        let parsed = parser.parse("When did Melanie paint a sunrise?", None);
376
377        assert!(parsed.temporal.has_temporal_intent);
378        assert_eq!(parsed.temporal.intent, TemporalIntent::WhenQuestion);
379    }
380
381    #[test]
382    fn test_resolve_last_year() {
383        let parser = RuleBasedParser::new();
384        let context = chrono::Utc.with_ymd_and_hms(2023, 5, 8, 12, 0, 0).unwrap();
385        let parsed = parser.parse("Melanie painted it last year", Some(context));
386
387        assert!(!parsed.temporal.relative_refs.is_empty());
388        let ref_ = &parsed.temporal.relative_refs[0];
389        assert_eq!(ref_.text, "last year");
390        assert_eq!(
391            ref_.resolved,
392            Some(NaiveDate::from_ymd_opt(2022, 5, 8).unwrap())
393        );
394    }
395}