Skip to main content

sklears_kernel_approximation/
nlp_kernels.rs

1//! Natural Language Processing kernel approximations
2//!
3//! This module provides kernel approximation methods specifically designed for NLP tasks,
4//! including text kernels, semantic features, syntactic approximations, and word embedding kernels.
5
6use scirs2_core::ndarray::{Array1, Array2, ArrayView1, Axis};
7use scirs2_core::random::essentials::Normal as RandNormal;
8use scirs2_core::random::rngs::StdRng as RealStdRng;
9use scirs2_core::random::RngExt;
10use scirs2_core::random::{thread_rng, SeedableRng};
11use serde::{Deserialize, Serialize};
12use std::collections::hash_map::DefaultHasher;
13use std::collections::{HashMap, HashSet};
14use std::hash::{Hash, Hasher};
15
16use sklears_core::error::Result;
17use sklears_core::traits::{Fit, Transform};
18
19/// Text kernel approximation using bag-of-words and n-gram features
20#[derive(Debug, Clone, Serialize, Deserialize)]
21/// TextKernelApproximation
22pub struct TextKernelApproximation {
23    /// n_components
24    pub n_components: usize,
25    /// max_features
26    pub max_features: usize,
27    /// ngram_range
28    pub ngram_range: (usize, usize),
29    /// min_df
30    pub min_df: usize,
31    /// max_df
32    pub max_df: f64,
33    /// use_tf_idf
34    pub use_tf_idf: bool,
35    /// use_hashing
36    pub use_hashing: bool,
37    /// sublinear_tf
38    pub sublinear_tf: bool,
39}
40
41impl TextKernelApproximation {
42    pub fn new(n_components: usize) -> Self {
43        Self {
44            n_components,
45            max_features: 10000,
46            ngram_range: (1, 1),
47            min_df: 1,
48            max_df: 1.0,
49            use_tf_idf: true,
50            use_hashing: false,
51            sublinear_tf: false,
52        }
53    }
54
55    pub fn max_features(mut self, max_features: usize) -> Self {
56        self.max_features = max_features;
57        self
58    }
59
60    pub fn ngram_range(mut self, ngram_range: (usize, usize)) -> Self {
61        self.ngram_range = ngram_range;
62        self
63    }
64
65    pub fn min_df(mut self, min_df: usize) -> Self {
66        self.min_df = min_df;
67        self
68    }
69
70    pub fn max_df(mut self, max_df: f64) -> Self {
71        self.max_df = max_df;
72        self
73    }
74
75    pub fn use_tf_idf(mut self, use_tf_idf: bool) -> Self {
76        self.use_tf_idf = use_tf_idf;
77        self
78    }
79
80    pub fn use_hashing(mut self, use_hashing: bool) -> Self {
81        self.use_hashing = use_hashing;
82        self
83    }
84
85    pub fn sublinear_tf(mut self, sublinear_tf: bool) -> Self {
86        self.sublinear_tf = sublinear_tf;
87        self
88    }
89
90    fn tokenize(&self, text: &str) -> Vec<String> {
91        text.to_lowercase()
92            .split_whitespace()
93            .map(|s| s.trim_matches(|c: char| !c.is_alphanumeric()))
94            .filter(|s| !s.is_empty())
95            .map(|s| s.to_string())
96            .collect()
97    }
98
99    fn extract_ngrams(&self, tokens: &[String]) -> Vec<String> {
100        let mut ngrams = Vec::new();
101
102        for n in self.ngram_range.0..=self.ngram_range.1 {
103            for i in 0..=tokens.len().saturating_sub(n) {
104                if i + n <= tokens.len() {
105                    let ngram = tokens[i..i + n].join(" ");
106                    ngrams.push(ngram);
107                }
108            }
109        }
110
111        ngrams
112    }
113
114    fn hash_feature(&self, feature: &str) -> usize {
115        let mut hasher = DefaultHasher::new();
116        feature.hash(&mut hasher);
117        hasher.finish() as usize % self.max_features
118    }
119}
120
121#[derive(Debug, Clone, Serialize, Deserialize)]
122/// FittedTextKernelApproximation
123pub struct FittedTextKernelApproximation {
124    /// n_components
125    pub n_components: usize,
126    /// max_features
127    pub max_features: usize,
128    /// ngram_range
129    pub ngram_range: (usize, usize),
130    /// min_df
131    pub min_df: usize,
132    /// max_df
133    pub max_df: f64,
134    /// use_tf_idf
135    pub use_tf_idf: bool,
136    /// use_hashing
137    pub use_hashing: bool,
138    /// sublinear_tf
139    pub sublinear_tf: bool,
140    /// vocabulary
141    pub vocabulary: HashMap<String, usize>,
142    /// idf_values
143    pub idf_values: Array1<f64>,
144    /// random_weights
145    pub random_weights: Array2<f64>,
146}
147
148impl Fit<Vec<String>, ()> for TextKernelApproximation {
149    type Fitted = FittedTextKernelApproximation;
150
151    fn fit(self, documents: &Vec<String>, _y: &()) -> Result<Self::Fitted> {
152        let mut vocabulary = HashMap::new();
153        let mut document_frequencies = HashMap::new();
154        let n_documents = documents.len();
155
156        // Build vocabulary and compute document frequencies
157        for doc in documents {
158            let tokens = self.tokenize(doc);
159            let ngrams = self.extract_ngrams(&tokens);
160            let unique_ngrams: HashSet<String> = ngrams.into_iter().collect();
161
162            for ngram in unique_ngrams {
163                *document_frequencies.entry(ngram.clone()).or_insert(0) += 1;
164                if !vocabulary.contains_key(&ngram) {
165                    vocabulary.insert(ngram, vocabulary.len());
166                }
167            }
168        }
169
170        // Filter vocabulary based on min_df and max_df
171        let mut filtered_vocabulary = HashMap::new();
172        for (term, &df) in &document_frequencies {
173            let df_ratio = df as f64 / n_documents as f64;
174            if df >= self.min_df && df_ratio <= self.max_df {
175                filtered_vocabulary.insert(term.clone(), filtered_vocabulary.len());
176            }
177        }
178
179        // Limit vocabulary size
180        if filtered_vocabulary.len() > self.max_features {
181            let mut sorted_vocab: Vec<_> = filtered_vocabulary.iter().collect();
182            sorted_vocab.sort_by(|a, b| {
183                document_frequencies[a.0]
184                    .cmp(&document_frequencies[b.0])
185                    .reverse()
186            });
187
188            let mut new_vocabulary = HashMap::new();
189            for (term, _) in sorted_vocab.iter().take(self.max_features) {
190                new_vocabulary.insert(term.to_string(), new_vocabulary.len());
191            }
192            filtered_vocabulary = new_vocabulary;
193        }
194
195        // Compute IDF values
196        let vocab_size = filtered_vocabulary.len();
197        let mut idf_values = Array1::zeros(vocab_size);
198
199        if self.use_tf_idf {
200            for (term, &idx) in &filtered_vocabulary {
201                let df = document_frequencies.get(term).unwrap_or(&0);
202                idf_values[idx] = (n_documents as f64 / (*df as f64 + 1.0)).ln() + 1.0;
203            }
204        } else {
205            idf_values.fill(1.0);
206        }
207
208        // Generate random weights for kernel approximation
209        let mut rng = RealStdRng::from_seed(thread_rng().random());
210        let normal = RandNormal::new(0.0, 1.0).expect("operation should succeed");
211
212        let mut random_weights = Array2::zeros((self.n_components, vocab_size));
213        for i in 0..self.n_components {
214            for j in 0..vocab_size {
215                random_weights[[i, j]] = rng.sample(normal);
216            }
217        }
218
219        Ok(FittedTextKernelApproximation {
220            n_components: self.n_components,
221            max_features: self.max_features,
222            ngram_range: self.ngram_range,
223            min_df: self.min_df,
224            max_df: self.max_df,
225            use_tf_idf: self.use_tf_idf,
226            use_hashing: self.use_hashing,
227            sublinear_tf: self.sublinear_tf,
228            vocabulary: filtered_vocabulary,
229            idf_values,
230            random_weights,
231        })
232    }
233}
234
235impl FittedTextKernelApproximation {
236    fn tokenize(&self, text: &str) -> Vec<String> {
237        text.to_lowercase()
238            .split_whitespace()
239            .map(|s| s.trim_matches(|c: char| !c.is_alphanumeric()))
240            .filter(|s| !s.is_empty())
241            .map(|s| s.to_string())
242            .collect()
243    }
244
245    fn extract_ngrams(&self, tokens: &[String]) -> Vec<String> {
246        let mut ngrams = Vec::new();
247
248        for n in self.ngram_range.0..=self.ngram_range.1 {
249            for i in 0..=tokens.len().saturating_sub(n) {
250                if i + n <= tokens.len() {
251                    let ngram = tokens[i..i + n].join(" ");
252                    ngrams.push(ngram);
253                }
254            }
255        }
256
257        ngrams
258    }
259}
260
261impl Transform<Vec<String>, Array2<f64>> for FittedTextKernelApproximation {
262    fn transform(&self, documents: &Vec<String>) -> Result<Array2<f64>> {
263        let n_documents = documents.len();
264        let vocab_size = self.vocabulary.len();
265
266        // Convert documents to TF-IDF vectors
267        let mut tf_idf_matrix = Array2::zeros((n_documents, vocab_size));
268
269        for (doc_idx, doc) in documents.iter().enumerate() {
270            let tokens = self.tokenize(doc);
271            let ngrams = self.extract_ngrams(&tokens);
272
273            // Count term frequencies
274            let mut term_counts = HashMap::new();
275            for ngram in ngrams {
276                *term_counts.entry(ngram).or_insert(0) += 1;
277            }
278
279            // Compute TF-IDF
280            for (term, &count) in &term_counts {
281                if let Some(&vocab_idx) = self.vocabulary.get(term) {
282                    let tf = if self.sublinear_tf {
283                        1.0 + (count as f64).ln()
284                    } else {
285                        count as f64
286                    };
287
288                    let tf_idf = tf * self.idf_values[vocab_idx];
289                    tf_idf_matrix[[doc_idx, vocab_idx]] = tf_idf;
290                }
291            }
292        }
293
294        // Apply random projection for kernel approximation
295        let mut result = Array2::zeros((n_documents, self.n_components));
296
297        for i in 0..n_documents {
298            for j in 0..self.n_components {
299                let mut dot_product = 0.0;
300                for k in 0..vocab_size {
301                    dot_product += tf_idf_matrix[[i, k]] * self.random_weights[[j, k]];
302                }
303                result[[i, j]] = dot_product;
304            }
305        }
306
307        Ok(result)
308    }
309}
310
311/// Semantic kernel approximation using word embeddings and similarity measures
312#[derive(Debug, Clone, Serialize, Deserialize)]
313/// SemanticKernelApproximation
314pub struct SemanticKernelApproximation {
315    /// n_components
316    pub n_components: usize,
317    /// embedding_dim
318    pub embedding_dim: usize,
319    /// similarity_measure
320    pub similarity_measure: SimilarityMeasure,
321    /// aggregation_method
322    pub aggregation_method: AggregationMethod,
323    /// use_attention
324    pub use_attention: bool,
325}
326
327#[derive(Debug, Clone, Serialize, Deserialize)]
328/// SimilarityMeasure
329pub enum SimilarityMeasure {
330    /// Cosine
331    Cosine,
332    /// Euclidean
333    Euclidean,
334    /// Manhattan
335    Manhattan,
336    /// Dot
337    Dot,
338}
339
340#[derive(Debug, Clone, Serialize, Deserialize)]
341/// AggregationMethod
342pub enum AggregationMethod {
343    /// Mean
344    Mean,
345    /// Max
346    Max,
347    /// Sum
348    Sum,
349    /// AttentionWeighted
350    AttentionWeighted,
351}
352
353impl SemanticKernelApproximation {
354    pub fn new(n_components: usize, embedding_dim: usize) -> Self {
355        Self {
356            n_components,
357            embedding_dim,
358            similarity_measure: SimilarityMeasure::Cosine,
359            aggregation_method: AggregationMethod::Mean,
360            use_attention: false,
361        }
362    }
363
364    pub fn similarity_measure(mut self, measure: SimilarityMeasure) -> Self {
365        self.similarity_measure = measure;
366        self
367    }
368
369    pub fn aggregation_method(mut self, method: AggregationMethod) -> Self {
370        self.aggregation_method = method;
371        self
372    }
373
374    pub fn use_attention(mut self, use_attention: bool) -> Self {
375        self.use_attention = use_attention;
376        self
377    }
378
379    fn compute_similarity(&self, vec1: &ArrayView1<f64>, vec2: &ArrayView1<f64>) -> f64 {
380        match self.similarity_measure {
381            SimilarityMeasure::Cosine => {
382                let dot = vec1.dot(vec2);
383                let norm1 = vec1.dot(vec1).sqrt();
384                let norm2 = vec2.dot(vec2).sqrt();
385                if norm1 > 0.0 && norm2 > 0.0 {
386                    dot / (norm1 * norm2)
387                } else {
388                    0.0
389                }
390            }
391            SimilarityMeasure::Euclidean => {
392                let diff = vec1 - vec2;
393                -diff.dot(&diff).sqrt()
394            }
395            SimilarityMeasure::Manhattan => {
396                let diff = vec1 - vec2;
397                -diff.mapv(|x| x.abs()).sum()
398            }
399            SimilarityMeasure::Dot => vec1.dot(vec2),
400        }
401    }
402
403    fn aggregate_embeddings(&self, embeddings: &Array2<f64>) -> Array1<f64> {
404        match self.aggregation_method {
405            AggregationMethod::Mean => embeddings
406                .mean_axis(Axis(0))
407                .expect("operation should succeed"),
408            AggregationMethod::Max => {
409                let mut result = Array1::zeros(embeddings.ncols());
410                for i in 0..embeddings.ncols() {
411                    let col = embeddings.column(i);
412                    result[i] = col.fold(f64::NEG_INFINITY, |acc, &x| acc.max(x));
413                }
414                result
415            }
416            AggregationMethod::Sum => embeddings.sum_axis(Axis(0)),
417            AggregationMethod::AttentionWeighted => {
418                // Simplified attention mechanism
419                let n_tokens = embeddings.nrows();
420                let mut attention_weights = Array1::zeros(n_tokens);
421
422                for i in 0..n_tokens {
423                    let token_embedding = embeddings.row(i);
424                    attention_weights[i] = token_embedding.dot(&token_embedding).sqrt();
425                }
426
427                // Softmax
428                let max_weight = attention_weights.fold(f64::NEG_INFINITY, |acc, &x| acc.max(x));
429                attention_weights.mapv_inplace(|x| (x - max_weight).exp());
430                let sum_weights = attention_weights.sum();
431                if sum_weights > 0.0 {
432                    attention_weights /= sum_weights;
433                }
434
435                // Weighted average
436                let mut result = Array1::zeros(embeddings.ncols());
437                for i in 0..n_tokens {
438                    let token_embedding = embeddings.row(i);
439                    for j in 0..embeddings.ncols() {
440                        result[j] += attention_weights[i] * token_embedding[j];
441                    }
442                }
443                result
444            }
445        }
446    }
447}
448
449#[derive(Debug, Clone, Serialize, Deserialize)]
450/// FittedSemanticKernelApproximation
451pub struct FittedSemanticKernelApproximation {
452    /// n_components
453    pub n_components: usize,
454    /// embedding_dim
455    pub embedding_dim: usize,
456    /// similarity_measure
457    pub similarity_measure: SimilarityMeasure,
458    /// aggregation_method
459    pub aggregation_method: AggregationMethod,
460    /// use_attention
461    pub use_attention: bool,
462    /// word_embeddings
463    pub word_embeddings: HashMap<String, Array1<f64>>,
464    /// projection_matrix
465    pub projection_matrix: Array2<f64>,
466}
467
468impl Fit<Vec<String>, ()> for SemanticKernelApproximation {
469    type Fitted = FittedSemanticKernelApproximation;
470
471    fn fit(self, documents: &Vec<String>, _y: &()) -> Result<Self::Fitted> {
472        let mut rng = RealStdRng::from_seed(thread_rng().random());
473        let normal = RandNormal::new(0.0, 1.0).expect("operation should succeed");
474
475        // Generate random word embeddings (in practice, these would be pre-trained)
476        let mut word_embeddings = HashMap::new();
477        let mut vocabulary = HashSet::new();
478
479        for doc in documents {
480            let tokens: Vec<String> = doc
481                .to_lowercase()
482                .split_whitespace()
483                .map(|s| s.to_string())
484                .collect();
485
486            for token in tokens {
487                vocabulary.insert(token);
488            }
489        }
490
491        for word in vocabulary {
492            let embedding = Array1::from_vec(
493                (0..self.embedding_dim)
494                    .map(|_| rng.sample(normal))
495                    .collect(),
496            );
497            word_embeddings.insert(word, embedding);
498        }
499
500        // Generate projection matrix for kernel approximation
501        let mut projection_matrix = Array2::zeros((self.n_components, self.embedding_dim));
502        for i in 0..self.n_components {
503            for j in 0..self.embedding_dim {
504                projection_matrix[[i, j]] = rng.sample(normal);
505            }
506        }
507
508        Ok(FittedSemanticKernelApproximation {
509            n_components: self.n_components,
510            embedding_dim: self.embedding_dim,
511            similarity_measure: self.similarity_measure,
512            aggregation_method: self.aggregation_method,
513            use_attention: self.use_attention,
514            word_embeddings,
515            projection_matrix,
516        })
517    }
518}
519
520impl FittedSemanticKernelApproximation {
521    fn aggregate_embeddings(&self, embeddings: &Array2<f64>) -> Array1<f64> {
522        match self.aggregation_method {
523            AggregationMethod::Mean => embeddings
524                .mean_axis(Axis(0))
525                .expect("operation should succeed"),
526            AggregationMethod::Max => {
527                let mut result = Array1::zeros(embeddings.ncols());
528                for i in 0..embeddings.ncols() {
529                    let col = embeddings.column(i);
530                    result[i] = col.fold(f64::NEG_INFINITY, |acc, &x| acc.max(x));
531                }
532                result
533            }
534            AggregationMethod::Sum => embeddings.sum_axis(Axis(0)),
535            AggregationMethod::AttentionWeighted => {
536                // Simplified attention mechanism
537                let n_tokens = embeddings.nrows();
538                let mut attention_weights = Array1::zeros(n_tokens);
539
540                for i in 0..n_tokens {
541                    let token_embedding = embeddings.row(i);
542                    attention_weights[i] = token_embedding.dot(&token_embedding).sqrt();
543                }
544
545                // Softmax
546                let max_weight = attention_weights.fold(f64::NEG_INFINITY, |acc, &x| acc.max(x));
547                attention_weights.mapv_inplace(|x| (x - max_weight).exp());
548                let sum_weights = attention_weights.sum();
549                if sum_weights > 0.0 {
550                    attention_weights /= sum_weights;
551                }
552
553                // Weighted average
554                let mut result = Array1::zeros(embeddings.ncols());
555                for i in 0..n_tokens {
556                    let token_embedding = embeddings.row(i);
557                    for j in 0..embeddings.ncols() {
558                        result[j] += attention_weights[i] * token_embedding[j];
559                    }
560                }
561                result
562            }
563        }
564    }
565}
566
567impl Transform<Vec<String>, Array2<f64>> for FittedSemanticKernelApproximation {
568    fn transform(&self, documents: &Vec<String>) -> Result<Array2<f64>> {
569        let n_documents = documents.len();
570        let mut result = Array2::zeros((n_documents, self.n_components));
571
572        for (doc_idx, doc) in documents.iter().enumerate() {
573            let tokens: Vec<String> = doc
574                .to_lowercase()
575                .split_whitespace()
576                .map(|s| s.to_string())
577                .collect();
578
579            // Get embeddings for tokens
580            let mut token_embeddings = Vec::new();
581            for token in tokens {
582                if let Some(embedding) = self.word_embeddings.get(&token) {
583                    token_embeddings.push(embedding.clone());
584                }
585            }
586
587            if !token_embeddings.is_empty() {
588                // Convert to Array2
589                let embeddings_matrix = Array2::from_shape_vec(
590                    (token_embeddings.len(), self.embedding_dim),
591                    token_embeddings
592                        .iter()
593                        .flat_map(|e| e.iter().cloned())
594                        .collect(),
595                )?;
596
597                // Aggregate embeddings
598                let doc_embedding = self.aggregate_embeddings(&embeddings_matrix);
599
600                // Apply projection for kernel approximation
601                for i in 0..self.n_components {
602                    let projected = self.projection_matrix.row(i).dot(&doc_embedding);
603                    result[[doc_idx, i]] = projected.tanh(); // Non-linear activation
604                }
605            }
606        }
607
608        Ok(result)
609    }
610}
611
612/// Syntactic kernel approximation using parse trees and grammatical features
613#[derive(Debug, Clone, Serialize, Deserialize)]
614/// SyntacticKernelApproximation
615pub struct SyntacticKernelApproximation {
616    /// n_components
617    pub n_components: usize,
618    /// max_tree_depth
619    pub max_tree_depth: usize,
620    /// use_pos_tags
621    pub use_pos_tags: bool,
622    /// use_dependencies
623    pub use_dependencies: bool,
624    /// tree_kernel_type
625    pub tree_kernel_type: TreeKernelType,
626}
627
628#[derive(Debug, Clone, Serialize, Deserialize)]
629/// TreeKernelType
630pub enum TreeKernelType {
631    /// Subset
632    Subset,
633    /// Subsequence
634    Subsequence,
635    /// Partial
636    Partial,
637}
638
639impl SyntacticKernelApproximation {
640    pub fn new(n_components: usize) -> Self {
641        Self {
642            n_components,
643            max_tree_depth: 10,
644            use_pos_tags: true,
645            use_dependencies: true,
646            tree_kernel_type: TreeKernelType::Subset,
647        }
648    }
649
650    pub fn max_tree_depth(mut self, depth: usize) -> Self {
651        self.max_tree_depth = depth;
652        self
653    }
654
655    pub fn use_pos_tags(mut self, use_pos: bool) -> Self {
656        self.use_pos_tags = use_pos;
657        self
658    }
659
660    pub fn use_dependencies(mut self, use_deps: bool) -> Self {
661        self.use_dependencies = use_deps;
662        self
663    }
664
665    pub fn tree_kernel_type(mut self, kernel_type: TreeKernelType) -> Self {
666        self.tree_kernel_type = kernel_type;
667        self
668    }
669
670    fn extract_syntactic_features(&self, text: &str) -> Vec<String> {
671        let mut features = Vec::new();
672
673        // Simplified syntactic feature extraction
674        let tokens: Vec<&str> = text.split_whitespace().collect();
675
676        // Add POS tag features (simplified)
677        if self.use_pos_tags {
678            for token in &tokens {
679                let pos_tag = self.simple_pos_tag(token);
680                features.push(format!("POS_{}", pos_tag));
681            }
682        }
683
684        // Add dependency features (simplified)
685        if self.use_dependencies {
686            for i in 0..tokens.len() {
687                if i > 0 {
688                    features.push(format!("DEP_{}_{}", tokens[i - 1], tokens[i]));
689                }
690            }
691        }
692
693        // Add n-gram features
694        for n in 1..=3 {
695            for i in 0..=tokens.len().saturating_sub(n) {
696                if i + n <= tokens.len() {
697                    let ngram = tokens[i..i + n].join("_");
698                    features.push(format!("NGRAM_{}", ngram));
699                }
700            }
701        }
702
703        features
704    }
705
706    fn simple_pos_tag(&self, token: &str) -> String {
707        // Very simplified POS tagging
708        let token_lower = token.to_lowercase();
709
710        if token_lower.ends_with("ing") {
711            "VBG".to_string()
712        } else if token_lower.ends_with("ed") {
713            "VBD".to_string()
714        } else if token_lower.ends_with("ly") {
715            "RB".to_string()
716        } else if token_lower.ends_with("s") && !token_lower.ends_with("ss") {
717            "NNS".to_string()
718        } else if token.chars().all(|c| c.is_alphabetic() && c.is_uppercase()) {
719            "NNP".to_string()
720        } else if token.chars().all(|c| c.is_alphabetic()) {
721            "NN".to_string()
722        } else if token.chars().all(|c| c.is_numeric()) {
723            "CD".to_string()
724        } else {
725            "UNK".to_string()
726        }
727    }
728}
729
730#[derive(Debug, Clone, Serialize, Deserialize)]
731/// FittedSyntacticKernelApproximation
732pub struct FittedSyntacticKernelApproximation {
733    /// n_components
734    pub n_components: usize,
735    /// max_tree_depth
736    pub max_tree_depth: usize,
737    /// use_pos_tags
738    pub use_pos_tags: bool,
739    /// use_dependencies
740    pub use_dependencies: bool,
741    /// tree_kernel_type
742    pub tree_kernel_type: TreeKernelType,
743    /// feature_vocabulary
744    pub feature_vocabulary: HashMap<String, usize>,
745    /// random_weights
746    pub random_weights: Array2<f64>,
747}
748
749impl Fit<Vec<String>, ()> for SyntacticKernelApproximation {
750    type Fitted = FittedSyntacticKernelApproximation;
751
752    fn fit(self, documents: &Vec<String>, _y: &()) -> Result<Self::Fitted> {
753        let mut feature_vocabulary = HashMap::new();
754
755        // Extract syntactic features from all documents
756        for doc in documents {
757            let features = self.extract_syntactic_features(doc);
758            for feature in features {
759                if !feature_vocabulary.contains_key(&feature) {
760                    feature_vocabulary.insert(feature, feature_vocabulary.len());
761                }
762            }
763        }
764
765        // Generate random weights
766        let mut rng = RealStdRng::from_seed(thread_rng().random());
767        let normal = RandNormal::new(0.0, 1.0).expect("operation should succeed");
768
769        let vocab_size = feature_vocabulary.len();
770        let mut random_weights = Array2::zeros((self.n_components, vocab_size));
771
772        for i in 0..self.n_components {
773            for j in 0..vocab_size {
774                random_weights[[i, j]] = rng.sample(normal);
775            }
776        }
777
778        Ok(FittedSyntacticKernelApproximation {
779            n_components: self.n_components,
780            max_tree_depth: self.max_tree_depth,
781            use_pos_tags: self.use_pos_tags,
782            use_dependencies: self.use_dependencies,
783            tree_kernel_type: self.tree_kernel_type,
784            feature_vocabulary,
785            random_weights,
786        })
787    }
788}
789
790impl FittedSyntacticKernelApproximation {
791    fn extract_syntactic_features(&self, text: &str) -> Vec<String> {
792        let mut features = Vec::new();
793
794        // Simplified syntactic feature extraction
795        let tokens: Vec<&str> = text.split_whitespace().collect();
796
797        // Add POS tag features (simplified)
798        if self.use_pos_tags {
799            for token in &tokens {
800                let pos_tag = self.simple_pos_tag(token);
801                features.push(format!("POS_{}", pos_tag));
802            }
803        }
804
805        // Add dependency features (simplified)
806        if self.use_dependencies {
807            for i in 0..tokens.len() {
808                if i > 0 {
809                    features.push(format!("DEP_{}_{}", tokens[i - 1], tokens[i]));
810                }
811            }
812        }
813
814        // Add n-gram features
815        for n in 1..=3 {
816            for i in 0..=tokens.len().saturating_sub(n) {
817                if i + n <= tokens.len() {
818                    let ngram = tokens[i..i + n].join("_");
819                    features.push(format!("NGRAM_{}", ngram));
820                }
821            }
822        }
823
824        features
825    }
826
827    fn simple_pos_tag(&self, token: &str) -> String {
828        // Very simplified POS tagging
829        let token_lower = token.to_lowercase();
830
831        if token_lower.ends_with("ing") {
832            "VBG".to_string()
833        } else if token_lower.ends_with("ed") {
834            "VBD".to_string()
835        } else if token_lower.ends_with("ly") {
836            "RB".to_string()
837        } else if token_lower.ends_with("s") && !token_lower.ends_with("ss") {
838            "NNS".to_string()
839        } else if token.chars().all(|c| c.is_alphabetic() && c.is_uppercase()) {
840            "NNP".to_string()
841        } else if token.chars().all(|c| c.is_alphabetic()) {
842            "NN".to_string()
843        } else if token.chars().all(|c| c.is_numeric()) {
844            "CD".to_string()
845        } else {
846            "UNK".to_string()
847        }
848    }
849}
850
851impl Transform<Vec<String>, Array2<f64>> for FittedSyntacticKernelApproximation {
852    fn transform(&self, documents: &Vec<String>) -> Result<Array2<f64>> {
853        let n_documents = documents.len();
854        let vocab_size = self.feature_vocabulary.len();
855
856        // Convert documents to feature vectors
857        let mut feature_matrix = Array2::zeros((n_documents, vocab_size));
858
859        for (doc_idx, doc) in documents.iter().enumerate() {
860            let features = self.extract_syntactic_features(doc);
861            let mut feature_counts = HashMap::new();
862
863            for feature in features {
864                *feature_counts.entry(feature).or_insert(0) += 1;
865            }
866
867            for (feature, count) in feature_counts {
868                if let Some(&vocab_idx) = self.feature_vocabulary.get(&feature) {
869                    feature_matrix[[doc_idx, vocab_idx]] = count as f64;
870                }
871            }
872        }
873
874        // Apply random projection
875        let mut result = Array2::zeros((n_documents, self.n_components));
876
877        for i in 0..n_documents {
878            for j in 0..self.n_components {
879                let mut dot_product = 0.0;
880                for k in 0..vocab_size {
881                    dot_product += feature_matrix[[i, k]] * self.random_weights[[j, k]];
882                }
883                result[[i, j]] = dot_product.tanh();
884            }
885        }
886
887        Ok(result)
888    }
889}
890
891/// Document kernel approximation for document-level features
892#[derive(Debug, Clone, Serialize, Deserialize)]
893/// DocumentKernelApproximation
894pub struct DocumentKernelApproximation {
895    /// n_components
896    pub n_components: usize,
897    /// use_topic_features
898    pub use_topic_features: bool,
899    /// use_readability_features
900    pub use_readability_features: bool,
901    /// use_stylometric_features
902    pub use_stylometric_features: bool,
903    /// n_topics
904    pub n_topics: usize,
905}
906
907impl DocumentKernelApproximation {
908    pub fn new(n_components: usize) -> Self {
909        Self {
910            n_components,
911            use_topic_features: true,
912            use_readability_features: true,
913            use_stylometric_features: true,
914            n_topics: 10,
915        }
916    }
917
918    pub fn use_topic_features(mut self, use_topics: bool) -> Self {
919        self.use_topic_features = use_topics;
920        self
921    }
922
923    pub fn use_readability_features(mut self, use_readability: bool) -> Self {
924        self.use_readability_features = use_readability;
925        self
926    }
927
928    pub fn use_stylometric_features(mut self, use_stylometric: bool) -> Self {
929        self.use_stylometric_features = use_stylometric;
930        self
931    }
932
933    pub fn n_topics(mut self, n_topics: usize) -> Self {
934        self.n_topics = n_topics;
935        self
936    }
937
938    fn extract_document_features(&self, text: &str) -> Vec<f64> {
939        let mut features = Vec::new();
940
941        let sentences: Vec<&str> = text.split(&['.', '!', '?'][..]).collect();
942        let words: Vec<&str> = text.split_whitespace().collect();
943        let characters: Vec<char> = text.chars().collect();
944
945        if self.use_readability_features {
946            // Readability features
947            let avg_sentence_length = if !sentences.is_empty() {
948                words.len() as f64 / sentences.len() as f64
949            } else {
950                0.0
951            };
952
953            let avg_word_length = if !words.is_empty() {
954                characters.len() as f64 / words.len() as f64
955            } else {
956                0.0
957            };
958
959            features.push(avg_sentence_length);
960            features.push(avg_word_length);
961            features.push(sentences.len() as f64);
962            features.push(words.len() as f64);
963        }
964
965        if self.use_stylometric_features {
966            // Stylometric features
967            let punctuation_count = characters
968                .iter()
969                .filter(|c| c.is_ascii_punctuation())
970                .count();
971            let uppercase_count = characters.iter().filter(|c| c.is_uppercase()).count();
972            let digit_count = characters.iter().filter(|c| c.is_numeric()).count();
973
974            features.push(punctuation_count as f64 / characters.len() as f64);
975            features.push(uppercase_count as f64 / characters.len() as f64);
976            features.push(digit_count as f64 / characters.len() as f64);
977
978            // Type-token ratio
979            let unique_words: HashSet<&str> = words.iter().cloned().collect();
980            let ttr = if !words.is_empty() {
981                unique_words.len() as f64 / words.len() as f64
982            } else {
983                0.0
984            };
985            features.push(ttr);
986        }
987
988        if self.use_topic_features {
989            // Simplified topic features (in practice, use LDA or similar)
990            let mut topic_features = vec![0.0; self.n_topics];
991            let mut hasher = DefaultHasher::new();
992            text.hash(&mut hasher);
993            let hash = hasher.finish();
994
995            for i in 0..self.n_topics {
996                topic_features[i] = ((hash + i as u64) % 1000) as f64 / 1000.0;
997            }
998
999            features.extend(topic_features);
1000        }
1001
1002        features
1003    }
1004}
1005
1006#[derive(Debug, Clone, Serialize, Deserialize)]
1007/// FittedDocumentKernelApproximation
1008pub struct FittedDocumentKernelApproximation {
1009    /// n_components
1010    pub n_components: usize,
1011    /// use_topic_features
1012    pub use_topic_features: bool,
1013    /// use_readability_features
1014    pub use_readability_features: bool,
1015    /// use_stylometric_features
1016    pub use_stylometric_features: bool,
1017    /// n_topics
1018    pub n_topics: usize,
1019    /// feature_dim
1020    pub feature_dim: usize,
1021    /// random_weights
1022    pub random_weights: Array2<f64>,
1023}
1024
1025impl Fit<Vec<String>, ()> for DocumentKernelApproximation {
1026    type Fitted = FittedDocumentKernelApproximation;
1027
1028    fn fit(self, documents: &Vec<String>, _y: &()) -> Result<Self::Fitted> {
1029        // Determine feature dimension
1030        let sample_features = self.extract_document_features(&documents[0]);
1031        let feature_dim = sample_features.len();
1032
1033        // Generate random weights
1034        let mut rng = RealStdRng::from_seed(thread_rng().random());
1035        let normal = RandNormal::new(0.0, 1.0).expect("operation should succeed");
1036
1037        let mut random_weights = Array2::zeros((self.n_components, feature_dim));
1038        for i in 0..self.n_components {
1039            for j in 0..feature_dim {
1040                random_weights[[i, j]] = rng.sample(normal);
1041            }
1042        }
1043
1044        Ok(FittedDocumentKernelApproximation {
1045            n_components: self.n_components,
1046            use_topic_features: self.use_topic_features,
1047            use_readability_features: self.use_readability_features,
1048            use_stylometric_features: self.use_stylometric_features,
1049            n_topics: self.n_topics,
1050            feature_dim,
1051            random_weights,
1052        })
1053    }
1054}
1055
1056impl FittedDocumentKernelApproximation {
1057    fn extract_document_features(&self, text: &str) -> Vec<f64> {
1058        let mut features = Vec::new();
1059
1060        let sentences: Vec<&str> = text.split(&['.', '!', '?'][..]).collect();
1061        let words: Vec<&str> = text.split_whitespace().collect();
1062        let characters: Vec<char> = text.chars().collect();
1063
1064        if self.use_readability_features {
1065            // Readability features
1066            let avg_sentence_length = if !sentences.is_empty() {
1067                words.len() as f64 / sentences.len() as f64
1068            } else {
1069                0.0
1070            };
1071
1072            let avg_word_length = if !words.is_empty() {
1073                characters.len() as f64 / words.len() as f64
1074            } else {
1075                0.0
1076            };
1077
1078            features.push(avg_sentence_length);
1079            features.push(avg_word_length);
1080            features.push(sentences.len() as f64);
1081            features.push(words.len() as f64);
1082        }
1083
1084        if self.use_stylometric_features {
1085            // Stylometric features
1086            let punctuation_count = characters
1087                .iter()
1088                .filter(|c| c.is_ascii_punctuation())
1089                .count();
1090            let uppercase_count = characters.iter().filter(|c| c.is_uppercase()).count();
1091            let digit_count = characters.iter().filter(|c| c.is_numeric()).count();
1092
1093            features.push(punctuation_count as f64 / characters.len() as f64);
1094            features.push(uppercase_count as f64 / characters.len() as f64);
1095            features.push(digit_count as f64 / characters.len() as f64);
1096
1097            // Type-token ratio
1098            let unique_words: HashSet<&str> = words.iter().cloned().collect();
1099            let ttr = if !words.is_empty() {
1100                unique_words.len() as f64 / words.len() as f64
1101            } else {
1102                0.0
1103            };
1104            features.push(ttr);
1105        }
1106
1107        if self.use_topic_features {
1108            // Simplified topic features (in practice, use LDA or similar)
1109            let mut topic_features = vec![0.0; self.n_topics];
1110            let mut hasher = DefaultHasher::new();
1111            text.hash(&mut hasher);
1112            let hash = hasher.finish();
1113
1114            for i in 0..self.n_topics {
1115                topic_features[i] = ((hash + i as u64) % 1000) as f64 / 1000.0;
1116            }
1117
1118            features.extend(topic_features);
1119        }
1120
1121        features
1122    }
1123}
1124
1125impl Transform<Vec<String>, Array2<f64>> for FittedDocumentKernelApproximation {
1126    fn transform(&self, documents: &Vec<String>) -> Result<Array2<f64>> {
1127        let n_documents = documents.len();
1128        let mut result = Array2::zeros((n_documents, self.n_components));
1129
1130        for (doc_idx, doc) in documents.iter().enumerate() {
1131            let features = self.extract_document_features(doc);
1132            let feature_array = Array1::from_vec(features);
1133
1134            for i in 0..self.n_components {
1135                let projected = self.random_weights.row(i).dot(&feature_array);
1136                result[[doc_idx, i]] = projected.tanh();
1137            }
1138        }
1139
1140        Ok(result)
1141    }
1142}
1143
1144#[allow(non_snake_case)]
1145#[cfg(test)]
1146mod tests {
1147    use super::*;
1148
1149    #[test]
1150    fn test_text_kernel_approximation() {
1151        let docs = vec![
1152            "This is a test document".to_string(),
1153            "Another test document here".to_string(),
1154            "Third document for testing".to_string(),
1155        ];
1156
1157        let text_kernel = TextKernelApproximation::new(50);
1158        let fitted = text_kernel
1159            .fit(&docs, &())
1160            .expect("operation should succeed");
1161        let transformed = fitted.transform(&docs).expect("operation should succeed");
1162
1163        assert_eq!(transformed.shape()[0], 3);
1164        assert_eq!(transformed.shape()[1], 50);
1165    }
1166
1167    #[test]
1168    fn test_semantic_kernel_approximation() {
1169        let docs = vec![
1170            "Semantic similarity test".to_string(),
1171            "Another semantic test".to_string(),
1172        ];
1173
1174        let semantic_kernel = SemanticKernelApproximation::new(30, 100);
1175        let fitted = semantic_kernel
1176            .fit(&docs, &())
1177            .expect("operation should succeed");
1178        let transformed = fitted.transform(&docs).expect("operation should succeed");
1179
1180        assert_eq!(transformed.shape()[0], 2);
1181        assert_eq!(transformed.shape()[1], 30);
1182    }
1183
1184    #[test]
1185    fn test_syntactic_kernel_approximation() {
1186        let docs = vec![
1187            "The cat sat on the mat".to_string(),
1188            "Dogs are running quickly".to_string(),
1189        ];
1190
1191        let syntactic_kernel = SyntacticKernelApproximation::new(40);
1192        let fitted = syntactic_kernel
1193            .fit(&docs, &())
1194            .expect("operation should succeed");
1195        let transformed = fitted.transform(&docs).expect("operation should succeed");
1196
1197        assert_eq!(transformed.shape()[0], 2);
1198        assert_eq!(transformed.shape()[1], 40);
1199    }
1200
1201    #[test]
1202    fn test_document_kernel_approximation() {
1203        let docs = vec![
1204            "This is a long document with multiple sentences. It contains various features."
1205                .to_string(),
1206            "Short doc.".to_string(),
1207        ];
1208
1209        let doc_kernel = DocumentKernelApproximation::new(25);
1210        let fitted = doc_kernel
1211            .fit(&docs, &())
1212            .expect("operation should succeed");
1213        let transformed = fitted.transform(&docs).expect("operation should succeed");
1214
1215        assert_eq!(transformed.shape()[0], 2);
1216        assert_eq!(transformed.shape()[1], 25);
1217    }
1218}