Skip to main content

scirs2_text/
question_answering.rs

1//! Extractive question answering over plain-text documents.
2//!
3//! This module provides rule-based, overlap-based, and TF-IDF-based extractive
4//! QA that locates the most plausible answer span for a given question inside a
5//! context document – all without external ML weights.
6//!
7//! # Example
8//!
9//! ```rust
10//! use scirs2_text::question_answering::{QAContext, QAMethod, extract_answer};
11//!
12//! let doc = "Marie Curie was born in Warsaw on 7 November 1867. \
13//!            She won two Nobel Prizes.";
14//! let spans = extract_answer("When was Marie Curie born?", doc, 3);
15//! assert!(!spans.is_empty());
16//! ```
17
18use crate::error::{Result, TextError};
19use std::collections::{HashMap, HashSet};
20
21// ---------------------------------------------------------------------------
22// Public types
23// ---------------------------------------------------------------------------
24
25/// Method used to score candidate answer spans.
26#[derive(Debug, Clone, PartialEq)]
27pub enum QAMethod {
28    /// Rank candidate sentences by TF-IDF overlap with the question, then
29    /// extract the best named-entity-aware span within the top sentence.
30    TfIdf,
31    /// Rank by exact bigram overlap between the question and each sentence.
32    BigramOverlap,
33    /// Use cosine similarity of averaged word-embedding vectors (when
34    /// embeddings are provided to `QAContext`).
35    WordEmbeddingMatch,
36}
37
38/// A single question type category inferred from the question wh-word.
39#[derive(Debug, Clone, PartialEq)]
40pub enum QuestionType {
41    /// "Who" – expects a person / organisation
42    Who,
43    /// "What" – general entity or definition
44    What,
45    /// "When" – temporal expression
46    When,
47    /// "Where" – location
48    Where,
49    /// "Why" – reason / cause
50    Why,
51    /// "How" – manner / quantity
52    How,
53    /// Could not be classified
54    Unknown,
55}
56
57/// An extracted answer span with provenance and confidence.
58#[derive(Debug, Clone)]
59pub struct AnswerSpan {
60    /// Byte-level start offset in the *original* document text.
61    pub start: usize,
62    /// Byte-level end offset (exclusive) in the *original* document text.
63    pub end: usize,
64    /// The answer text itself.
65    pub text: String,
66    /// Confidence score in [0, 1].
67    pub confidence: f64,
68    /// Which sentence (0-indexed) the span was drawn from.
69    pub sentence_index: usize,
70}
71
72/// A tokenised, indexed context document ready for repeated QA queries.
73pub struct QAContext {
74    /// Original document text.
75    pub text: String,
76    /// Sentences with their byte offsets in the original text.
77    sentences: Vec<SentenceRecord>,
78    /// Optional word embeddings (word → fixed-length float vector).
79    embeddings: Option<HashMap<String, Vec<f64>>>,
80}
81
82// ---------------------------------------------------------------------------
83// Internal helpers
84// ---------------------------------------------------------------------------
85
86/// A sentence together with its position inside the document.
87#[derive(Debug, Clone)]
88struct SentenceRecord {
89    text: String,
90    start: usize,
91    tokens: Vec<String>,
92}
93
94/// Tokenise a text slice into lowercase alphabetic/numeric tokens.
95fn simple_tokenize(text: &str) -> Vec<String> {
96    text.split(|c: char| !c.is_alphanumeric())
97        .filter(|s| !s.is_empty())
98        .map(|s| s.to_lowercase())
99        .collect()
100}
101
102/// English stop-words used to suppress common terms in TF-IDF.
103fn stop_words() -> HashSet<&'static str> {
104    [
105        "a", "an", "the", "is", "are", "was", "were", "be", "been", "being", "have", "has", "had",
106        "do", "does", "did", "will", "would", "shall", "should", "may", "might", "must", "can",
107        "could", "to", "of", "in", "on", "at", "by", "for", "with", "about", "against", "between",
108        "into", "through", "during", "before", "after", "above", "below", "from", "up", "down",
109        "out", "off", "over", "under", "again", "further", "then", "once", "and", "but", "or",
110        "nor", "so", "yet", "both", "either", "neither", "not", "only", "own", "same", "than",
111        "too", "very", "just", "i", "you", "he", "she", "it", "we", "they", "me", "him", "her",
112        "us", "them", "my", "your", "his", "its", "our", "their", "what", "which", "who", "whom",
113        "this", "that", "these", "those", "am", "s", "t",
114    ]
115    .iter()
116    .cloned()
117    .collect()
118}
119
120/// Classify the question type from the first recognised wh-word.
121pub fn classify_question(question: &str) -> QuestionType {
122    let lower = question.to_lowercase();
123    // Check word boundaries roughly by looking for whole-word occurrences.
124    for word in lower.split_whitespace() {
125        let w = word.trim_matches(|c: char| !c.is_alphabetic());
126        match w {
127            "who" | "whose" | "whom" => return QuestionType::Who,
128            "when" => return QuestionType::When,
129            "where" => return QuestionType::Where,
130            "why" => return QuestionType::Why,
131            "how" => return QuestionType::How,
132            "what" | "which" => return QuestionType::What,
133            _ => {}
134        }
135    }
136    QuestionType::Unknown
137}
138
139// ---------------------------------------------------------------------------
140// QAContext implementation
141// ---------------------------------------------------------------------------
142
143impl QAContext {
144    /// Build a `QAContext` from a plain document string.
145    pub fn new(text: &str) -> Self {
146        let sentences = Self::split_sentences(text);
147        Self {
148            text: text.to_string(),
149            sentences,
150            embeddings: None,
151        }
152    }
153
154    /// Attach word embeddings for `WordEmbeddingMatch` queries.
155    pub fn with_embeddings(mut self, embeddings: HashMap<String, Vec<f64>>) -> Self {
156        self.embeddings = Some(embeddings);
157        self
158    }
159
160    // ------------------------------------------------------------------
161    // Sentence splitting
162    // ------------------------------------------------------------------
163
164    fn split_sentences(text: &str) -> Vec<SentenceRecord> {
165        let mut records = Vec::new();
166        let mut start = 0usize;
167        let bytes = text.as_bytes();
168        let len = bytes.len();
169
170        while start < len {
171            // Find the next sentence boundary: '.', '?', '!' followed by
172            // whitespace or end-of-string.
173            let mut end = start;
174            while end < len {
175                let b = bytes[end];
176                if b == b'.' || b == b'?' || b == b'!' {
177                    // Make sure we are at a valid char boundary before slicing.
178                    end += 1;
179                    // Consume any trailing whitespace so the next sentence
180                    // starts cleanly.
181                    while end < len && bytes[end] == b' ' {
182                        end += 1;
183                    }
184                    break;
185                }
186                end += 1;
187            }
188
189            let raw = text[start..end].trim();
190            if !raw.is_empty() {
191                records.push(SentenceRecord {
192                    text: raw.to_string(),
193                    start,
194                    tokens: simple_tokenize(raw),
195                });
196            }
197            start = end;
198        }
199
200        records
201    }
202
203    // ------------------------------------------------------------------
204    // TF-IDF ranking helpers
205    // ------------------------------------------------------------------
206
207    /// Build IDF table over the stored sentences.
208    fn build_idf(&self) -> HashMap<String, f64> {
209        let n = self.sentences.len() as f64;
210        let mut df: HashMap<String, usize> = HashMap::new();
211        for sent in &self.sentences {
212            let unique: HashSet<&String> = sent.tokens.iter().collect();
213            for tok in unique {
214                *df.entry(tok.clone()).or_insert(0) += 1;
215            }
216        }
217        df.into_iter()
218            .map(|(t, d)| (t, (1.0 + n / (1.0 + d as f64)).ln()))
219            .collect()
220    }
221
222    /// Score a single sentence against query tokens using TF-IDF cosine.
223    fn tfidf_score(
224        query_tokens: &[String],
225        sentence: &SentenceRecord,
226        idf: &HashMap<String, f64>,
227        stops: &HashSet<&'static str>,
228    ) -> f64 {
229        // Build TF for the sentence
230        let mut sent_tf: HashMap<String, f64> = HashMap::new();
231        for tok in &sentence.tokens {
232            *sent_tf.entry(tok.clone()).or_insert(0.0) += 1.0;
233        }
234        let sent_len = sentence.tokens.len().max(1) as f64;
235
236        let mut dot = 0.0f64;
237        let mut q_norm = 0.0f64;
238        let mut s_norm = 0.0f64;
239
240        // Compute vectors over query tokens only (sparse dot product)
241        let query_freq: HashMap<&String, f64> = {
242            let mut m = HashMap::new();
243            for t in query_tokens {
244                if !stops.contains(t.as_str()) {
245                    *m.entry(t).or_insert(0.0) += 1.0;
246                }
247            }
248            m
249        };
250
251        for (tok, &qf) in &query_freq {
252            let idf_val = idf.get(*tok).copied().unwrap_or(0.0);
253            let q_tfidf = (qf / query_tokens.len().max(1) as f64) * idf_val;
254            let s_tfidf = sent_tf.get(*tok).copied().unwrap_or(0.0) / sent_len * idf_val;
255            dot += q_tfidf * s_tfidf;
256            q_norm += q_tfidf * q_tfidf;
257            s_norm += s_tfidf * s_tfidf;
258        }
259
260        if q_norm > 0.0 && s_norm > 0.0 {
261            dot / (q_norm.sqrt() * s_norm.sqrt())
262        } else {
263            0.0
264        }
265    }
266
267    // ------------------------------------------------------------------
268    // Bigram overlap
269    // ------------------------------------------------------------------
270
271    fn bigram_overlap_score(query_tokens: &[String], sentence: &SentenceRecord) -> f64 {
272        if query_tokens.len() < 2 || sentence.tokens.len() < 2 {
273            // Fall back to unigram overlap
274            let q_set: HashSet<&String> = query_tokens.iter().collect();
275            let s_set: HashSet<&String> = sentence.tokens.iter().collect();
276            let inter = q_set.intersection(&s_set).count();
277            return inter as f64 / q_set.len().max(1) as f64;
278        }
279
280        let q_bigrams: HashSet<(&String, &String)> =
281            query_tokens.windows(2).map(|w| (&w[0], &w[1])).collect();
282        let s_bigrams: HashSet<(&String, &String)> =
283            sentence.tokens.windows(2).map(|w| (&w[0], &w[1])).collect();
284
285        let inter = q_bigrams.intersection(&s_bigrams).count();
286        let union = q_bigrams.union(&s_bigrams).count();
287        if union == 0 {
288            0.0
289        } else {
290            inter as f64 / union as f64
291        }
292    }
293
294    // ------------------------------------------------------------------
295    // Embedding cosine
296    // ------------------------------------------------------------------
297
298    fn embedding_score(
299        query_tokens: &[String],
300        sentence: &SentenceRecord,
301        embeddings: &HashMap<String, Vec<f64>>,
302    ) -> f64 {
303        let q_vec = Self::average_embedding(query_tokens, embeddings);
304        let s_vec = Self::average_embedding(&sentence.tokens, embeddings);
305        match (q_vec, s_vec) {
306            (Some(q), Some(s)) => cosine_sim(&q, &s),
307            _ => 0.0,
308        }
309    }
310
311    fn average_embedding(
312        tokens: &[String],
313        embeddings: &HashMap<String, Vec<f64>>,
314    ) -> Option<Vec<f64>> {
315        let vecs: Vec<&Vec<f64>> = tokens.iter().filter_map(|t| embeddings.get(t)).collect();
316        if vecs.is_empty() {
317            return None;
318        }
319        let dim = vecs[0].len();
320        let mut sum = vec![0.0f64; dim];
321        for v in &vecs {
322            for (s, &x) in sum.iter_mut().zip(v.iter()) {
323                *s += x;
324            }
325        }
326        let n = vecs.len() as f64;
327        Some(sum.into_iter().map(|x| x / n).collect())
328    }
329
330    // ------------------------------------------------------------------
331    // Named-entity-aware span extraction
332    // ------------------------------------------------------------------
333
334    /// Extract the best sub-span from a sentence for the given question type.
335    ///
336    /// Applies lightweight regex heuristics to prefer date/time/person/location
337    /// phrases when the question type suggests one.
338    fn extract_best_span(
339        sentence: &SentenceRecord,
340        q_type: &QuestionType,
341        doc_text: &str,
342    ) -> Option<(usize, usize, f64)> {
343        // Patterns: (regex_literal, applies_to_question_types, bonus)
344        let patterns: &[(&str, &[QuestionType], f64)] = &[
345            // Dates / years
346            (
347                r"\b(?:January|February|March|April|May|June|July|August|September|October|November|December)\s+\d{1,2},?\s+\d{4}\b",
348                &[QuestionType::When],
349                0.3,
350            ),
351            (
352                r"\b\d{1,2}\s+(?:January|February|March|April|May|June|July|August|September|October|November|December)\s+\d{4}\b",
353                &[QuestionType::When],
354                0.3,
355            ),
356            (r"\b\d{4}\b", &[QuestionType::When], 0.15),
357            // Capitalised phrases (likely NEs)
358            (
359                r"\b[A-Z][a-z]+(?:\s+[A-Z][a-z]+)+\b",
360                &[QuestionType::Who, QuestionType::Where, QuestionType::What],
361                0.2,
362            ),
363            // Location indicators
364            (
365                r"\bin\s+[A-Z][a-zA-Z]+(?:,\s*[A-Z][a-zA-Z]+)*\b",
366                &[QuestionType::Where],
367                0.25,
368            ),
369        ];
370
371        // Find the sentence's absolute offset in the original doc.
372        let sent_start_in_doc = sentence.start;
373        let sent_end_in_doc = sent_start_in_doc + sentence.text.len();
374        // Clamp to doc bounds to be safe.
375        let sent_end_in_doc = sent_end_in_doc.min(doc_text.len());
376
377        let mut best: Option<(usize, usize, f64)> = None;
378
379        for (pattern_str, qtypes, bonus) in patterns {
380            let applies = qtypes.iter().any(|qt| qt == q_type);
381            if !applies && *q_type != QuestionType::Unknown && *q_type != QuestionType::What {
382                continue;
383            }
384
385            // We use a manual match rather than the regex crate to keep
386            // compile times down – but we DO use the regex crate which is
387            // already a dependency.  Build lazily.
388            if let Ok(re) = regex::Regex::new(pattern_str) {
389                for m in re.find_iter(&sentence.text) {
390                    let abs_start = sent_start_in_doc + m.start();
391                    let abs_end = sent_start_in_doc + m.end();
392                    // Check bounds
393                    if abs_end > sent_end_in_doc {
394                        continue;
395                    }
396                    let score = 0.5 + bonus;
397                    if best.is_none_or(|(_, _, s)| score > s) {
398                        best = Some((abs_start, abs_end, score));
399                    }
400                }
401            }
402        }
403
404        best
405    }
406
407    // ------------------------------------------------------------------
408    // Core QA routine
409    // ------------------------------------------------------------------
410
411    /// Find the single best answer span for `question` using the given method.
412    pub fn find_answer_span(&self, question: &str, method: QAMethod) -> Result<Option<AnswerSpan>> {
413        if self.sentences.is_empty() {
414            return Ok(None);
415        }
416
417        let q_tokens = simple_tokenize(question);
418        if q_tokens.is_empty() {
419            return Err(TextError::InvalidInput(
420                "Question must not be empty".to_string(),
421            ));
422        }
423
424        let q_type = classify_question(question);
425        let stops = stop_words();
426
427        // Score every sentence.
428        let mut scored: Vec<(usize, f64)> = self
429            .sentences
430            .iter()
431            .enumerate()
432            .map(|(i, sent)| {
433                let base = match &method {
434                    QAMethod::TfIdf => {
435                        let idf = self.build_idf();
436                        Self::tfidf_score(&q_tokens, sent, &idf, &stops)
437                    }
438                    QAMethod::BigramOverlap => Self::bigram_overlap_score(&q_tokens, sent),
439                    QAMethod::WordEmbeddingMatch => {
440                        if let Some(emb) = &self.embeddings {
441                            Self::embedding_score(&q_tokens, sent, emb)
442                        } else {
443                            // Fallback to bigram overlap if no embeddings.
444                            Self::bigram_overlap_score(&q_tokens, sent)
445                        }
446                    }
447                };
448                (i, base)
449            })
450            .collect();
451
452        scored.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
453
454        let (best_idx, base_score) = scored[0];
455        if base_score <= 0.0 {
456            return Ok(None);
457        }
458
459        let best_sent = &self.sentences[best_idx];
460
461        // Try to extract a named-entity-aware sub-span.
462        let span = Self::extract_best_span(best_sent, &q_type, &self.text);
463
464        let answer = if let Some((start, end, ne_bonus)) = span {
465            if start < end && end <= self.text.len() {
466                AnswerSpan {
467                    start,
468                    end,
469                    text: self.text[start..end].to_string(),
470                    confidence: (base_score + ne_bonus).min(1.0),
471                    sentence_index: best_idx,
472                }
473            } else {
474                // Fall back to full sentence
475                let start = best_sent.start;
476                let end = (best_sent.start + best_sent.text.len()).min(self.text.len());
477                AnswerSpan {
478                    start,
479                    end,
480                    text: best_sent.text.clone(),
481                    confidence: base_score,
482                    sentence_index: best_idx,
483                }
484            }
485        } else {
486            // Return the whole sentence as the answer span.
487            let start = best_sent.start;
488            let end = (best_sent.start + best_sent.text.len()).min(self.text.len());
489            AnswerSpan {
490                start,
491                end,
492                text: best_sent.text.clone(),
493                confidence: base_score,
494                sentence_index: best_idx,
495            }
496        };
497
498        Ok(Some(answer))
499    }
500
501    // ------------------------------------------------------------------
502    // Multi-answer extraction
503    // ------------------------------------------------------------------
504
505    /// Rank all sentences and return the top-`k` answer spans.
506    pub fn find_top_k(
507        &self,
508        question: &str,
509        method: QAMethod,
510        k: usize,
511    ) -> Result<Vec<AnswerSpan>> {
512        if self.sentences.is_empty() || k == 0 {
513            return Ok(Vec::new());
514        }
515
516        let q_tokens = simple_tokenize(question);
517        if q_tokens.is_empty() {
518            return Err(TextError::InvalidInput(
519                "Question must not be empty".to_string(),
520            ));
521        }
522
523        let q_type = classify_question(question);
524        let stops = stop_words();
525        let idf = self.build_idf();
526
527        let mut scored: Vec<(usize, f64)> = self
528            .sentences
529            .iter()
530            .enumerate()
531            .map(|(i, sent)| {
532                let base = match &method {
533                    QAMethod::TfIdf => Self::tfidf_score(&q_tokens, sent, &idf, &stops),
534                    QAMethod::BigramOverlap => Self::bigram_overlap_score(&q_tokens, sent),
535                    QAMethod::WordEmbeddingMatch => {
536                        if let Some(emb) = &self.embeddings {
537                            Self::embedding_score(&q_tokens, sent, emb)
538                        } else {
539                            Self::bigram_overlap_score(&q_tokens, sent)
540                        }
541                    }
542                };
543                (i, base)
544            })
545            .collect();
546
547        scored.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
548
549        let mut answers = Vec::new();
550        for (idx, base_score) in scored.into_iter().take(k) {
551            if base_score <= 0.0 {
552                break;
553            }
554            let sent = &self.sentences[idx];
555            let span = Self::extract_best_span(sent, &q_type, &self.text);
556
557            let answer = if let Some((start, end, bonus)) = span {
558                if start < end && end <= self.text.len() {
559                    AnswerSpan {
560                        start,
561                        end,
562                        text: self.text[start..end].to_string(),
563                        confidence: (base_score + bonus).min(1.0),
564                        sentence_index: idx,
565                    }
566                } else {
567                    let s = sent.start;
568                    let e = (sent.start + sent.text.len()).min(self.text.len());
569                    AnswerSpan {
570                        start: s,
571                        end: e,
572                        text: sent.text.clone(),
573                        confidence: base_score,
574                        sentence_index: idx,
575                    }
576                }
577            } else {
578                let s = sent.start;
579                let e = (sent.start + sent.text.len()).min(self.text.len());
580                AnswerSpan {
581                    start: s,
582                    end: e,
583                    text: sent.text.clone(),
584                    confidence: base_score,
585                    sentence_index: idx,
586                }
587            };
588
589            answers.push(answer);
590        }
591
592        Ok(answers)
593    }
594}
595
596// ---------------------------------------------------------------------------
597// Free-standing helpers
598// ---------------------------------------------------------------------------
599
600/// Rank sentences in `context_sentences` by TF-IDF similarity to `query_tokens`.
601///
602/// Returns a vector of scores parallel to `context_sentences`.
603pub fn tf_idf_similarity(query_tokens: &[String], context_sentences: &[Vec<String>]) -> Vec<f64> {
604    if context_sentences.is_empty() || query_tokens.is_empty() {
605        return vec![0.0; context_sentences.len()];
606    }
607
608    let n = context_sentences.len() as f64;
609    let stops = stop_words();
610
611    // Build IDF
612    let mut df: HashMap<String, usize> = HashMap::new();
613    for sent in context_sentences {
614        let unique: HashSet<&String> = sent.iter().collect();
615        for tok in unique {
616            *df.entry(tok.clone()).or_insert(0) += 1;
617        }
618    }
619    let idf: HashMap<String, f64> = df
620        .into_iter()
621        .map(|(t, d)| (t, (1.0 + n / (1.0 + d as f64)).ln()))
622        .collect();
623
624    context_sentences
625        .iter()
626        .map(|sent| {
627            let record = SentenceRecord {
628                text: sent.join(" "),
629                start: 0,
630                tokens: sent.clone(),
631            };
632            QAContext::tfidf_score(query_tokens, &record, &idf, &stops)
633        })
634        .collect()
635}
636
637/// Convenience function: extract up to `top_k` answers from `document` for
638/// `question` using TF-IDF by default.
639pub fn extract_answer(question: &str, document: &str, top_k: usize) -> Vec<AnswerSpan> {
640    let ctx = QAContext::new(document);
641    ctx.find_top_k(question, QAMethod::TfIdf, top_k)
642        .unwrap_or_default()
643}
644
645/// Cosine similarity between two float vectors (helper shared with this module).
646fn cosine_sim(a: &[f64], b: &[f64]) -> f64 {
647    if a.len() != b.len() || a.is_empty() {
648        return 0.0;
649    }
650    let dot: f64 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
651    let na: f64 = a.iter().map(|x| x * x).sum::<f64>().sqrt();
652    let nb: f64 = b.iter().map(|x| x * x).sum::<f64>().sqrt();
653    if na > 0.0 && nb > 0.0 {
654        dot / (na * nb)
655    } else {
656        0.0
657    }
658}
659
660// ---------------------------------------------------------------------------
661// Tests
662// ---------------------------------------------------------------------------
663
664#[cfg(test)]
665mod tests {
666    use super::*;
667
668    const DOC: &str = "Marie Curie was born in Warsaw on 7 November 1867. \
669         She conducted pioneering research on radioactivity. \
670         In 1903 she won the Nobel Prize in Physics. \
671         She also won the Nobel Prize in Chemistry in 1911. \
672         Paris became her home after she moved from Poland.";
673
674    #[test]
675    fn test_classify_question() {
676        assert_eq!(
677            classify_question("Who discovered radium?"),
678            QuestionType::Who
679        );
680        assert_eq!(classify_question("When was she born?"), QuestionType::When);
681        assert_eq!(
682            classify_question("Where did she live?"),
683            QuestionType::Where
684        );
685        assert_eq!(classify_question("How did she win?"), QuestionType::How);
686        assert_eq!(
687            classify_question("What is radioactivity?"),
688            QuestionType::What
689        );
690        assert_eq!(
691            classify_question("Why is science important?"),
692            QuestionType::Why
693        );
694    }
695
696    #[test]
697    fn test_extract_answer_tfidf() {
698        let answers = extract_answer("When was Marie Curie born?", DOC, 3);
699        assert!(!answers.is_empty());
700        // The birth-date sentence should be highly ranked.
701        assert!(
702            answers[0].text.to_lowercase().contains("born")
703                || answers[0].text.contains("1867")
704                || answers[0].text.to_lowercase().contains("november")
705        );
706        assert!(answers[0].confidence > 0.0);
707    }
708
709    #[test]
710    fn test_find_answer_span_bigram() {
711        let ctx = QAContext::new(DOC);
712        let ans = ctx
713            .find_answer_span(
714                "What prize did she win in Physics?",
715                QAMethod::BigramOverlap,
716            )
717            .expect("QA failed");
718        assert!(ans.is_some());
719        let span = ans.expect("should have a span");
720        assert!(
721            span.text.to_lowercase().contains("physics")
722                || span.text.contains("1903")
723                || span.text.to_lowercase().contains("prize")
724        );
725    }
726
727    #[test]
728    fn test_find_top_k() {
729        let ctx = QAContext::new(DOC);
730        let answers = ctx
731            .find_top_k("Nobel Prize", QAMethod::TfIdf, 2)
732            .expect("top-k failed");
733        assert!(answers.len() <= 2);
734    }
735
736    #[test]
737    fn test_embedding_fallback_without_embeddings() {
738        let ctx = QAContext::new(DOC);
739        // Without embeddings, WordEmbeddingMatch falls back to bigram overlap.
740        let ans = ctx
741            .find_answer_span("Where did she live?", QAMethod::WordEmbeddingMatch)
742            .expect("QA failed");
743        // Just verifies it does not panic and returns a valid result.
744        let _ = ans;
745    }
746
747    #[test]
748    fn test_tf_idf_similarity_standalone() {
749        let query = simple_tokenize("Nobel Prize winner");
750        let sentences: Vec<Vec<String>> = vec![
751            simple_tokenize("She won the Nobel Prize in Physics"),
752            simple_tokenize("Marie Curie was born in Warsaw"),
753            simple_tokenize("Nobel Prize in Chemistry was awarded"),
754        ];
755        let scores = tf_idf_similarity(&query, &sentences);
756        assert_eq!(scores.len(), 3);
757        // Nobel-prize sentences should score higher than the birth sentence.
758        assert!(scores[0] > scores[1] || scores[2] > scores[1]);
759    }
760
761    #[test]
762    fn test_answer_span_bounds() {
763        let ctx = QAContext::new(DOC);
764        let answers = ctx
765            .find_top_k("Marie Curie radioactivity", QAMethod::TfIdf, 5)
766            .expect("failed");
767        for ans in answers {
768            // start and end must be valid byte offsets inside the document.
769            assert!(ans.start <= ans.end);
770            assert!(ans.end <= DOC.len());
771            assert_eq!(ans.text, DOC[ans.start..ans.end]);
772        }
773    }
774
775    #[test]
776    fn test_empty_document() {
777        let ctx = QAContext::new("");
778        let ans = ctx
779            .find_answer_span("Who is here?", QAMethod::TfIdf)
780            .expect("QA failed");
781        assert!(ans.is_none());
782    }
783}