Skip to main content

tensorlogic_sklears_kernels/
string_kernel.rs

1//! String kernels for text similarity.
2//!
3//! These kernels measure similarity between text sequences using
4//! substring matching, n-grams, and subsequence features.
5
6use std::collections::HashMap;
7
8use serde::{Deserialize, Serialize};
9
10use crate::error::{KernelError, Result};
11use crate::types::Kernel;
12
13/// N-gram string kernel configuration
14#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
15pub struct NGramKernelConfig {
16    /// N-gram size
17    pub n: usize,
18    /// Whether to normalize by string length
19    pub normalize: bool,
20}
21
22impl NGramKernelConfig {
23    /// Create configuration with n-gram size
24    pub fn new(n: usize) -> Result<Self> {
25        if n == 0 {
26            return Err(KernelError::InvalidParameter {
27                parameter: "n".to_string(),
28                value: n.to_string(),
29                reason: "n-gram size must be positive".to_string(),
30            });
31        }
32
33        Ok(Self { n, normalize: true })
34    }
35
36    /// Set normalization flag
37    pub fn with_normalize(mut self, normalize: bool) -> Self {
38        self.normalize = normalize;
39        self
40    }
41}
42
43/// N-gram string kernel
44///
45/// Measures similarity by counting common n-grams.
46///
47/// # Example
48///
49/// ```rust
50/// use tensorlogic_sklears_kernels::{NGramKernel, NGramKernelConfig};
51///
52/// let config = NGramKernelConfig::new(2).unwrap(); // bigrams
53/// let kernel = NGramKernel::new(config);
54///
55/// let text1 = "hello world";
56/// let text2 = "hello there";
57///
58/// let sim = kernel.compute_strings(text1, text2).unwrap();
59/// println!("Similarity: {}", sim);
60/// ```
61pub struct NGramKernel {
62    config: NGramKernelConfig,
63}
64
65impl NGramKernel {
66    /// Create a new n-gram kernel
67    pub fn new(config: NGramKernelConfig) -> Self {
68        Self { config }
69    }
70
71    /// Extract n-grams from text
72    fn extract_ngrams(&self, text: &str) -> HashMap<String, usize> {
73        let mut ngrams = HashMap::new();
74        let chars: Vec<char> = text.chars().collect();
75
76        if chars.len() < self.config.n {
77            return ngrams;
78        }
79
80        for i in 0..=(chars.len() - self.config.n) {
81            let ngram: String = chars[i..i + self.config.n].iter().collect();
82            *ngrams.entry(ngram).or_insert(0) += 1;
83        }
84
85        ngrams
86    }
87
88    /// Compute similarity between two text strings
89    pub fn compute_strings(&self, text1: &str, text2: &str) -> Result<f64> {
90        let ngrams1 = self.extract_ngrams(text1);
91        let ngrams2 = self.extract_ngrams(text2);
92
93        // Compute intersection
94        let mut similarity = 0.0;
95        for (ngram, count1) in &ngrams1 {
96            if let Some(count2) = ngrams2.get(ngram) {
97                similarity += (*count1).min(*count2) as f64;
98            }
99        }
100
101        if self.config.normalize {
102            let total1: usize = ngrams1.values().sum();
103            let total2: usize = ngrams2.values().sum();
104            let normalizer = ((total1 * total2) as f64).sqrt();
105
106            if normalizer > 0.0 {
107                similarity /= normalizer;
108            }
109        }
110
111        Ok(similarity)
112    }
113}
114
115impl Kernel for NGramKernel {
116    fn compute(&self, _x: &[f64], _y: &[f64]) -> Result<f64> {
117        // Placeholder - use compute_strings for string data
118        Ok(0.0)
119    }
120
121    fn name(&self) -> &str {
122        "NGram"
123    }
124}
125
126/// Subsequence string kernel configuration
127#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
128pub struct SubsequenceKernelConfig {
129    /// Maximum subsequence length
130    pub max_length: usize,
131    /// Decay factor for longer subsequences
132    pub decay: f64,
133}
134
135impl SubsequenceKernelConfig {
136    /// Create default configuration
137    pub fn new() -> Self {
138        Self {
139            max_length: 3,
140            decay: 0.5,
141        }
142    }
143
144    /// Set maximum length
145    pub fn with_max_length(mut self, length: usize) -> Result<Self> {
146        if length == 0 {
147            return Err(KernelError::InvalidParameter {
148                parameter: "max_length".to_string(),
149                value: length.to_string(),
150                reason: "max_length must be positive".to_string(),
151            });
152        }
153        self.max_length = length;
154        Ok(self)
155    }
156
157    /// Set decay factor
158    pub fn with_decay(mut self, decay: f64) -> Result<Self> {
159        if !(0.0..=1.0).contains(&decay) {
160            return Err(KernelError::InvalidParameter {
161                parameter: "decay".to_string(),
162                value: decay.to_string(),
163                reason: "decay must be in [0, 1]".to_string(),
164            });
165        }
166        self.decay = decay;
167        Ok(self)
168    }
169}
170
171impl Default for SubsequenceKernelConfig {
172    fn default() -> Self {
173        Self::new()
174    }
175}
176
177/// Subsequence string kernel
178///
179/// Measures similarity by counting common non-contiguous subsequences.
180pub struct SubsequenceKernel {
181    config: SubsequenceKernelConfig,
182}
183
184impl SubsequenceKernel {
185    /// Create a new subsequence kernel
186    pub fn new(config: SubsequenceKernelConfig) -> Self {
187        Self { config }
188    }
189
190    /// Compute similarity between two text strings
191    pub fn compute_strings(&self, text1: &str, text2: &str) -> Result<f64> {
192        let chars1: Vec<char> = text1.chars().collect();
193        let chars2: Vec<char> = text2.chars().collect();
194
195        let mut similarity = 0.0;
196
197        // Use dynamic programming to count common subsequences
198        for length in 1..=self.config.max_length.min(chars1.len()).min(chars2.len()) {
199            let count = self.count_common_subsequences(&chars1, &chars2, length);
200            similarity += count as f64 * self.config.decay.powi(length as i32);
201        }
202
203        Ok(similarity)
204    }
205
206    /// Count common subsequences of given length
207    fn count_common_subsequences(&self, s1: &[char], s2: &[char], length: usize) -> usize {
208        if length > s1.len() || length > s2.len() {
209            return 0;
210        }
211
212        // Simplified counting - exact match of subsequences
213        let subseqs1 = self.extract_subsequences(s1, length);
214        let subseqs2 = self.extract_subsequences(s2, length);
215
216        let mut count = 0;
217        for subseq in &subseqs1 {
218            if subseqs2.contains(subseq) {
219                count += 1;
220            }
221        }
222
223        count
224    }
225
226    /// Extract all subsequences of given length
227    fn extract_subsequences(&self, chars: &[char], length: usize) -> Vec<Vec<char>> {
228        let mut subsequences = Vec::new();
229        self.generate_subsequences(chars, length, 0, Vec::new(), &mut subsequences);
230        subsequences
231    }
232
233    /// Generate subsequences recursively
234    #[allow(clippy::only_used_in_recursion)]
235    fn generate_subsequences(
236        &self,
237        chars: &[char],
238        remaining: usize,
239        start: usize,
240        current: Vec<char>,
241        result: &mut Vec<Vec<char>>,
242    ) {
243        if remaining == 0 {
244            result.push(current);
245            return;
246        }
247
248        for i in start..chars.len() {
249            let mut new_current = current.clone();
250            new_current.push(chars[i]);
251            self.generate_subsequences(chars, remaining - 1, i + 1, new_current, result);
252        }
253    }
254}
255
256impl Kernel for SubsequenceKernel {
257    fn compute(&self, _x: &[f64], _y: &[f64]) -> Result<f64> {
258        // Placeholder - use compute_strings for string data
259        Ok(0.0)
260    }
261
262    fn name(&self) -> &str {
263        "Subsequence"
264    }
265}
266
267/// Edit distance kernel (exponential of negative edit distance)
268///
269/// K(s1, s2) = exp(-gamma * edit_distance(s1, s2))
270pub struct EditDistanceKernel {
271    /// Bandwidth parameter
272    gamma: f64,
273}
274
275impl EditDistanceKernel {
276    /// Create a new edit distance kernel
277    pub fn new(gamma: f64) -> Result<Self> {
278        if gamma <= 0.0 {
279            return Err(KernelError::InvalidParameter {
280                parameter: "gamma".to_string(),
281                value: gamma.to_string(),
282                reason: "gamma must be positive".to_string(),
283            });
284        }
285
286        Ok(Self { gamma })
287    }
288
289    /// Compute Levenshtein edit distance
290    #[allow(clippy::needless_range_loop)]
291    fn edit_distance(&self, s1: &str, s2: &str) -> usize {
292        let chars1: Vec<char> = s1.chars().collect();
293        let chars2: Vec<char> = s2.chars().collect();
294
295        let m = chars1.len();
296        let n = chars2.len();
297
298        let mut dp = vec![vec![0; n + 1]; m + 1];
299
300        // Initialize
301        for i in 0..=m {
302            dp[i][0] = i;
303        }
304        for j in 0..=n {
305            dp[0][j] = j;
306        }
307
308        // Fill DP table
309        for i in 1..=m {
310            for j in 1..=n {
311                let cost = if chars1[i - 1] == chars2[j - 1] { 0 } else { 1 };
312
313                dp[i][j] = (dp[i - 1][j] + 1) // deletion
314                    .min(dp[i][j - 1] + 1) // insertion
315                    .min(dp[i - 1][j - 1] + cost); // substitution
316            }
317        }
318
319        dp[m][n]
320    }
321
322    /// Compute similarity between two text strings
323    pub fn compute_strings(&self, text1: &str, text2: &str) -> Result<f64> {
324        let distance = self.edit_distance(text1, text2);
325        let similarity = (-self.gamma * distance as f64).exp();
326        Ok(similarity)
327    }
328}
329
330impl Kernel for EditDistanceKernel {
331    fn compute(&self, _x: &[f64], _y: &[f64]) -> Result<f64> {
332        // Placeholder - use compute_strings for string data
333        Ok(0.0)
334    }
335
336    fn name(&self) -> &str {
337        "EditDistance"
338    }
339}
340
341#[cfg(test)]
342mod tests {
343    use super::*;
344
345    #[test]
346    fn test_ngram_kernel() {
347        let config = NGramKernelConfig::new(2).unwrap();
348        let kernel = NGramKernel::new(config);
349
350        let text1 = "hello";
351        let text2 = "hallo";
352
353        let sim = kernel.compute_strings(text1, text2).unwrap();
354        assert!(sim > 0.0);
355        assert!(sim <= 1.0);
356    }
357
358    #[test]
359    fn test_ngram_identical_strings() {
360        let config = NGramKernelConfig::new(2).unwrap();
361        let kernel = NGramKernel::new(config);
362
363        let text = "test";
364        let sim = kernel.compute_strings(text, text).unwrap();
365
366        assert!((sim - 1.0).abs() < 1e-10);
367    }
368
369    #[test]
370    fn test_ngram_different_strings() {
371        let config = NGramKernelConfig::new(2).unwrap();
372        let kernel = NGramKernel::new(config);
373
374        let text1 = "abc";
375        let text2 = "xyz";
376
377        let sim = kernel.compute_strings(text1, text2).unwrap();
378        assert!(sim < 0.1); // Should be very low similarity
379    }
380
381    #[test]
382    fn test_ngram_config_invalid_n() {
383        let result = NGramKernelConfig::new(0);
384        assert!(result.is_err());
385    }
386
387    #[test]
388    fn test_subsequence_kernel() {
389        let config = SubsequenceKernelConfig::new();
390        let kernel = SubsequenceKernel::new(config);
391
392        let text1 = "abc";
393        let text2 = "aec";
394
395        let sim = kernel.compute_strings(text1, text2).unwrap();
396        assert!(sim > 0.0);
397    }
398
399    #[test]
400    fn test_subsequence_identical() {
401        let config = SubsequenceKernelConfig::new();
402        let kernel = SubsequenceKernel::new(config);
403
404        let text = "test";
405        let sim = kernel.compute_strings(text, text).unwrap();
406
407        assert!(sim > 0.0);
408    }
409
410    #[test]
411    fn test_subsequence_config() {
412        let config = SubsequenceKernelConfig::new()
413            .with_max_length(5)
414            .unwrap()
415            .with_decay(0.7)
416            .unwrap();
417
418        assert_eq!(config.max_length, 5);
419        assert!((config.decay - 0.7).abs() < 1e-10);
420    }
421
422    #[test]
423    fn test_subsequence_invalid_config() {
424        let result = SubsequenceKernelConfig::new().with_max_length(0);
425        assert!(result.is_err());
426
427        let result = SubsequenceKernelConfig::new().with_decay(1.5);
428        assert!(result.is_err());
429    }
430
431    #[test]
432    fn test_edit_distance_kernel() {
433        let kernel = EditDistanceKernel::new(0.1).unwrap();
434
435        let text1 = "kitten";
436        let text2 = "sitting";
437
438        let sim = kernel.compute_strings(text1, text2).unwrap();
439        assert!(sim > 0.0);
440        assert!(sim < 1.0);
441    }
442
443    #[test]
444    fn test_edit_distance_identical() {
445        let kernel = EditDistanceKernel::new(0.1).unwrap();
446
447        let text = "test";
448        let sim = kernel.compute_strings(text, text).unwrap();
449
450        assert!((sim - 1.0).abs() < 1e-10); // exp(-0 * 0.1) = 1.0
451    }
452
453    #[test]
454    fn test_edit_distance_computation() {
455        let kernel = EditDistanceKernel::new(1.0).unwrap();
456
457        assert_eq!(kernel.edit_distance("", ""), 0);
458        assert_eq!(kernel.edit_distance("a", ""), 1);
459        assert_eq!(kernel.edit_distance("", "a"), 1);
460        assert_eq!(kernel.edit_distance("abc", "abc"), 0);
461        assert_eq!(kernel.edit_distance("abc", "abd"), 1);
462        assert_eq!(kernel.edit_distance("kitten", "sitting"), 3);
463    }
464
465    #[test]
466    fn test_edit_distance_invalid_gamma() {
467        let result = EditDistanceKernel::new(-0.1);
468        assert!(result.is_err());
469
470        let result = EditDistanceKernel::new(0.0);
471        assert!(result.is_err());
472    }
473
474    #[test]
475    fn test_kernel_trait() {
476        let kernel = NGramKernel::new(NGramKernelConfig::new(2).unwrap());
477        assert_eq!(kernel.name(), "NGram");
478
479        let kernel = SubsequenceKernel::new(SubsequenceKernelConfig::new());
480        assert_eq!(kernel.name(), "Subsequence");
481
482        let kernel = EditDistanceKernel::new(0.1).unwrap();
483        assert_eq!(kernel.name(), "EditDistance");
484    }
485}