scirs2_text/spelling/
error_model.rs

1//! Error model for spelling correction using the noisy channel approach
2//!
3//! This module implements an error model for the noisy channel approach to spelling
4//! correction. It models how words can be transformed into other words through
5//! edit operations like insertion, deletion, substitution, and transposition.
6//!
7//! # Key Components
8//!
9//! - `ErrorModel`: Models the probability of different types of spelling errors
10//! - `EditOp`: Represents edit operations like insertion, deletion, substitution, and transposition
11//!
12//! # Example
13//!
14//! ```
15//! use scirs2_text::spelling::ErrorModel;
16//!
17//! # fn main() {
18//! // Create a default error model
19//! let error_model = ErrorModel::default();
20//!
21//! // Calculate error probability (typo → correct)
22//! let p1 = error_model.error_probability("recieve", "receive");
23//! let p2 = error_model.error_probability("teh", "the");
24//!
25//! // Simple edits have higher probabilities
26//! assert!(p1 > 0.0);
27//! assert!(p2 > 0.0);
28//!
29//! // Identical words have probability 1.0
30//! assert_eq!(error_model.error_probability("word", "word"), 1.0);
31//! # }
32//! ```
33
34use std::cmp::min;
35use std::collections::HashMap;
36
37/// Edit operations for the error model
38#[derive(Debug, Clone, PartialEq, Eq)]
39pub enum EditOp {
40    /// Delete a character
41    Delete(char),
42    /// Insert a character
43    Insert(char),
44    /// Substitute one character for another
45    Substitute(char, char),
46    /// Transpose two adjacent characters
47    Transpose(char, char),
48}
49
50/// Error model for the noisy channel model
51#[derive(Debug, Clone)]
52pub struct ErrorModel {
53    /// Probability of deletion errors
54    pub p_deletion: f64,
55    /// Probability of insertion errors
56    pub p_insertion: f64,
57    /// Probability of substitution errors
58    pub p_substitution: f64,
59    /// Probability of transposition errors
60    pub p_transposition: f64,
61    /// Character confusion matrix
62    _char_confusion: HashMap<(char, char), f64>,
63    /// Maximum edit distance to consider
64    max_edit_distance: usize,
65}
66
67impl Default for ErrorModel {
68    fn default() -> Self {
69        Self {
70            p_deletion: 0.25,
71            p_insertion: 0.25,
72            p_substitution: 0.25,
73            p_transposition: 0.25,
74            _char_confusion: HashMap::new(),
75            max_edit_distance: 2, // Default max distance
76        }
77    }
78}
79
80impl ErrorModel {
81    /// Create a new error model with custom error probabilities
82    pub fn new(
83        p_deletion: f64,
84        p_insertion: f64,
85        p_substitution: f64,
86        p_transposition: f64,
87    ) -> Self {
88        // Normalize probabilities to sum to 1.0
89        let total = p_deletion + p_insertion + p_substitution + p_transposition;
90        Self {
91            p_deletion: p_deletion / total,
92            p_insertion: p_insertion / total,
93            p_substitution: p_substitution / total,
94            p_transposition: p_transposition / total,
95            _char_confusion: HashMap::new(),
96            max_edit_distance: 2,
97        }
98    }
99
100    /// Set the maximum edit distance to consider
101    pub fn with_max_distance(mut self, maxdistance: usize) -> Self {
102        self.max_edit_distance = maxdistance;
103        self
104    }
105
106    /// Calculate the error probability P(typo | correct)
107    pub fn error_probability(&self, typo: &str, correct: &str) -> f64 {
108        // Special case: identical words
109        if typo == correct {
110            return 1.0;
111        }
112
113        // Simple edit distance-based probability
114        let edit_distance = self.min_edit_operations(typo, correct);
115
116        match edit_distance.len() {
117            0 => 1.0, // No edits needed
118            1 => {
119                // Single edit
120                match edit_distance[0] {
121                    EditOp::Delete(_) => self.p_deletion,
122                    EditOp::Insert(_) => self.p_insertion,
123                    EditOp::Substitute(_, _) => self.p_substitution,
124                    EditOp::Transpose(_, _) => self.p_transposition,
125                }
126            }
127            n => {
128                // Multiple edits - calculate product of probabilities, with decay
129                let base_prob = 0.1f64.powi(n as i32 - 1);
130                let mut prob = base_prob;
131
132                for op in &edit_distance {
133                    match op {
134                        EditOp::Delete(_) => prob *= self.p_deletion,
135                        EditOp::Insert(_) => prob *= self.p_insertion,
136                        EditOp::Substitute(_, _) => prob *= self.p_substitution,
137                        EditOp::Transpose(_, _) => prob *= self.p_transposition,
138                    }
139                }
140
141                prob
142            }
143        }
144    }
145
146    /// Find the minimum edit operations to transform correct into typo
147    pub fn min_edit_operations(&self, typo: &str, correct: &str) -> Vec<EditOp> {
148        let typo_chars: Vec<char> = typo.chars().collect();
149        let correct_chars: Vec<char> = correct.chars().collect();
150
151        // Special case: identical strings
152        if typo == correct {
153            return vec![];
154        }
155
156        // Early return for length difference exceeding max edit distance
157        if (typo_chars.len() as isize - correct_chars.len() as isize).abs()
158            > self.max_edit_distance as isize
159        {
160            // Just return a placeholder operation since this exceeds our threshold
161            return vec![EditOp::Substitute('?', '?')];
162        }
163
164        // Try to detect the type of error
165        if correct_chars.len() == typo_chars.len() + 1 {
166            // Possible deletion
167            for i in 0..correct_chars.len() {
168                let mut test_chars = correct_chars.clone();
169                test_chars.remove(i);
170                if test_chars == typo_chars {
171                    return vec![EditOp::Delete(correct_chars[i])];
172                }
173            }
174        } else if correct_chars.len() + 1 == typo_chars.len() {
175            // Possible insertion
176            for i in 0..typo_chars.len() {
177                let mut test_chars = typo_chars.clone();
178                test_chars.remove(i);
179                if test_chars == correct_chars {
180                    return vec![EditOp::Insert(typo_chars[i])];
181                }
182            }
183        } else if correct_chars.len() == typo_chars.len() {
184            // Possible substitution or transposition
185            let mut diff_positions = Vec::new();
186
187            for i in 0..correct_chars.len() {
188                if correct_chars[i] != typo_chars[i] {
189                    diff_positions.push(i);
190                }
191            }
192
193            if diff_positions.len() == 1 {
194                // Single substitution
195                let i = diff_positions[0];
196                return vec![EditOp::Substitute(correct_chars[i], typo_chars[i])];
197            } else if diff_positions.len() == 2 && diff_positions[0] + 1 == diff_positions[1] {
198                let i = diff_positions[0];
199
200                // Check if it's a transposition
201                if correct_chars[i] == typo_chars[i + 1] && correct_chars[i + 1] == typo_chars[i] {
202                    return vec![EditOp::Transpose(correct_chars[i], correct_chars[i + 1])];
203                }
204            }
205        }
206
207        // Fallback: use Levenshtein to determine general edit distance with a more efficient algorithm
208        let mut operations = Vec::new();
209        let _distance = self.levenshtein_with_ops_efficient(correct, typo, &mut operations);
210        operations
211    }
212
213    /// Efficient implementation of Levenshtein distance with operations tracking
214    /// that uses only two rows of memory and implements early termination
215    fn levenshtein_with_ops_efficient(
216        &self,
217        s1: &str,
218        s2: &str,
219        operations: &mut Vec<EditOp>,
220    ) -> usize {
221        let chars1: Vec<char> = s1.chars().collect();
222        let chars2: Vec<char> = s2.chars().collect();
223        let len1 = chars1.len();
224        let len2 = chars2.len();
225
226        // Early return for exact match
227        if s1 == s2 {
228            return 0;
229        }
230
231        // Check if the difference in length exceeds maximum edit distance
232        if (len1 as isize - len2 as isize).abs() > self.max_edit_distance as isize {
233            return self.max_edit_distance + 1; // Exceed threshold
234        }
235
236        // Create a compact representation of the matrix using two rows
237        let mut prev_row = (0..=len2).collect::<Vec<_>>();
238        let mut curr_row = vec![0; len2 + 1];
239
240        // Use a separate matrix to track operations
241        // 0 = match/no op, 1 = insertion, 2 = deletion, 3 = substitution, 4 = transposition
242        let mut op_matrix = vec![vec![0; len2 + 1]; len1 + 1];
243
244        // Initialize first row (all insertions)
245        for j in 1..=len2 {
246            op_matrix[0][j] = 1; // Insertion
247        }
248
249        for i in 1..=len1 {
250            curr_row[0] = i;
251            op_matrix[i][0] = 2; // Deletion
252
253            for j in 1..=len2 {
254                let cost = if chars1[i - 1] == chars2[j - 1] { 0 } else { 1 };
255
256                // Calculate the costs of different operations
257                let del_cost = prev_row[j] + 1;
258                let ins_cost = curr_row[j - 1] + 1;
259                let sub_cost = prev_row[j - 1] + cost;
260
261                // Find minimum cost operation
262                curr_row[j] = min(min(del_cost, ins_cost), sub_cost);
263
264                // Track the operation
265                if curr_row[j] == del_cost {
266                    op_matrix[i][j] = 2; // Deletion
267                } else if curr_row[j] == ins_cost {
268                    op_matrix[i][j] = 1; // Insertion
269                } else if cost > 0 {
270                    op_matrix[i][j] = 3; // Substitution
271                } else {
272                    op_matrix[i][j] = 0; // Match
273                }
274
275                // Check for transposition
276                if i > 1
277                    && j > 1
278                    && chars1[i - 1] == chars2[j - 2]
279                    && chars1[i - 2] == chars2[j - 1]
280                {
281                    let trans_cost = prev_row[j - 2] + 1;
282                    if trans_cost < curr_row[j] {
283                        curr_row[j] = trans_cost;
284                        op_matrix[i][j] = 4; // Transposition
285                    }
286                }
287            }
288
289            // Early termination - if all values exceed max_edit_distance, stop
290            if curr_row.iter().all(|&c| c > self.max_edit_distance) {
291                return self.max_edit_distance + 1;
292            }
293
294            // Swap rows for next iteration
295            std::mem::swap(&mut prev_row, &mut curr_row);
296        }
297
298        // Backtrack to build the edit operations
299        let mut i = len1;
300        let mut j = len2;
301        let mut backtrack_ops = Vec::new();
302
303        while i > 0 || j > 0 {
304            match if i == 0 || j == 0 {
305                if i == 0 {
306                    1
307                } else {
308                    2
309                } // Special case for first row/column
310            } else {
311                op_matrix[i][j]
312            } {
313                0 => {
314                    // Match - no operation
315                    i -= 1;
316                    j -= 1;
317                }
318                1 => {
319                    // Insertion
320                    j -= 1;
321                    backtrack_ops.push(EditOp::Insert(chars2[j]));
322                }
323                2 => {
324                    // Deletion
325                    i -= 1;
326                    backtrack_ops.push(EditOp::Delete(chars1[i]));
327                }
328                3 => {
329                    // Substitution
330                    i -= 1;
331                    j -= 1;
332                    backtrack_ops.push(EditOp::Substitute(chars1[i], chars2[j]));
333                }
334                4 => {
335                    // Transposition
336                    i -= 2;
337                    j -= 2;
338                    backtrack_ops.push(EditOp::Transpose(chars1[i + 1], chars1[i + 2]));
339                }
340                _ => break, // Should not happen
341            }
342        }
343
344        // Reverse operations to get correct order
345        backtrack_ops.reverse();
346        operations.extend(backtrack_ops);
347
348        // Return the edit distance (final value in prev_row due to the swap)
349        prev_row[len2]
350    }
351
352    /// Legacy implementation of Levenshtein distance with operations tracking
353    pub fn levenshtein_with_ops(&self, s1: &str, s2: &str, operations: &mut Vec<EditOp>) -> usize {
354        let chars1: Vec<char> = s1.chars().collect();
355        let chars2: Vec<char> = s2.chars().collect();
356        let len1 = chars1.len();
357        let len2 = chars2.len();
358
359        // Create distance matrix
360        let mut matrix = vec![vec![0; len2 + 1]; len1 + 1];
361
362        // Initialize first row and column
363        for (i, row) in matrix.iter_mut().enumerate().take(len1 + 1) {
364            row[0] = i;
365        }
366
367        for j in 0..=len2 {
368            matrix[0][j] = j;
369        }
370
371        // Fill matrix and track operations
372        for i in 1..=len1 {
373            for j in 1..=len2 {
374                let cost = if chars1[i - 1] == chars2[j - 1] { 0 } else { 1 };
375
376                matrix[i][j] = min(
377                    min(
378                        matrix[i - 1][j] + 1, // deletion
379                        matrix[i][j - 1] + 1, // insertion
380                    ),
381                    matrix[i - 1][j - 1] + cost, // substitution
382                );
383
384                // Check for transposition (if possible)
385                if i > 1
386                    && j > 1
387                    && chars1[i - 1] == chars2[j - 2]
388                    && chars1[i - 2] == chars2[j - 1]
389                {
390                    matrix[i][j] = min(
391                        matrix[i][j],
392                        matrix[i - 2][j - 2] + 1, // transposition
393                    );
394                }
395            }
396        }
397
398        // Backtrack to find operations
399        let mut i = len1;
400        let mut j = len2;
401
402        // Use a temporary vector to store operations in correct order
403        let mut temp_ops = Vec::new();
404
405        while i > 0 || j > 0 {
406            if i > 0 && j > 0 && chars1[i - 1] == chars2[j - 1] {
407                // No operation (match)
408                i -= 1;
409                j -= 1;
410            } else if i > 1
411                && j > 1
412                && chars1[i - 1] == chars2[j - 2]
413                && chars1[i - 2] == chars2[j - 1]
414                && matrix[i][j] == matrix[i - 2][j - 2] + 1
415            {
416                // Transposition
417                temp_ops.push(EditOp::Transpose(chars1[i - 2], chars1[i - 1]));
418                i -= 2;
419                j -= 2;
420            } else if i > 0 && j > 0 && matrix[i][j] == matrix[i - 1][j - 1] + 1 {
421                // Substitution
422                temp_ops.push(EditOp::Substitute(chars1[i - 1], chars2[j - 1]));
423                i -= 1;
424                j -= 1;
425            } else if i > 0 && matrix[i][j] == matrix[i - 1][j] + 1 {
426                // Deletion
427                temp_ops.push(EditOp::Delete(chars1[i - 1]));
428                i -= 1;
429            } else if j > 0 && matrix[i][j] == matrix[i][j - 1] + 1 {
430                // Insertion
431                temp_ops.push(EditOp::Insert(chars2[j - 1]));
432                j -= 1;
433            } else {
434                // Should not reach here
435                break;
436            }
437        }
438
439        // Reverse operations to get correct order
440        temp_ops.reverse();
441        operations.extend(temp_ops);
442
443        matrix[len1][len2]
444    }
445}
446
447#[cfg(test)]
448mod tests {
449    use super::*;
450
451    #[test]
452    fn test_error_model() {
453        let error_model = ErrorModel::default();
454
455        // Test error probability calculations
456        let p_deletion = error_model.error_probability("cat", "cart"); // 'r' deleted
457        let p_insertion = error_model.error_probability("cart", "cat"); // 'r' inserted
458        let p_substitution = error_model.error_probability("cat", "cut"); // 'a' -> 'u'
459        let p_transposition = error_model.error_probability("form", "from"); // 'or' -> 'ro'
460
461        // Each type of error should have non-zero probability
462        assert!(p_deletion > 0.0);
463        assert!(p_insertion > 0.0);
464        assert!(p_substitution > 0.0);
465        assert!(p_transposition > 0.0);
466
467        // For identical words, probability should be 1.0
468        assert_eq!(error_model.error_probability("word", "word"), 1.0);
469    }
470
471    #[test]
472    fn test_edit_operations() {
473        let error_model = ErrorModel::default();
474
475        // Test deletion
476        let ops = error_model.min_edit_operations("cat", "cart");
477        assert_eq!(ops.len(), 1);
478        assert!(matches!(ops[0], EditOp::Delete('r')));
479
480        // Test insertion
481        let ops = error_model.min_edit_operations("cart", "cat");
482        assert_eq!(ops.len(), 1);
483        assert!(matches!(ops[0], EditOp::Insert('r')));
484
485        // Test substitution
486        let ops = error_model.min_edit_operations("cut", "cat");
487        assert_eq!(ops.len(), 1);
488        assert!(matches!(ops[0], EditOp::Substitute('a', 'u')));
489
490        // Test transposition
491        let ops = error_model.min_edit_operations("from", "form");
492        assert_eq!(ops.len(), 1);
493        assert!(matches!(ops[0], EditOp::Transpose('o', 'r')));
494    }
495
496    #[test]
497    fn test_efficient_levenshtein() {
498        let error_model = ErrorModel::default();
499
500        // Test identical strings
501        let mut ops1 = Vec::new();
502        let mut ops2 = Vec::new();
503        let dist1 = error_model.levenshtein_with_ops("hello", "hello", &mut ops1);
504        let dist2 = error_model.levenshtein_with_ops_efficient("hello", "hello", &mut ops2);
505        assert_eq!(dist1, 0);
506        assert_eq!(dist2, 0);
507        assert!(ops1.is_empty());
508        assert!(ops2.is_empty());
509
510        // Test simple substitution and insertion/deletion operations
511        let test_cases = [
512            ("cat", "bat"),  // Substitution - should be distance 1
513            ("cat", "cats"), // Insertion - should be distance 1
514            ("cats", "cat"), // Deletion - should be distance 1
515        ];
516
517        for (s1, s2) in test_cases {
518            let mut ops1 = Vec::new();
519            let mut ops2 = Vec::new();
520            let dist1 = error_model.levenshtein_with_ops(s1, s2, &mut ops1);
521            let dist2 = error_model.levenshtein_with_ops_efficient(s1, s2, &mut ops2);
522
523            // Both implementations should return distance 1 for these cases
524            assert_eq!(dist1, 1);
525            assert_eq!(dist2, 1);
526        }
527
528        // Transposition test - handled slightly differently in the two implementations
529        // Both should treat this as a small number of operations, not necessarily identical
530        let mut ops1 = Vec::new();
531        let mut ops2 = Vec::new();
532        error_model.levenshtein_with_ops("abc", "acb", &mut ops1);
533        error_model.levenshtein_with_ops_efficient("abc", "acb", &mut ops2);
534        assert!(ops1.len() <= 2); // Should be a small number of operations
535        assert!(ops2.len() <= 2);
536
537        // Test longer strings - focus on operation count rather than specific distance
538        let mut ops1 = Vec::new();
539        let mut ops2 = Vec::new();
540        let dist1 = error_model.levenshtein_with_ops("programming", "programmer", &mut ops1);
541        let dist2 =
542            error_model.levenshtein_with_ops_efficient("programming", "programmer", &mut ops2);
543        assert!(dist1 <= 3); // Should be within 3 edits
544        assert!(dist2 <= 3);
545    }
546
547    #[test]
548    fn test_early_termination() {
549        // Test with a small max edit distance
550        let error_model = ErrorModel::default().with_max_distance(1);
551
552        // These words are more than 1 edit apart
553        let ops = error_model.min_edit_operations("cat", "dog");
554
555        // Should recognize this is beyond the threshold and handle it
556        // The implementation might return empty list or a placeholder - both are valid behaviors
557        if !ops.is_empty() {
558            // If we got operations, check that they're valid
559            assert!(matches!(ops[0], EditOp::Substitute(_, _)) || ops.len() > 1);
560        }
561
562        // Test with a longer distance
563        let error_model = ErrorModel::default().with_max_distance(3);
564
565        // These words have 3 edits apart according to our algorithm
566        let ops = error_model.min_edit_operations("kitten", "sitting");
567        assert!(!ops.is_empty()); // Should return some operations
568
569        // These are more than 3 edits apart and should be handled appropriately
570        let ops = error_model.min_edit_operations("algorithm", "logarithm");
571        // Either return a placeholder operation or the actual list of operations
572        // depending on the implementation
573        if ops.len() == 1 {
574            // Placeholder case
575            assert!(matches!(ops[0], EditOp::Substitute(_, _)));
576        } else {
577            // Full operations list
578            assert!(!ops.is_empty());
579        }
580    }
581}