tensorlogic_sklears_kernels/
string_kernel.rs1use std::collections::HashMap;
7
8use serde::{Deserialize, Serialize};
9
10use crate::error::{KernelError, Result};
11use crate::types::Kernel;
12
13#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
15pub struct NGramKernelConfig {
16 pub n: usize,
18 pub normalize: bool,
20}
21
22impl NGramKernelConfig {
23 pub fn new(n: usize) -> Result<Self> {
25 if n == 0 {
26 return Err(KernelError::InvalidParameter {
27 parameter: "n".to_string(),
28 value: n.to_string(),
29 reason: "n-gram size must be positive".to_string(),
30 });
31 }
32
33 Ok(Self { n, normalize: true })
34 }
35
36 pub fn with_normalize(mut self, normalize: bool) -> Self {
38 self.normalize = normalize;
39 self
40 }
41}
42
43pub struct NGramKernel {
62 config: NGramKernelConfig,
63}
64
65impl NGramKernel {
66 pub fn new(config: NGramKernelConfig) -> Self {
68 Self { config }
69 }
70
71 fn extract_ngrams(&self, text: &str) -> HashMap<String, usize> {
73 let mut ngrams = HashMap::new();
74 let chars: Vec<char> = text.chars().collect();
75
76 if chars.len() < self.config.n {
77 return ngrams;
78 }
79
80 for i in 0..=(chars.len() - self.config.n) {
81 let ngram: String = chars[i..i + self.config.n].iter().collect();
82 *ngrams.entry(ngram).or_insert(0) += 1;
83 }
84
85 ngrams
86 }
87
88 pub fn compute_strings(&self, text1: &str, text2: &str) -> Result<f64> {
90 let ngrams1 = self.extract_ngrams(text1);
91 let ngrams2 = self.extract_ngrams(text2);
92
93 let mut similarity = 0.0;
95 for (ngram, count1) in &ngrams1 {
96 if let Some(count2) = ngrams2.get(ngram) {
97 similarity += (*count1).min(*count2) as f64;
98 }
99 }
100
101 if self.config.normalize {
102 let total1: usize = ngrams1.values().sum();
103 let total2: usize = ngrams2.values().sum();
104 let normalizer = ((total1 * total2) as f64).sqrt();
105
106 if normalizer > 0.0 {
107 similarity /= normalizer;
108 }
109 }
110
111 Ok(similarity)
112 }
113}
114
115impl Kernel for NGramKernel {
116 fn compute(&self, _x: &[f64], _y: &[f64]) -> Result<f64> {
117 Ok(0.0)
119 }
120
121 fn name(&self) -> &str {
122 "NGram"
123 }
124}
125
126#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
128pub struct SubsequenceKernelConfig {
129 pub max_length: usize,
131 pub decay: f64,
133}
134
135impl SubsequenceKernelConfig {
136 pub fn new() -> Self {
138 Self {
139 max_length: 3,
140 decay: 0.5,
141 }
142 }
143
144 pub fn with_max_length(mut self, length: usize) -> Result<Self> {
146 if length == 0 {
147 return Err(KernelError::InvalidParameter {
148 parameter: "max_length".to_string(),
149 value: length.to_string(),
150 reason: "max_length must be positive".to_string(),
151 });
152 }
153 self.max_length = length;
154 Ok(self)
155 }
156
157 pub fn with_decay(mut self, decay: f64) -> Result<Self> {
159 if !(0.0..=1.0).contains(&decay) {
160 return Err(KernelError::InvalidParameter {
161 parameter: "decay".to_string(),
162 value: decay.to_string(),
163 reason: "decay must be in [0, 1]".to_string(),
164 });
165 }
166 self.decay = decay;
167 Ok(self)
168 }
169}
170
171impl Default for SubsequenceKernelConfig {
172 fn default() -> Self {
173 Self::new()
174 }
175}
176
177pub struct SubsequenceKernel {
181 config: SubsequenceKernelConfig,
182}
183
184impl SubsequenceKernel {
185 pub fn new(config: SubsequenceKernelConfig) -> Self {
187 Self { config }
188 }
189
190 pub fn compute_strings(&self, text1: &str, text2: &str) -> Result<f64> {
192 let chars1: Vec<char> = text1.chars().collect();
193 let chars2: Vec<char> = text2.chars().collect();
194
195 let mut similarity = 0.0;
196
197 for length in 1..=self.config.max_length.min(chars1.len()).min(chars2.len()) {
199 let count = self.count_common_subsequences(&chars1, &chars2, length);
200 similarity += count as f64 * self.config.decay.powi(length as i32);
201 }
202
203 Ok(similarity)
204 }
205
206 fn count_common_subsequences(&self, s1: &[char], s2: &[char], length: usize) -> usize {
208 if length > s1.len() || length > s2.len() {
209 return 0;
210 }
211
212 let subseqs1 = self.extract_subsequences(s1, length);
214 let subseqs2 = self.extract_subsequences(s2, length);
215
216 let mut count = 0;
217 for subseq in &subseqs1 {
218 if subseqs2.contains(subseq) {
219 count += 1;
220 }
221 }
222
223 count
224 }
225
226 fn extract_subsequences(&self, chars: &[char], length: usize) -> Vec<Vec<char>> {
228 let mut subsequences = Vec::new();
229 self.generate_subsequences(chars, length, 0, Vec::new(), &mut subsequences);
230 subsequences
231 }
232
233 #[allow(clippy::only_used_in_recursion)]
235 fn generate_subsequences(
236 &self,
237 chars: &[char],
238 remaining: usize,
239 start: usize,
240 current: Vec<char>,
241 result: &mut Vec<Vec<char>>,
242 ) {
243 if remaining == 0 {
244 result.push(current);
245 return;
246 }
247
248 for i in start..chars.len() {
249 let mut new_current = current.clone();
250 new_current.push(chars[i]);
251 self.generate_subsequences(chars, remaining - 1, i + 1, new_current, result);
252 }
253 }
254}
255
256impl Kernel for SubsequenceKernel {
257 fn compute(&self, _x: &[f64], _y: &[f64]) -> Result<f64> {
258 Ok(0.0)
260 }
261
262 fn name(&self) -> &str {
263 "Subsequence"
264 }
265}
266
267pub struct EditDistanceKernel {
271 gamma: f64,
273}
274
275impl EditDistanceKernel {
276 pub fn new(gamma: f64) -> Result<Self> {
278 if gamma <= 0.0 {
279 return Err(KernelError::InvalidParameter {
280 parameter: "gamma".to_string(),
281 value: gamma.to_string(),
282 reason: "gamma must be positive".to_string(),
283 });
284 }
285
286 Ok(Self { gamma })
287 }
288
289 #[allow(clippy::needless_range_loop)]
291 fn edit_distance(&self, s1: &str, s2: &str) -> usize {
292 let chars1: Vec<char> = s1.chars().collect();
293 let chars2: Vec<char> = s2.chars().collect();
294
295 let m = chars1.len();
296 let n = chars2.len();
297
298 let mut dp = vec![vec![0; n + 1]; m + 1];
299
300 for i in 0..=m {
302 dp[i][0] = i;
303 }
304 for j in 0..=n {
305 dp[0][j] = j;
306 }
307
308 for i in 1..=m {
310 for j in 1..=n {
311 let cost = if chars1[i - 1] == chars2[j - 1] { 0 } else { 1 };
312
313 dp[i][j] = (dp[i - 1][j] + 1) .min(dp[i][j - 1] + 1) .min(dp[i - 1][j - 1] + cost); }
317 }
318
319 dp[m][n]
320 }
321
322 pub fn compute_strings(&self, text1: &str, text2: &str) -> Result<f64> {
324 let distance = self.edit_distance(text1, text2);
325 let similarity = (-self.gamma * distance as f64).exp();
326 Ok(similarity)
327 }
328}
329
330impl Kernel for EditDistanceKernel {
331 fn compute(&self, _x: &[f64], _y: &[f64]) -> Result<f64> {
332 Ok(0.0)
334 }
335
336 fn name(&self) -> &str {
337 "EditDistance"
338 }
339}
340
341#[cfg(test)]
342mod tests {
343 use super::*;
344
345 #[test]
346 fn test_ngram_kernel() {
347 let config = NGramKernelConfig::new(2).unwrap();
348 let kernel = NGramKernel::new(config);
349
350 let text1 = "hello";
351 let text2 = "hallo";
352
353 let sim = kernel.compute_strings(text1, text2).unwrap();
354 assert!(sim > 0.0);
355 assert!(sim <= 1.0);
356 }
357
358 #[test]
359 fn test_ngram_identical_strings() {
360 let config = NGramKernelConfig::new(2).unwrap();
361 let kernel = NGramKernel::new(config);
362
363 let text = "test";
364 let sim = kernel.compute_strings(text, text).unwrap();
365
366 assert!((sim - 1.0).abs() < 1e-10);
367 }
368
369 #[test]
370 fn test_ngram_different_strings() {
371 let config = NGramKernelConfig::new(2).unwrap();
372 let kernel = NGramKernel::new(config);
373
374 let text1 = "abc";
375 let text2 = "xyz";
376
377 let sim = kernel.compute_strings(text1, text2).unwrap();
378 assert!(sim < 0.1); }
380
381 #[test]
382 fn test_ngram_config_invalid_n() {
383 let result = NGramKernelConfig::new(0);
384 assert!(result.is_err());
385 }
386
387 #[test]
388 fn test_subsequence_kernel() {
389 let config = SubsequenceKernelConfig::new();
390 let kernel = SubsequenceKernel::new(config);
391
392 let text1 = "abc";
393 let text2 = "aec";
394
395 let sim = kernel.compute_strings(text1, text2).unwrap();
396 assert!(sim > 0.0);
397 }
398
399 #[test]
400 fn test_subsequence_identical() {
401 let config = SubsequenceKernelConfig::new();
402 let kernel = SubsequenceKernel::new(config);
403
404 let text = "test";
405 let sim = kernel.compute_strings(text, text).unwrap();
406
407 assert!(sim > 0.0);
408 }
409
410 #[test]
411 fn test_subsequence_config() {
412 let config = SubsequenceKernelConfig::new()
413 .with_max_length(5)
414 .unwrap()
415 .with_decay(0.7)
416 .unwrap();
417
418 assert_eq!(config.max_length, 5);
419 assert!((config.decay - 0.7).abs() < 1e-10);
420 }
421
422 #[test]
423 fn test_subsequence_invalid_config() {
424 let result = SubsequenceKernelConfig::new().with_max_length(0);
425 assert!(result.is_err());
426
427 let result = SubsequenceKernelConfig::new().with_decay(1.5);
428 assert!(result.is_err());
429 }
430
431 #[test]
432 fn test_edit_distance_kernel() {
433 let kernel = EditDistanceKernel::new(0.1).unwrap();
434
435 let text1 = "kitten";
436 let text2 = "sitting";
437
438 let sim = kernel.compute_strings(text1, text2).unwrap();
439 assert!(sim > 0.0);
440 assert!(sim < 1.0);
441 }
442
443 #[test]
444 fn test_edit_distance_identical() {
445 let kernel = EditDistanceKernel::new(0.1).unwrap();
446
447 let text = "test";
448 let sim = kernel.compute_strings(text, text).unwrap();
449
450 assert!((sim - 1.0).abs() < 1e-10); }
452
453 #[test]
454 fn test_edit_distance_computation() {
455 let kernel = EditDistanceKernel::new(1.0).unwrap();
456
457 assert_eq!(kernel.edit_distance("", ""), 0);
458 assert_eq!(kernel.edit_distance("a", ""), 1);
459 assert_eq!(kernel.edit_distance("", "a"), 1);
460 assert_eq!(kernel.edit_distance("abc", "abc"), 0);
461 assert_eq!(kernel.edit_distance("abc", "abd"), 1);
462 assert_eq!(kernel.edit_distance("kitten", "sitting"), 3);
463 }
464
465 #[test]
466 fn test_edit_distance_invalid_gamma() {
467 let result = EditDistanceKernel::new(-0.1);
468 assert!(result.is_err());
469
470 let result = EditDistanceKernel::new(0.0);
471 assert!(result.is_err());
472 }
473
474 #[test]
475 fn test_kernel_trait() {
476 let kernel = NGramKernel::new(NGramKernelConfig::new(2).unwrap());
477 assert_eq!(kernel.name(), "NGram");
478
479 let kernel = SubsequenceKernel::new(SubsequenceKernelConfig::new());
480 assert_eq!(kernel.name(), "Subsequence");
481
482 let kernel = EditDistanceKernel::new(0.1).unwrap();
483 assert_eq!(kernel.name(), "EditDistance");
484 }
485}