Skip to main content

scirs2_text/
dialog.rs

1//! Dialogue system components
2//!
3//! Provides rule-based building blocks for constructing conversational agents:
4//!
5//! - [`DialogState`] – conversation context, entity map, and slot map.
6//! - [`IntentClassifier`] – pattern-matching intent recognition.
7//! - [`EntityExtractor`] – rule-based extraction of dates, numbers, names, and locations.
8//! - [`SlotFiller`] – template-based slot value extraction.
9//! - [`DialogPolicy`] – simple state-machine dialog management.
10//! - [`DialogAct`] – enum of high-level dialog acts.
11//! - [`response_template`] – generate natural-language responses from acts and slots.
12//!
13//! All components are 100% Pure Rust with no external NLP models.
14
15use crate::error::{Result, TextError};
16use std::collections::HashMap;
17
18// ---------------------------------------------------------------------------
19// DialogAct
20// ---------------------------------------------------------------------------
21
22/// High-level dialog act categories used by the dialog policy and response generator.
23#[derive(Debug, Clone, PartialEq, Eq, Hash)]
24pub enum DialogAct {
25    /// Opening greeting.
26    Greet,
27    /// Request for information or an action from the user.
28    Request,
29    /// Inform the user of a fact or status.
30    Inform,
31    /// Ask the user to confirm a value or action.
32    Confirm,
33    /// Reject a proposed value or action.
34    Reject,
35    /// Closing farewell.
36    Goodbye,
37    /// System does not understand the utterance.
38    Unknown,
39}
40
41impl std::fmt::Display for DialogAct {
42    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
43        let label = match self {
44            Self::Greet => "GREET",
45            Self::Request => "REQUEST",
46            Self::Inform => "INFORM",
47            Self::Confirm => "CONFIRM",
48            Self::Reject => "REJECT",
49            Self::Goodbye => "GOODBYE",
50            Self::Unknown => "UNKNOWN",
51        };
52        write!(f, "{label}")
53    }
54}
55
56// ---------------------------------------------------------------------------
57// DialogState
58// ---------------------------------------------------------------------------
59
60/// Complete state of an ongoing dialogue.
61///
62/// Carries the raw utterance history, a map of extracted entities, and a map
63/// of domain-specific slot values filled so far.
64///
65/// # Example
66///
67/// ```rust
68/// use scirs2_text::dialog::{DialogState, DialogAct};
69///
70/// let mut state = DialogState::new();
71/// state.add_utterance("Hello, I want to book a flight to Paris.");
72/// state.set_slot("destination", "Paris");
73/// assert_eq!(state.get_slot("destination"), Some("Paris"));
74/// ```
75#[derive(Debug, Clone, Default)]
76pub struct DialogState {
77    /// Raw utterance history (user turns only).
78    pub context: Vec<String>,
79    /// Named entities extracted so far.  Maps entity type label → value.
80    pub entities: HashMap<String, String>,
81    /// Domain-specific slot values.
82    pub slots: HashMap<String, String>,
83    /// Current dialog act (last recognised).
84    pub current_act: Option<DialogAct>,
85    /// Number of turns completed.
86    pub turn_count: usize,
87}
88
89impl DialogState {
90    /// Create an empty `DialogState`.
91    pub fn new() -> Self {
92        Self::default()
93    }
94
95    /// Append a user utterance to the context history.
96    pub fn add_utterance(&mut self, utterance: &str) {
97        self.context.push(utterance.to_string());
98        self.turn_count += 1;
99    }
100
101    /// Set a slot value.
102    pub fn set_slot(&mut self, slot: &str, value: &str) {
103        self.slots.insert(slot.to_string(), value.to_string());
104    }
105
106    /// Get a slot value by name.
107    pub fn get_slot(&self, slot: &str) -> Option<&str> {
108        self.slots.get(slot).map(|s| s.as_str())
109    }
110
111    /// Set an entity value.
112    pub fn set_entity(&mut self, entity_type: &str, value: &str) {
113        self.entities
114            .insert(entity_type.to_string(), value.to_string());
115    }
116
117    /// Get an entity value by type label.
118    pub fn get_entity(&self, entity_type: &str) -> Option<&str> {
119        self.entities.get(entity_type).map(|s| s.as_str())
120    }
121
122    /// Reset all state (slots, entities, context).
123    pub fn reset(&mut self) {
124        *self = Self::default();
125    }
126
127    /// Return the last utterance, or `None` if the context is empty.
128    pub fn last_utterance(&self) -> Option<&str> {
129        self.context.last().map(|s| s.as_str())
130    }
131
132    /// Return `true` if the required `slots` are all filled.
133    pub fn slots_filled(&self, required: &[&str]) -> bool {
134        required.iter().all(|s| self.slots.contains_key(*s))
135    }
136}
137
138// ---------------------------------------------------------------------------
139// IntentClassifier
140// ---------------------------------------------------------------------------
141
142/// Pattern-matching intent classifier.
143///
144/// Each intent is represented by a name and a list of keyword patterns.  An
145/// utterance is matched by counting how many patterns contain at least one
146/// word from the utterance; the intent with the most matches wins.
147///
148/// # Example
149///
150/// ```rust
151/// use scirs2_text::dialog::{IntentClassifier, classify_intent};
152///
153/// let mut clf = IntentClassifier::new();
154/// clf.add_intent("book_flight", vec!["book", "flight", "fly", "ticket"]);
155/// clf.add_intent("cancel", vec!["cancel", "undo", "remove"]);
156///
157/// let (intent, confidence) = classify_intent("I want to book a flight", &clf);
158/// assert_eq!(intent, "book_flight");
159/// assert!(confidence > 0.0);
160/// ```
161#[derive(Debug, Clone, Default)]
162pub struct IntentClassifier {
163    /// Registered intent names in registration order.
164    pub intents: Vec<String>,
165    /// Patterns for each intent (parallel to `intents`).
166    ///
167    /// Each element is a list of keyword strings that evidence the intent.
168    pub patterns: Vec<Vec<String>>,
169}
170
171impl IntentClassifier {
172    /// Create an empty classifier.
173    pub fn new() -> Self {
174        Self::default()
175    }
176
177    /// Register a new intent with its keyword patterns.
178    ///
179    /// All pattern strings are lower-cased and stored as-is; matching is
180    /// performed case-insensitively.
181    pub fn add_intent(&mut self, name: &str, patterns: Vec<&str>) {
182        self.intents.push(name.to_string());
183        self.patterns
184            .push(patterns.into_iter().map(|p| p.to_lowercase()).collect());
185    }
186
187    /// Return the number of registered intents.
188    pub fn len(&self) -> usize {
189        self.intents.len()
190    }
191
192    /// Return `true` when no intents are registered.
193    pub fn is_empty(&self) -> bool {
194        self.intents.is_empty()
195    }
196}
197
198/// Classify an utterance using `classifier`.
199///
200/// Returns `(intent_name, confidence)`.  Confidence is the normalised fraction
201/// of the winning intent's patterns that matched at least one token in the
202/// utterance.  When no intent is registered or no patterns match, returns
203/// `("unknown", 0.0)`.
204pub fn classify_intent(utterance: &str, classifier: &IntentClassifier) -> (String, f64) {
205    if classifier.intents.is_empty() {
206        return ("unknown".to_string(), 0.0);
207    }
208
209    let utt_lower = utterance.to_lowercase();
210    let utt_tokens: Vec<&str> = utt_lower
211        .split(|c: char| !c.is_alphanumeric())
212        .filter(|t| !t.is_empty())
213        .collect();
214
215    let mut best_intent = "unknown".to_string();
216    let mut best_score = 0.0_f64;
217    let mut best_matches = 0usize;
218
219    for (intent_idx, patterns) in classifier.patterns.iter().enumerate() {
220        if patterns.is_empty() {
221            continue;
222        }
223        let total = patterns.len();
224        let matches = patterns
225            .iter()
226            .filter(|pat| {
227                // A pattern matches if any token in the utterance starts with or equals the pattern.
228                utt_tokens.iter().any(|tok| {
229                    *tok == pat.as_str()
230                        || tok.starts_with(pat.as_str())
231                        || utt_lower.contains(pat.as_str())
232                })
233            })
234            .count();
235
236        let score = matches as f64 / total as f64;
237        if matches > best_matches || (matches == best_matches && score > best_score) {
238            best_matches = matches;
239            best_score = score;
240            best_intent = classifier.intents[intent_idx].clone();
241        }
242    }
243
244    if best_matches == 0 {
245        ("unknown".to_string(), 0.0)
246    } else {
247        (best_intent, best_score)
248    }
249}
250
251// ---------------------------------------------------------------------------
252// EntityExtractor
253// ---------------------------------------------------------------------------
254
255/// Recognised entity type labels returned by [`EntityExtractor`].
256#[derive(Debug, Clone, PartialEq, Eq, Hash)]
257pub enum EntityKind {
258    /// A date expression (e.g. "January 15" or "15/01/2024").
259    Date,
260    /// A cardinal number or decimal.
261    Number,
262    /// A probable proper name (heuristic: consecutive capitalised tokens).
263    Name,
264    /// A location keyword match (heuristic: following "in", "to", "from" etc.).
265    Location,
266    /// Custom user-defined entity kind.
267    Custom(String),
268}
269
270impl std::fmt::Display for EntityKind {
271    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
272        match self {
273            Self::Date => write!(f, "DATE"),
274            Self::Number => write!(f, "NUMBER"),
275            Self::Name => write!(f, "NAME"),
276            Self::Location => write!(f, "LOCATION"),
277            Self::Custom(s) => write!(f, "CUSTOM({})", s),
278        }
279    }
280}
281
282/// A single entity extracted from an utterance.
283#[derive(Debug, Clone)]
284pub struct ExtractedEntity {
285    /// The matched text.
286    pub text: String,
287    /// Entity kind.
288    pub kind: EntityKind,
289    /// Character start offset (byte index).
290    pub start: usize,
291    /// Character end offset (byte index, exclusive).
292    pub end: usize,
293}
294
295/// Rule-based entity extractor for dialogue systems.
296///
297/// Extracts dates, numbers, potential names (capitalised phrases), and
298/// location hints without relying on trained models.
299#[derive(Debug, Default)]
300pub struct EntityExtractor {
301    /// Additional custom gazetteer entries (lowercase term → EntityKind).
302    gazetteer: Vec<(String, EntityKind)>,
303}
304
305impl EntityExtractor {
306    /// Create a new `EntityExtractor` with default rules.
307    pub fn new() -> Self {
308        Self::default()
309    }
310
311    /// Add a custom gazetteer entry.  Matching is case-insensitive.
312    pub fn add_gazetteer_entry(&mut self, term: &str, kind: EntityKind) {
313        self.gazetteer.push((term.to_lowercase(), kind));
314    }
315
316    /// Extract entities from `utterance`.
317    ///
318    /// The extraction order is: gazetteer, dates, numbers, names (consecutive
319    /// capitalised tokens), location hints.  Overlapping spans are not
320    /// deduplicated; callers should post-process if needed.
321    pub fn extract(&self, utterance: &str) -> Vec<ExtractedEntity> {
322        let mut entities: Vec<ExtractedEntity> = Vec::new();
323
324        self.extract_gazetteer(utterance, &mut entities);
325        self.extract_dates(utterance, &mut entities);
326        self.extract_numbers(utterance, &mut entities);
327        self.extract_names(utterance, &mut entities);
328        self.extract_locations(utterance, &mut entities);
329
330        // Sort by start position.
331        entities.sort_by_key(|e| e.start);
332        entities
333    }
334
335    /// Match gazetteer entries (exact, case-insensitive substring search).
336    fn extract_gazetteer(&self, text: &str, out: &mut Vec<ExtractedEntity>) {
337        let text_lower = text.to_lowercase();
338        for (term, kind) in &self.gazetteer {
339            let mut search_start = 0usize;
340            while let Some(offset) = text_lower[search_start..].find(term.as_str()) {
341                let abs_start = search_start + offset;
342                let abs_end = abs_start + term.len();
343                out.push(ExtractedEntity {
344                    text: text[abs_start..abs_end].to_string(),
345                    kind: kind.clone(),
346                    start: abs_start,
347                    end: abs_end,
348                });
349                search_start = abs_end;
350            }
351        }
352    }
353
354    /// Extract date expressions using simple patterns.
355    ///
356    /// Recognises:
357    /// - `DD/MM/YYYY` or `MM/DD/YYYY` (slash-separated numbers).
358    /// - Month name followed by an optional day number.
359    fn extract_dates(&self, text: &str, out: &mut Vec<ExtractedEntity>) {
360        // Slash-delimited dates: digits / digits (/ digits)?
361        let mut i = 0;
362        let bytes = text.as_bytes();
363        let len = bytes.len();
364
365        while i < len {
366            // Skip non-digit characters.
367            if !bytes[i].is_ascii_digit() {
368                i += 1;
369                continue;
370            }
371            // Consume digit run.
372            let start = i;
373            while i < len && bytes[i].is_ascii_digit() {
374                i += 1;
375            }
376            // Look for /digits(/digits)?
377            if i < len && bytes[i] == b'/' {
378                let slash1 = i;
379                i += 1;
380                let seg2_start = i;
381                while i < len && bytes[i].is_ascii_digit() {
382                    i += 1;
383                }
384                if i > seg2_start {
385                    let end = if i < len && bytes[i] == b'/' {
386                        i += 1; // consume second slash
387                        let seg3_start = i;
388                        while i < len && bytes[i].is_ascii_digit() {
389                            i += 1;
390                        }
391                        if i > seg3_start {
392                            i
393                        } else {
394                            // Backtrack the slash.
395                            i = slash1 + 1 + (i - slash1 - 1);
396                            slash1
397                        }
398                    } else {
399                        i
400                    };
401                    let matched = &text[start..end];
402                    if matched.contains('/') {
403                        out.push(ExtractedEntity {
404                            text: matched.to_string(),
405                            kind: EntityKind::Date,
406                            start,
407                            end,
408                        });
409                    }
410                }
411                continue;
412            }
413        }
414
415        // Month-name patterns.
416        let months = [
417            "january",
418            "february",
419            "march",
420            "april",
421            "may",
422            "june",
423            "july",
424            "august",
425            "september",
426            "october",
427            "november",
428            "december",
429            "jan",
430            "feb",
431            "mar",
432            "apr",
433            "jun",
434            "jul",
435            "aug",
436            "sep",
437            "oct",
438            "nov",
439            "dec",
440        ];
441        let text_lower = text.to_lowercase();
442        for month in &months {
443            let mut search_pos = 0usize;
444            while let Some(offset) = text_lower[search_pos..].find(month) {
445                let abs_start = search_pos + offset;
446                let abs_end = abs_start + month.len();
447
448                // Make sure it's a word boundary (not mid-word).
449                let before_ok =
450                    abs_start == 0 || !text.as_bytes()[abs_start - 1].is_ascii_alphanumeric();
451                let after_ok =
452                    abs_end >= text.len() || !text.as_bytes()[abs_end].is_ascii_alphanumeric();
453
454                if before_ok && after_ok {
455                    // Optionally consume a following number (day).
456                    let mut end = abs_end;
457                    let rest = &text[abs_end..];
458                    let after_space: &str = rest.trim_start_matches(' ');
459                    let day_len: usize = after_space
460                        .chars()
461                        .take_while(|c| c.is_ascii_digit())
462                        .map(|c| c.len_utf8())
463                        .sum();
464                    if day_len > 0 {
465                        let spaces = rest.len() - after_space.len();
466                        end += spaces + day_len;
467                    }
468
469                    // Also try to consume a following 4-digit year.
470                    let rest2 = &text[end..];
471                    let after_space2: &str = rest2.trim_start_matches(' ');
472                    let year_candidate: String = after_space2
473                        .chars()
474                        .take_while(|c| c.is_ascii_digit())
475                        .collect();
476                    if year_candidate.len() == 4 {
477                        let spaces2 = rest2.len() - after_space2.len();
478                        end += spaces2 + 4;
479                    }
480
481                    out.push(ExtractedEntity {
482                        text: text[abs_start..end].to_string(),
483                        kind: EntityKind::Date,
484                        start: abs_start,
485                        end,
486                    });
487                }
488
489                search_pos = abs_end;
490            }
491        }
492    }
493
494    /// Extract cardinal numbers (integers and decimals).
495    fn extract_numbers(&self, text: &str, out: &mut Vec<ExtractedEntity>) {
496        let mut i = 0;
497        let bytes = text.as_bytes();
498        let len = bytes.len();
499
500        while i < len {
501            if !bytes[i].is_ascii_digit() {
502                i += 1;
503                continue;
504            }
505            let start = i;
506            while i < len && bytes[i].is_ascii_digit() {
507                i += 1;
508            }
509            // Optionally consume a decimal part.
510            if i < len && (bytes[i] == b'.' || bytes[i] == b',') {
511                let sep = i;
512                i += 1;
513                let frac_start = i;
514                while i < len && bytes[i].is_ascii_digit() {
515                    i += 1;
516                }
517                if i == frac_start {
518                    // Nothing after separator — backtrack.
519                    i = sep;
520                }
521            }
522            // Ensure not part of a slash-date already extracted (heuristic: skip
523            // entries that overlap with a date entity).
524            let end = i;
525            let candidate = &text[start..end];
526            let already_date = out
527                .iter()
528                .any(|e| e.kind == EntityKind::Date && e.start <= start && e.end >= end);
529            if !already_date {
530                out.push(ExtractedEntity {
531                    text: candidate.to_string(),
532                    kind: EntityKind::Number,
533                    start,
534                    end,
535                });
536            }
537        }
538    }
539
540    /// Extract probable proper names: runs of two or more consecutive
541    /// capitalised tokens (words beginning with an uppercase letter) that are
542    /// not preceded by a sentence-initial position.
543    fn extract_names(&self, text: &str, out: &mut Vec<ExtractedEntity>) {
544        // Split on whitespace while tracking byte offsets.
545        let mut word_spans: Vec<(usize, usize, &str)> = Vec::new();
546        let mut pos = 0usize;
547        for word in text.split_ascii_whitespace() {
548            // Find the word in text starting from pos.
549            if let Some(offset) = text[pos..].find(word) {
550                let start = pos + offset;
551                let end = start + word.len();
552                word_spans.push((start, end, word));
553                pos = end;
554            }
555        }
556
557        // Find runs of capitalised tokens (heuristic proper-name detection).
558        let mut i = 0usize;
559        while i < word_spans.len() {
560            let (start, _, word) = word_spans[i];
561            // Strip leading punctuation.
562            let first_alpha = word.chars().find(|c| c.is_alphabetic());
563            let is_cap = first_alpha.map(|c| c.is_uppercase()).unwrap_or(false);
564
565            if !is_cap {
566                i += 1;
567                continue;
568            }
569
570            // Consume the run.
571            let run_start = start;
572            let mut j = i;
573            while j < word_spans.len() {
574                let (_, _, w) = word_spans[j];
575                let fc = w.chars().find(|c| c.is_alphabetic());
576                if fc.map(|c| c.is_uppercase()).unwrap_or(false) {
577                    j += 1;
578                } else {
579                    break;
580                }
581            }
582
583            // Only emit runs of 2+ tokens.
584            if j - i >= 2 {
585                let (_, run_end, _) = word_spans[j - 1];
586                let name_text = &text[run_start..run_end];
587                out.push(ExtractedEntity {
588                    text: name_text.to_string(),
589                    kind: EntityKind::Name,
590                    start: run_start,
591                    end: run_end,
592                });
593                i = j;
594            } else {
595                i += 1;
596            }
597        }
598    }
599
600    /// Extract location hints using positional keywords ("in", "to", "from",
601    /// "at", "near", "between") followed by a capitalised token.
602    fn extract_locations(&self, text: &str, out: &mut Vec<ExtractedEntity>) {
603        let location_triggers = ["in ", "to ", "from ", "at ", "near ", "between "];
604        for trigger in &location_triggers {
605            let text_lower = text.to_lowercase();
606            let mut search_pos = 0usize;
607            while let Some(offset) = text_lower[search_pos..].find(trigger) {
608                let abs_trigger_start = search_pos + offset;
609                let candidate_start = abs_trigger_start + trigger.len();
610                if candidate_start >= text.len() {
611                    break;
612                }
613
614                // Consume the capitalised word (or phrase of capitalised words).
615                let rest = &text[candidate_start..];
616                let mut loc_end = candidate_start;
617                for word in rest.split_ascii_whitespace() {
618                    let first_char = word
619                        .trim_matches(|c: char| !c.is_alphabetic())
620                        .chars()
621                        .next();
622                    if first_char.map(|c| c.is_uppercase()).unwrap_or(false) {
623                        loc_end += word.len() + 1; // +1 for the space
624                    } else {
625                        break;
626                    }
627                }
628                // Trim trailing separator.
629                let loc_end = loc_end.min(text.len());
630                if loc_end > candidate_start {
631                    let loc_text = text[candidate_start..loc_end].trim().to_string();
632                    if !loc_text.is_empty() {
633                        let actual_end = candidate_start + loc_text.len();
634                        out.push(ExtractedEntity {
635                            text: loc_text,
636                            kind: EntityKind::Location,
637                            start: candidate_start,
638                            end: actual_end,
639                        });
640                    }
641                }
642
643                search_pos = candidate_start;
644            }
645        }
646    }
647}
648
649// ---------------------------------------------------------------------------
650// SlotFiller
651// ---------------------------------------------------------------------------
652
653/// Template-based slot filler.
654///
655/// A slot template is a string like `"fly from {origin} to {destination}"`.
656/// The filler extracts the values of `{origin}` and `{destination}` by
657/// matching the literal parts of the template against the utterance.
658///
659/// # Example
660///
661/// ```rust
662/// use scirs2_text::dialog::SlotFiller;
663///
664/// let sf = SlotFiller::new();
665/// let slots = sf.fill("book a flight from London to Paris",
666///                      "flight from {origin} to {destination}").unwrap();
667/// assert_eq!(slots.get("origin").map(|s| s.as_str()), Some("London"));
668/// assert_eq!(slots.get("destination").map(|s| s.as_str()), Some("Paris"));
669/// ```
670#[derive(Debug, Default, Clone)]
671pub struct SlotFiller;
672
673impl SlotFiller {
674    /// Create a new `SlotFiller`.
675    pub fn new() -> Self {
676        Self
677    }
678
679    /// Fill slots defined by `template` from `utterance`.
680    ///
681    /// Template syntax: literal text with `{slot_name}` placeholders.
682    ///
683    /// Returns a map of slot names to their extracted values, or an error if
684    /// the template cannot be parsed.
685    pub fn fill(&self, utterance: &str, template: &str) -> Result<HashMap<String, String>> {
686        // Parse the template into alternating literals and slot names.
687        let parts = parse_template(template)?;
688        let mut slots: HashMap<String, String> = HashMap::new();
689
690        // Try to match the utterance against the template parts.
691        let utt_lower = utterance.to_lowercase();
692        let mut search_pos = 0usize;
693
694        let n = parts.len();
695        let mut pi = 0usize;
696
697        while pi < n {
698            match &parts[pi] {
699                TemplatePart::Literal(lit) => {
700                    let lit_lower = lit.to_lowercase();
701                    if lit_lower.is_empty() {
702                        pi += 1;
703                        continue;
704                    }
705                    if let Some(offset) = utt_lower[search_pos..].find(lit_lower.as_str()) {
706                        search_pos += offset + lit.len();
707                        pi += 1;
708                    } else {
709                        // Literal not found; stop matching.
710                        break;
711                    }
712                }
713                TemplatePart::Slot(slot_name) => {
714                    // The slot value runs up to the next literal (or end of string).
715                    let next_literal: Option<&str> = parts[pi + 1..].iter().find_map(|p| {
716                        if let TemplatePart::Literal(s) = p {
717                            if !s.is_empty() {
718                                Some(s.as_str())
719                            } else {
720                                None
721                            }
722                        } else {
723                            None
724                        }
725                    });
726
727                    let value_end = if let Some(next_lit) = next_literal {
728                        let next_lit_lower = next_lit.to_lowercase();
729                        utt_lower[search_pos..]
730                            .find(next_lit_lower.as_str())
731                            .map(|o| search_pos + o)
732                            .unwrap_or(utt_lower.len())
733                    } else {
734                        utt_lower.len()
735                    };
736
737                    let raw_value = utterance[search_pos..value_end].trim().to_string();
738                    if !raw_value.is_empty() {
739                        slots.insert(slot_name.clone(), raw_value);
740                    }
741                    search_pos = value_end;
742                    pi += 1;
743                }
744            }
745        }
746
747        Ok(slots)
748    }
749}
750
751/// Internal template part.
752#[derive(Debug)]
753enum TemplatePart {
754    Literal(String),
755    Slot(String),
756}
757
758/// Parse a template string into a vector of [`TemplatePart`] items.
759fn parse_template(template: &str) -> Result<Vec<TemplatePart>> {
760    let mut parts: Vec<TemplatePart> = Vec::new();
761    let mut chars = template.char_indices().peekable();
762    let mut buf = String::new();
763
764    while let Some((_, ch)) = chars.next() {
765        if ch == '{' {
766            // Flush literal buffer.
767            if !buf.is_empty() {
768                parts.push(TemplatePart::Literal(std::mem::take(&mut buf)));
769            }
770            // Read slot name until '}'.
771            let mut slot_name = String::new();
772            let mut closed = false;
773            for (_, sc) in chars.by_ref() {
774                if sc == '}' {
775                    closed = true;
776                    break;
777                }
778                slot_name.push(sc);
779            }
780            if !closed {
781                return Err(TextError::InvalidInput(
782                    "Unclosed '{' in slot template".to_string(),
783                ));
784            }
785            if slot_name.is_empty() {
786                return Err(TextError::InvalidInput(
787                    "Empty slot name '{}' in template".to_string(),
788                ));
789            }
790            parts.push(TemplatePart::Slot(slot_name));
791        } else {
792            buf.push(ch);
793        }
794    }
795
796    if !buf.is_empty() {
797        parts.push(TemplatePart::Literal(buf));
798    }
799
800    Ok(parts)
801}
802
803// ---------------------------------------------------------------------------
804// DialogPolicy
805// ---------------------------------------------------------------------------
806
807/// States of the built-in dialog state machine.
808#[derive(Debug, Clone, PartialEq, Eq, Hash)]
809pub enum PolicyState {
810    /// Initial state (no turns yet).
811    Initial,
812    /// Greeting has been exchanged.
813    Greeted,
814    /// System is collecting slot values.
815    SlotCollection,
816    /// All required slots are filled; awaiting confirmation.
817    PendingConfirmation,
818    /// Transaction confirmed and executed.
819    Confirmed,
820    /// Dialog has ended.
821    Ended,
822}
823
824impl std::fmt::Display for PolicyState {
825    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
826        let s = match self {
827            Self::Initial => "INITIAL",
828            Self::Greeted => "GREETED",
829            Self::SlotCollection => "SLOT_COLLECTION",
830            Self::PendingConfirmation => "PENDING_CONFIRMATION",
831            Self::Confirmed => "CONFIRMED",
832            Self::Ended => "ENDED",
833        };
834        write!(f, "{s}")
835    }
836}
837
838/// A recommended system action emitted by the policy.
839#[derive(Debug, Clone)]
840pub struct PolicyAction {
841    /// The recommended dialog act.
842    pub act: DialogAct,
843    /// Which slot to request next (if act == Request).
844    pub request_slot: Option<String>,
845    /// Slots to confirm (if act == Confirm).
846    pub confirm_slots: Vec<String>,
847}
848
849/// Simple state-machine dialog policy.
850///
851/// Drives a slot-filling task dialog through greeting → slot collection →
852/// confirmation → end.
853///
854/// # Example
855///
856/// ```rust
857/// use scirs2_text::dialog::{DialogPolicy, DialogState, DialogAct};
858///
859/// let mut policy = DialogPolicy::new(vec!["origin".to_string(), "destination".to_string()]);
860/// let mut state = DialogState::new();
861///
862/// let action = policy.next_action(&state);
863/// assert_eq!(action.act, DialogAct::Greet);
864/// ```
865pub struct DialogPolicy {
866    /// Slots that must be filled before confirming.
867    required_slots: Vec<String>,
868    /// Current state-machine state.
869    policy_state: PolicyState,
870}
871
872impl DialogPolicy {
873    /// Create a new `DialogPolicy` requiring the specified slots.
874    pub fn new(required_slots: Vec<String>) -> Self {
875        Self {
876            required_slots,
877            policy_state: PolicyState::Initial,
878        }
879    }
880
881    /// Current state-machine state.
882    pub fn state(&self) -> &PolicyState {
883        &self.policy_state
884    }
885
886    /// Compute the next recommended action given the current `dialog_state`.
887    ///
888    /// This also advances the internal state machine.
889    pub fn next_action(&mut self, dialog_state: &DialogState) -> PolicyAction {
890        match self.policy_state {
891            PolicyState::Initial => {
892                self.policy_state = PolicyState::Greeted;
893                PolicyAction {
894                    act: DialogAct::Greet,
895                    request_slot: None,
896                    confirm_slots: Vec::new(),
897                }
898            }
899            PolicyState::Greeted | PolicyState::SlotCollection => {
900                // Find the first unfilled required slot.
901                let missing = self
902                    .required_slots
903                    .iter()
904                    .find(|s| !dialog_state.slots.contains_key(*s))
905                    .cloned();
906
907                if let Some(slot) = missing {
908                    self.policy_state = PolicyState::SlotCollection;
909                    PolicyAction {
910                        act: DialogAct::Request,
911                        request_slot: Some(slot),
912                        confirm_slots: Vec::new(),
913                    }
914                } else {
915                    // All slots filled.
916                    self.policy_state = PolicyState::PendingConfirmation;
917                    PolicyAction {
918                        act: DialogAct::Confirm,
919                        request_slot: None,
920                        confirm_slots: self.required_slots.clone(),
921                    }
922                }
923            }
924            PolicyState::PendingConfirmation => {
925                // Check the last utterance for yes/no.
926                let confirmed = dialog_state
927                    .last_utterance()
928                    .map(|u| {
929                        let ul = u.to_lowercase();
930                        ul.contains("yes")
931                            || ul.contains("correct")
932                            || ul.contains("right")
933                            || ul.contains("confirm")
934                    })
935                    .unwrap_or(false);
936
937                if confirmed {
938                    self.policy_state = PolicyState::Confirmed;
939                    PolicyAction {
940                        act: DialogAct::Inform,
941                        request_slot: None,
942                        confirm_slots: Vec::new(),
943                    }
944                } else {
945                    // Assume rejection / restart slot collection.
946                    self.policy_state = PolicyState::SlotCollection;
947                    PolicyAction {
948                        act: DialogAct::Reject,
949                        request_slot: None,
950                        confirm_slots: Vec::new(),
951                    }
952                }
953            }
954            PolicyState::Confirmed => {
955                self.policy_state = PolicyState::Ended;
956                PolicyAction {
957                    act: DialogAct::Goodbye,
958                    request_slot: None,
959                    confirm_slots: Vec::new(),
960                }
961            }
962            PolicyState::Ended => PolicyAction {
963                act: DialogAct::Goodbye,
964                request_slot: None,
965                confirm_slots: Vec::new(),
966            },
967        }
968    }
969
970    /// Reset the policy to its initial state.
971    pub fn reset(&mut self) {
972        self.policy_state = PolicyState::Initial;
973    }
974}
975
976// ---------------------------------------------------------------------------
977// response_template
978// ---------------------------------------------------------------------------
979
980/// Generate a natural-language response for the given `act` and `slots`.
981///
982/// Slot values are substituted into the response where the placeholder
983/// `{slot_name}` appears.  Unknown slot references are left as-is.
984///
985/// # Example
986///
987/// ```rust
988/// use scirs2_text::dialog::{response_template, DialogAct};
989/// use std::collections::HashMap;
990///
991/// let mut slots = HashMap::new();
992/// slots.insert("destination".to_string(), "Paris".to_string());
993///
994/// let response = response_template(DialogAct::Inform, &slots);
995/// assert!(!response.is_empty());
996/// ```
997pub fn response_template(act: DialogAct, slots: &HashMap<String, String>) -> String {
998    let template = match act {
999        DialogAct::Greet => "Hello! How can I help you today?".to_string(),
1000        DialogAct::Request => {
1001            // Pick the first slot that has a value in the slot map as a hint,
1002            // otherwise fall back to a generic request.
1003            let slot_hint = slots
1004                .keys()
1005                .next()
1006                .map(|s| s.as_str())
1007                .unwrap_or("information");
1008            format!("Could you please provide the {slot_hint}?")
1009        }
1010        DialogAct::Inform => {
1011            if slots.is_empty() {
1012                "I have processed your request successfully.".to_string()
1013            } else {
1014                let details: Vec<String> = slots.iter().map(|(k, v)| format!("{k}: {v}")).collect();
1015                format!("Here is the information: {}.", details.join(", "))
1016            }
1017        }
1018        DialogAct::Confirm => {
1019            if slots.is_empty() {
1020                "Can you please confirm your request?".to_string()
1021            } else {
1022                let details: Vec<String> =
1023                    slots.iter().map(|(k, v)| format!("{k} = {v}")).collect();
1024                format!(
1025                    "Just to confirm, you would like to proceed with {}. Is that correct?",
1026                    details.join(", ")
1027                )
1028            }
1029        }
1030        DialogAct::Reject => {
1031            "I'm sorry, that does not match what we have. Let's try again.".to_string()
1032        }
1033        DialogAct::Goodbye => "Thank you for using our service. Goodbye!".to_string(),
1034        DialogAct::Unknown => {
1035            "I'm sorry, I didn't understand that. Could you rephrase?".to_string()
1036        }
1037    };
1038
1039    // Substitute slot values into the template.
1040    let mut result = template;
1041    for (key, value) in slots {
1042        let placeholder = format!("{{{key}}}");
1043        result = result.replace(&placeholder, value);
1044    }
1045    result
1046}
1047
1048// ---------------------------------------------------------------------------
1049// Tests
1050// ---------------------------------------------------------------------------
1051
1052#[cfg(test)]
1053mod tests {
1054    use super::*;
1055
1056    // -- DialogState --
1057
1058    #[test]
1059    fn test_dialog_state_slots() {
1060        let mut state = DialogState::new();
1061        state.set_slot("destination", "Paris");
1062        assert_eq!(state.get_slot("destination"), Some("Paris"));
1063        assert_eq!(state.get_slot("origin"), None);
1064    }
1065
1066    #[test]
1067    fn test_dialog_state_entities() {
1068        let mut state = DialogState::new();
1069        state.set_entity("DATE", "January 15");
1070        assert_eq!(state.get_entity("DATE"), Some("January 15"));
1071    }
1072
1073    #[test]
1074    fn test_dialog_state_utterances() {
1075        let mut state = DialogState::new();
1076        assert!(state.last_utterance().is_none());
1077        state.add_utterance("Hello");
1078        assert_eq!(state.last_utterance(), Some("Hello"));
1079        state.add_utterance("Goodbye");
1080        assert_eq!(state.last_utterance(), Some("Goodbye"));
1081        assert_eq!(state.turn_count, 2);
1082    }
1083
1084    #[test]
1085    fn test_dialog_state_slots_filled() {
1086        let mut state = DialogState::new();
1087        state.set_slot("a", "1");
1088        state.set_slot("b", "2");
1089        assert!(state.slots_filled(&["a", "b"]));
1090        assert!(!state.slots_filled(&["a", "b", "c"]));
1091    }
1092
1093    #[test]
1094    fn test_dialog_state_reset() {
1095        let mut state = DialogState::new();
1096        state.set_slot("x", "y");
1097        state.add_utterance("hello");
1098        state.reset();
1099        assert!(state.slots.is_empty());
1100        assert!(state.context.is_empty());
1101        assert_eq!(state.turn_count, 0);
1102    }
1103
1104    // -- IntentClassifier --
1105
1106    #[test]
1107    fn test_classify_intent_basic() {
1108        let mut clf = IntentClassifier::new();
1109        clf.add_intent("book_flight", vec!["book", "flight", "fly", "ticket"]);
1110        clf.add_intent("cancel", vec!["cancel", "undo", "delete"]);
1111
1112        let (intent, conf) = classify_intent("I want to book a flight", &clf);
1113        assert_eq!(intent, "book_flight");
1114        assert!(conf > 0.0);
1115    }
1116
1117    #[test]
1118    fn test_classify_intent_unknown() {
1119        let clf = IntentClassifier::new();
1120        let (intent, conf) = classify_intent("hello", &clf);
1121        assert_eq!(intent, "unknown");
1122        assert_eq!(conf, 0.0);
1123    }
1124
1125    #[test]
1126    fn test_classify_intent_no_match() {
1127        let mut clf = IntentClassifier::new();
1128        clf.add_intent("book_flight", vec!["book", "flight"]);
1129        let (intent, conf) = classify_intent("tell me the weather", &clf);
1130        assert_eq!(intent, "unknown");
1131        assert_eq!(conf, 0.0);
1132    }
1133
1134    #[test]
1135    fn test_classify_intent_case_insensitive() {
1136        let mut clf = IntentClassifier::new();
1137        clf.add_intent("greet", vec!["hello", "hi", "hey"]);
1138        let (intent, _conf) = classify_intent("HELLO there", &clf);
1139        assert_eq!(intent, "greet");
1140    }
1141
1142    // -- EntityExtractor --
1143
1144    #[test]
1145    fn test_extract_numbers() {
1146        let ext = EntityExtractor::new();
1147        let entities = ext.extract("I need 3 tickets and 12.5 kg baggage");
1148        let numbers: Vec<&str> = entities
1149            .iter()
1150            .filter(|e| e.kind == EntityKind::Number)
1151            .map(|e| e.text.as_str())
1152            .collect();
1153        assert!(numbers.contains(&"3"), "Missing '3': {:?}", numbers);
1154        assert!(numbers.contains(&"12.5"), "Missing '12.5': {:?}", numbers);
1155    }
1156
1157    #[test]
1158    fn test_extract_date_month_name() {
1159        let ext = EntityExtractor::new();
1160        let entities = ext.extract("The flight is on January 15");
1161        let dates: Vec<&str> = entities
1162            .iter()
1163            .filter(|e| e.kind == EntityKind::Date)
1164            .map(|e| e.text.as_str())
1165            .collect();
1166        assert!(!dates.is_empty(), "Expected at least one date entity");
1167        assert!(
1168            dates.iter().any(|d| d.contains("January")),
1169            "Expected 'January' in dates: {:?}",
1170            dates
1171        );
1172    }
1173
1174    #[test]
1175    fn test_extract_gazetteer() {
1176        let mut ext = EntityExtractor::new();
1177        ext.add_gazetteer_entry("london", EntityKind::Location);
1178        let entities = ext.extract("I want to travel to London");
1179        let locs: Vec<&str> = entities
1180            .iter()
1181            .filter(|e| e.kind == EntityKind::Location)
1182            .map(|e| e.text.as_str())
1183            .collect();
1184        assert!(!locs.is_empty(), "Expected location entity");
1185    }
1186
1187    // -- SlotFiller --
1188
1189    #[test]
1190    fn test_slot_filler_basic() {
1191        let sf = SlotFiller::new();
1192        let slots = sf
1193            .fill(
1194                "book a flight from London to Paris",
1195                "flight from {origin} to {destination}",
1196            )
1197            .expect("fill should succeed");
1198        assert_eq!(slots.get("origin").map(|s| s.as_str()), Some("London"));
1199        assert_eq!(slots.get("destination").map(|s| s.as_str()), Some("Paris"));
1200    }
1201
1202    #[test]
1203    fn test_slot_filler_single_slot() {
1204        let sf = SlotFiller::new();
1205        let slots = sf
1206            .fill("my name is Alice", "my name is {name}")
1207            .expect("fill should succeed");
1208        assert_eq!(slots.get("name").map(|s| s.as_str()), Some("Alice"));
1209    }
1210
1211    #[test]
1212    fn test_slot_filler_unclosed_brace_error() {
1213        let sf = SlotFiller::new();
1214        let result = sf.fill("hello world", "hello {world");
1215        assert!(result.is_err(), "Expected error for unclosed brace");
1216    }
1217
1218    #[test]
1219    fn test_slot_filler_no_match() {
1220        let sf = SlotFiller::new();
1221        let slots = sf
1222            .fill(
1223                "completely different text",
1224                "flight from {origin} to {destination}",
1225            )
1226            .expect("should not error");
1227        // Slots should be empty since the literal prefix didn't match.
1228        assert!(
1229            !slots.contains_key("origin") && !slots.contains_key("destination"),
1230            "Expected no slots when template does not match"
1231        );
1232    }
1233
1234    // -- DialogPolicy --
1235
1236    #[test]
1237    fn test_policy_initial_greet() {
1238        let mut policy = DialogPolicy::new(vec!["origin".to_string(), "destination".to_string()]);
1239        let state = DialogState::new();
1240        let action = policy.next_action(&state);
1241        assert_eq!(action.act, DialogAct::Greet);
1242    }
1243
1244    #[test]
1245    fn test_policy_requests_missing_slot() {
1246        let mut policy = DialogPolicy::new(vec!["origin".to_string(), "destination".to_string()]);
1247        let mut state = DialogState::new();
1248        policy.next_action(&state); // Greet
1249        let action = policy.next_action(&state);
1250        assert_eq!(action.act, DialogAct::Request);
1251        assert!(action.request_slot.is_some());
1252    }
1253
1254    #[test]
1255    fn test_policy_confirms_when_slots_filled() {
1256        let mut policy = DialogPolicy::new(vec!["origin".to_string(), "destination".to_string()]);
1257        let mut state = DialogState::new();
1258        policy.next_action(&state); // Greet
1259        state.set_slot("origin", "London");
1260        state.set_slot("destination", "Paris");
1261        let action = policy.next_action(&state);
1262        assert_eq!(action.act, DialogAct::Confirm);
1263    }
1264
1265    #[test]
1266    fn test_policy_informs_after_confirmation() {
1267        let mut policy = DialogPolicy::new(vec!["origin".to_string()]);
1268        let mut state = DialogState::new();
1269        policy.next_action(&state); // Greet
1270        state.set_slot("origin", "London");
1271        policy.next_action(&state); // Confirm
1272        state.add_utterance("yes");
1273        let action = policy.next_action(&state);
1274        assert_eq!(action.act, DialogAct::Inform);
1275    }
1276
1277    #[test]
1278    fn test_policy_goodbye_at_end() {
1279        let mut policy = DialogPolicy::new(vec!["origin".to_string()]);
1280        let mut state = DialogState::new();
1281        policy.next_action(&state); // Greet (→ Greeted)
1282        state.set_slot("origin", "London");
1283        policy.next_action(&state); // Confirm (→ PendingConfirmation)
1284        state.add_utterance("yes");
1285        policy.next_action(&state); // Inform (→ Confirmed)
1286        let action = policy.next_action(&state); // Goodbye (→ Ended)
1287        assert_eq!(action.act, DialogAct::Goodbye);
1288    }
1289
1290    #[test]
1291    fn test_policy_reset() {
1292        let mut policy = DialogPolicy::new(vec!["slot_a".to_string()]);
1293        let state = DialogState::new();
1294        policy.next_action(&state);
1295        assert_ne!(*policy.state(), PolicyState::Initial);
1296        policy.reset();
1297        assert_eq!(*policy.state(), PolicyState::Initial);
1298    }
1299
1300    // -- response_template --
1301
1302    #[test]
1303    fn test_response_greet() {
1304        let slots: HashMap<String, String> = HashMap::new();
1305        let response = response_template(DialogAct::Greet, &slots);
1306        assert!(!response.is_empty());
1307        let lower = response.to_lowercase();
1308        assert!(
1309            lower.contains("hello") || lower.contains("hi") || lower.contains("help"),
1310            "Greet response should be a greeting: '{response}'"
1311        );
1312    }
1313
1314    #[test]
1315    fn test_response_inform_with_slots() {
1316        let mut slots: HashMap<String, String> = HashMap::new();
1317        slots.insert("destination".to_string(), "Paris".to_string());
1318        let response = response_template(DialogAct::Inform, &slots);
1319        assert!(
1320            response.contains("Paris"),
1321            "Response should contain 'Paris': '{response}'"
1322        );
1323    }
1324
1325    #[test]
1326    fn test_response_goodbye() {
1327        let slots: HashMap<String, String> = HashMap::new();
1328        let response = response_template(DialogAct::Goodbye, &slots);
1329        let lower = response.to_lowercase();
1330        assert!(
1331            lower.contains("goodbye") || lower.contains("bye") || lower.contains("thank"),
1332            "Goodbye response unexpected: '{response}'"
1333        );
1334    }
1335
1336    #[test]
1337    fn test_response_confirm_with_slots() {
1338        let mut slots: HashMap<String, String> = HashMap::new();
1339        slots.insert("origin".to_string(), "London".to_string());
1340        slots.insert("destination".to_string(), "Tokyo".to_string());
1341        let response = response_template(DialogAct::Confirm, &slots);
1342        assert!(!response.is_empty());
1343    }
1344
1345    #[test]
1346    fn test_response_unknown() {
1347        let slots: HashMap<String, String> = HashMap::new();
1348        let response = response_template(DialogAct::Unknown, &slots);
1349        assert!(!response.is_empty());
1350    }
1351
1352    // -- DialogAct display --
1353
1354    #[test]
1355    fn test_dialog_act_display() {
1356        assert_eq!(DialogAct::Greet.to_string(), "GREET");
1357        assert_eq!(DialogAct::Goodbye.to_string(), "GOODBYE");
1358        assert_eq!(DialogAct::Unknown.to_string(), "UNKNOWN");
1359    }
1360}