sklears_kernel_approximation/
nlp_kernels.rs

1//! Natural Language Processing kernel approximations
2//!
3//! This module provides kernel approximation methods specifically designed for NLP tasks,
4//! including text kernels, semantic features, syntactic approximations, and word embedding kernels.
5
6use scirs2_core::ndarray::{Array1, Array2, ArrayView1, Axis};
7use scirs2_core::random::essentials::Normal as RandNormal;
8use scirs2_core::random::rngs::StdRng as RealStdRng;
9use scirs2_core::random::Rng;
10use scirs2_core::random::{thread_rng, SeedableRng};
11use serde::{Deserialize, Serialize};
12use std::collections::hash_map::DefaultHasher;
13use std::collections::{HashMap, HashSet};
14use std::hash::{Hash, Hasher};
15
16use sklears_core::error::Result;
17use sklears_core::traits::{Fit, Transform};
18
19/// Text kernel approximation using bag-of-words and n-gram features
20#[derive(Debug, Clone, Serialize, Deserialize)]
21/// TextKernelApproximation
22pub struct TextKernelApproximation {
23    /// n_components
24    pub n_components: usize,
25    /// max_features
26    pub max_features: usize,
27    /// ngram_range
28    pub ngram_range: (usize, usize),
29    /// min_df
30    pub min_df: usize,
31    /// max_df
32    pub max_df: f64,
33    /// use_tf_idf
34    pub use_tf_idf: bool,
35    /// use_hashing
36    pub use_hashing: bool,
37    /// sublinear_tf
38    pub sublinear_tf: bool,
39}
40
41impl TextKernelApproximation {
42    pub fn new(n_components: usize) -> Self {
43        Self {
44            n_components,
45            max_features: 10000,
46            ngram_range: (1, 1),
47            min_df: 1,
48            max_df: 1.0,
49            use_tf_idf: true,
50            use_hashing: false,
51            sublinear_tf: false,
52        }
53    }
54
55    pub fn max_features(mut self, max_features: usize) -> Self {
56        self.max_features = max_features;
57        self
58    }
59
60    pub fn ngram_range(mut self, ngram_range: (usize, usize)) -> Self {
61        self.ngram_range = ngram_range;
62        self
63    }
64
65    pub fn min_df(mut self, min_df: usize) -> Self {
66        self.min_df = min_df;
67        self
68    }
69
70    pub fn max_df(mut self, max_df: f64) -> Self {
71        self.max_df = max_df;
72        self
73    }
74
75    pub fn use_tf_idf(mut self, use_tf_idf: bool) -> Self {
76        self.use_tf_idf = use_tf_idf;
77        self
78    }
79
80    pub fn use_hashing(mut self, use_hashing: bool) -> Self {
81        self.use_hashing = use_hashing;
82        self
83    }
84
85    pub fn sublinear_tf(mut self, sublinear_tf: bool) -> Self {
86        self.sublinear_tf = sublinear_tf;
87        self
88    }
89
90    fn tokenize(&self, text: &str) -> Vec<String> {
91        text.to_lowercase()
92            .split_whitespace()
93            .map(|s| s.trim_matches(|c: char| !c.is_alphanumeric()))
94            .filter(|s| !s.is_empty())
95            .map(|s| s.to_string())
96            .collect()
97    }
98
99    fn extract_ngrams(&self, tokens: &[String]) -> Vec<String> {
100        let mut ngrams = Vec::new();
101
102        for n in self.ngram_range.0..=self.ngram_range.1 {
103            for i in 0..=tokens.len().saturating_sub(n) {
104                if i + n <= tokens.len() {
105                    let ngram = tokens[i..i + n].join(" ");
106                    ngrams.push(ngram);
107                }
108            }
109        }
110
111        ngrams
112    }
113
114    fn hash_feature(&self, feature: &str) -> usize {
115        let mut hasher = DefaultHasher::new();
116        feature.hash(&mut hasher);
117        hasher.finish() as usize % self.max_features
118    }
119}
120
121#[derive(Debug, Clone, Serialize, Deserialize)]
122/// FittedTextKernelApproximation
123pub struct FittedTextKernelApproximation {
124    /// n_components
125    pub n_components: usize,
126    /// max_features
127    pub max_features: usize,
128    /// ngram_range
129    pub ngram_range: (usize, usize),
130    /// min_df
131    pub min_df: usize,
132    /// max_df
133    pub max_df: f64,
134    /// use_tf_idf
135    pub use_tf_idf: bool,
136    /// use_hashing
137    pub use_hashing: bool,
138    /// sublinear_tf
139    pub sublinear_tf: bool,
140    /// vocabulary
141    pub vocabulary: HashMap<String, usize>,
142    /// idf_values
143    pub idf_values: Array1<f64>,
144    /// random_weights
145    pub random_weights: Array2<f64>,
146}
147
148impl Fit<Vec<String>, ()> for TextKernelApproximation {
149    type Fitted = FittedTextKernelApproximation;
150
151    fn fit(self, documents: &Vec<String>, _y: &()) -> Result<Self::Fitted> {
152        let mut vocabulary = HashMap::new();
153        let mut document_frequencies = HashMap::new();
154        let n_documents = documents.len();
155
156        // Build vocabulary and compute document frequencies
157        for doc in documents {
158            let tokens = self.tokenize(doc);
159            let ngrams = self.extract_ngrams(&tokens);
160            let unique_ngrams: HashSet<String> = ngrams.into_iter().collect();
161
162            for ngram in unique_ngrams {
163                *document_frequencies.entry(ngram.clone()).or_insert(0) += 1;
164                if !vocabulary.contains_key(&ngram) {
165                    vocabulary.insert(ngram, vocabulary.len());
166                }
167            }
168        }
169
170        // Filter vocabulary based on min_df and max_df
171        let mut filtered_vocabulary = HashMap::new();
172        for (term, &df) in &document_frequencies {
173            let df_ratio = df as f64 / n_documents as f64;
174            if df >= self.min_df && df_ratio <= self.max_df {
175                filtered_vocabulary.insert(term.clone(), filtered_vocabulary.len());
176            }
177        }
178
179        // Limit vocabulary size
180        if filtered_vocabulary.len() > self.max_features {
181            let mut sorted_vocab: Vec<_> = filtered_vocabulary.iter().collect();
182            sorted_vocab.sort_by(|a, b| {
183                document_frequencies[a.0]
184                    .cmp(&document_frequencies[b.0])
185                    .reverse()
186            });
187
188            let mut new_vocabulary = HashMap::new();
189            for (term, _) in sorted_vocab.iter().take(self.max_features) {
190                new_vocabulary.insert(term.to_string(), new_vocabulary.len());
191            }
192            filtered_vocabulary = new_vocabulary;
193        }
194
195        // Compute IDF values
196        let vocab_size = filtered_vocabulary.len();
197        let mut idf_values = Array1::zeros(vocab_size);
198
199        if self.use_tf_idf {
200            for (term, &idx) in &filtered_vocabulary {
201                let df = document_frequencies.get(term).unwrap_or(&0);
202                idf_values[idx] = (n_documents as f64 / (*df as f64 + 1.0)).ln() + 1.0;
203            }
204        } else {
205            idf_values.fill(1.0);
206        }
207
208        // Generate random weights for kernel approximation
209        let mut rng = RealStdRng::from_seed(thread_rng().gen());
210        let normal = RandNormal::new(0.0, 1.0).unwrap();
211
212        let mut random_weights = Array2::zeros((self.n_components, vocab_size));
213        for i in 0..self.n_components {
214            for j in 0..vocab_size {
215                random_weights[[i, j]] = rng.sample(normal);
216            }
217        }
218
219        Ok(FittedTextKernelApproximation {
220            n_components: self.n_components,
221            max_features: self.max_features,
222            ngram_range: self.ngram_range,
223            min_df: self.min_df,
224            max_df: self.max_df,
225            use_tf_idf: self.use_tf_idf,
226            use_hashing: self.use_hashing,
227            sublinear_tf: self.sublinear_tf,
228            vocabulary: filtered_vocabulary,
229            idf_values,
230            random_weights,
231        })
232    }
233}
234
235impl FittedTextKernelApproximation {
236    fn tokenize(&self, text: &str) -> Vec<String> {
237        text.to_lowercase()
238            .split_whitespace()
239            .map(|s| s.trim_matches(|c: char| !c.is_alphanumeric()))
240            .filter(|s| !s.is_empty())
241            .map(|s| s.to_string())
242            .collect()
243    }
244
245    fn extract_ngrams(&self, tokens: &[String]) -> Vec<String> {
246        let mut ngrams = Vec::new();
247
248        for n in self.ngram_range.0..=self.ngram_range.1 {
249            for i in 0..=tokens.len().saturating_sub(n) {
250                if i + n <= tokens.len() {
251                    let ngram = tokens[i..i + n].join(" ");
252                    ngrams.push(ngram);
253                }
254            }
255        }
256
257        ngrams
258    }
259}
260
261impl Transform<Vec<String>, Array2<f64>> for FittedTextKernelApproximation {
262    fn transform(&self, documents: &Vec<String>) -> Result<Array2<f64>> {
263        let n_documents = documents.len();
264        let vocab_size = self.vocabulary.len();
265
266        // Convert documents to TF-IDF vectors
267        let mut tf_idf_matrix = Array2::zeros((n_documents, vocab_size));
268
269        for (doc_idx, doc) in documents.iter().enumerate() {
270            let tokens = self.tokenize(doc);
271            let ngrams = self.extract_ngrams(&tokens);
272
273            // Count term frequencies
274            let mut term_counts = HashMap::new();
275            for ngram in ngrams {
276                *term_counts.entry(ngram).or_insert(0) += 1;
277            }
278
279            // Compute TF-IDF
280            for (term, &count) in &term_counts {
281                if let Some(&vocab_idx) = self.vocabulary.get(term) {
282                    let tf = if self.sublinear_tf {
283                        1.0 + (count as f64).ln()
284                    } else {
285                        count as f64
286                    };
287
288                    let tf_idf = tf * self.idf_values[vocab_idx];
289                    tf_idf_matrix[[doc_idx, vocab_idx]] = tf_idf;
290                }
291            }
292        }
293
294        // Apply random projection for kernel approximation
295        let mut result = Array2::zeros((n_documents, self.n_components));
296
297        for i in 0..n_documents {
298            for j in 0..self.n_components {
299                let mut dot_product = 0.0;
300                for k in 0..vocab_size {
301                    dot_product += tf_idf_matrix[[i, k]] * self.random_weights[[j, k]];
302                }
303                result[[i, j]] = dot_product;
304            }
305        }
306
307        Ok(result)
308    }
309}
310
311/// Semantic kernel approximation using word embeddings and similarity measures
312#[derive(Debug, Clone, Serialize, Deserialize)]
313/// SemanticKernelApproximation
314pub struct SemanticKernelApproximation {
315    /// n_components
316    pub n_components: usize,
317    /// embedding_dim
318    pub embedding_dim: usize,
319    /// similarity_measure
320    pub similarity_measure: SimilarityMeasure,
321    /// aggregation_method
322    pub aggregation_method: AggregationMethod,
323    /// use_attention
324    pub use_attention: bool,
325}
326
327#[derive(Debug, Clone, Serialize, Deserialize)]
328/// SimilarityMeasure
329pub enum SimilarityMeasure {
330    /// Cosine
331    Cosine,
332    /// Euclidean
333    Euclidean,
334    /// Manhattan
335    Manhattan,
336    /// Dot
337    Dot,
338}
339
340#[derive(Debug, Clone, Serialize, Deserialize)]
341/// AggregationMethod
342pub enum AggregationMethod {
343    /// Mean
344    Mean,
345    /// Max
346    Max,
347    /// Sum
348    Sum,
349    /// AttentionWeighted
350    AttentionWeighted,
351}
352
353impl SemanticKernelApproximation {
354    pub fn new(n_components: usize, embedding_dim: usize) -> Self {
355        Self {
356            n_components,
357            embedding_dim,
358            similarity_measure: SimilarityMeasure::Cosine,
359            aggregation_method: AggregationMethod::Mean,
360            use_attention: false,
361        }
362    }
363
364    pub fn similarity_measure(mut self, measure: SimilarityMeasure) -> Self {
365        self.similarity_measure = measure;
366        self
367    }
368
369    pub fn aggregation_method(mut self, method: AggregationMethod) -> Self {
370        self.aggregation_method = method;
371        self
372    }
373
374    pub fn use_attention(mut self, use_attention: bool) -> Self {
375        self.use_attention = use_attention;
376        self
377    }
378
379    fn compute_similarity(&self, vec1: &ArrayView1<f64>, vec2: &ArrayView1<f64>) -> f64 {
380        match self.similarity_measure {
381            SimilarityMeasure::Cosine => {
382                let dot = vec1.dot(vec2);
383                let norm1 = vec1.dot(vec1).sqrt();
384                let norm2 = vec2.dot(vec2).sqrt();
385                if norm1 > 0.0 && norm2 > 0.0 {
386                    dot / (norm1 * norm2)
387                } else {
388                    0.0
389                }
390            }
391            SimilarityMeasure::Euclidean => {
392                let diff = vec1 - vec2;
393                -diff.dot(&diff).sqrt()
394            }
395            SimilarityMeasure::Manhattan => {
396                let diff = vec1 - vec2;
397                -diff.mapv(|x| x.abs()).sum()
398            }
399            SimilarityMeasure::Dot => vec1.dot(vec2),
400        }
401    }
402
403    fn aggregate_embeddings(&self, embeddings: &Array2<f64>) -> Array1<f64> {
404        match self.aggregation_method {
405            AggregationMethod::Mean => embeddings.mean_axis(Axis(0)).unwrap(),
406            AggregationMethod::Max => {
407                let mut result = Array1::zeros(embeddings.ncols());
408                for i in 0..embeddings.ncols() {
409                    let col = embeddings.column(i);
410                    result[i] = col.fold(f64::NEG_INFINITY, |acc, &x| acc.max(x));
411                }
412                result
413            }
414            AggregationMethod::Sum => embeddings.sum_axis(Axis(0)),
415            AggregationMethod::AttentionWeighted => {
416                // Simplified attention mechanism
417                let n_tokens = embeddings.nrows();
418                let mut attention_weights = Array1::zeros(n_tokens);
419
420                for i in 0..n_tokens {
421                    let token_embedding = embeddings.row(i);
422                    attention_weights[i] = token_embedding.dot(&token_embedding).sqrt();
423                }
424
425                // Softmax
426                let max_weight = attention_weights.fold(f64::NEG_INFINITY, |acc, &x| acc.max(x));
427                attention_weights.mapv_inplace(|x| (x - max_weight).exp());
428                let sum_weights = attention_weights.sum();
429                if sum_weights > 0.0 {
430                    attention_weights /= sum_weights;
431                }
432
433                // Weighted average
434                let mut result = Array1::zeros(embeddings.ncols());
435                for i in 0..n_tokens {
436                    let token_embedding = embeddings.row(i);
437                    for j in 0..embeddings.ncols() {
438                        result[j] += attention_weights[i] * token_embedding[j];
439                    }
440                }
441                result
442            }
443        }
444    }
445}
446
447#[derive(Debug, Clone, Serialize, Deserialize)]
448/// FittedSemanticKernelApproximation
449pub struct FittedSemanticKernelApproximation {
450    /// n_components
451    pub n_components: usize,
452    /// embedding_dim
453    pub embedding_dim: usize,
454    /// similarity_measure
455    pub similarity_measure: SimilarityMeasure,
456    /// aggregation_method
457    pub aggregation_method: AggregationMethod,
458    /// use_attention
459    pub use_attention: bool,
460    /// word_embeddings
461    pub word_embeddings: HashMap<String, Array1<f64>>,
462    /// projection_matrix
463    pub projection_matrix: Array2<f64>,
464}
465
466impl Fit<Vec<String>, ()> for SemanticKernelApproximation {
467    type Fitted = FittedSemanticKernelApproximation;
468
469    fn fit(self, documents: &Vec<String>, _y: &()) -> Result<Self::Fitted> {
470        let mut rng = RealStdRng::from_seed(thread_rng().gen());
471        let normal = RandNormal::new(0.0, 1.0).unwrap();
472
473        // Generate random word embeddings (in practice, these would be pre-trained)
474        let mut word_embeddings = HashMap::new();
475        let mut vocabulary = HashSet::new();
476
477        for doc in documents {
478            let tokens: Vec<String> = doc
479                .to_lowercase()
480                .split_whitespace()
481                .map(|s| s.to_string())
482                .collect();
483
484            for token in tokens {
485                vocabulary.insert(token);
486            }
487        }
488
489        for word in vocabulary {
490            let embedding = Array1::from_vec(
491                (0..self.embedding_dim)
492                    .map(|_| rng.sample(normal))
493                    .collect(),
494            );
495            word_embeddings.insert(word, embedding);
496        }
497
498        // Generate projection matrix for kernel approximation
499        let mut projection_matrix = Array2::zeros((self.n_components, self.embedding_dim));
500        for i in 0..self.n_components {
501            for j in 0..self.embedding_dim {
502                projection_matrix[[i, j]] = rng.sample(normal);
503            }
504        }
505
506        Ok(FittedSemanticKernelApproximation {
507            n_components: self.n_components,
508            embedding_dim: self.embedding_dim,
509            similarity_measure: self.similarity_measure,
510            aggregation_method: self.aggregation_method,
511            use_attention: self.use_attention,
512            word_embeddings,
513            projection_matrix,
514        })
515    }
516}
517
518impl FittedSemanticKernelApproximation {
519    fn aggregate_embeddings(&self, embeddings: &Array2<f64>) -> Array1<f64> {
520        match self.aggregation_method {
521            AggregationMethod::Mean => embeddings.mean_axis(Axis(0)).unwrap(),
522            AggregationMethod::Max => {
523                let mut result = Array1::zeros(embeddings.ncols());
524                for i in 0..embeddings.ncols() {
525                    let col = embeddings.column(i);
526                    result[i] = col.fold(f64::NEG_INFINITY, |acc, &x| acc.max(x));
527                }
528                result
529            }
530            AggregationMethod::Sum => embeddings.sum_axis(Axis(0)),
531            AggregationMethod::AttentionWeighted => {
532                // Simplified attention mechanism
533                let n_tokens = embeddings.nrows();
534                let mut attention_weights = Array1::zeros(n_tokens);
535
536                for i in 0..n_tokens {
537                    let token_embedding = embeddings.row(i);
538                    attention_weights[i] = token_embedding.dot(&token_embedding).sqrt();
539                }
540
541                // Softmax
542                let max_weight = attention_weights.fold(f64::NEG_INFINITY, |acc, &x| acc.max(x));
543                attention_weights.mapv_inplace(|x| (x - max_weight).exp());
544                let sum_weights = attention_weights.sum();
545                if sum_weights > 0.0 {
546                    attention_weights /= sum_weights;
547                }
548
549                // Weighted average
550                let mut result = Array1::zeros(embeddings.ncols());
551                for i in 0..n_tokens {
552                    let token_embedding = embeddings.row(i);
553                    for j in 0..embeddings.ncols() {
554                        result[j] += attention_weights[i] * token_embedding[j];
555                    }
556                }
557                result
558            }
559        }
560    }
561}
562
563impl Transform<Vec<String>, Array2<f64>> for FittedSemanticKernelApproximation {
564    fn transform(&self, documents: &Vec<String>) -> Result<Array2<f64>> {
565        let n_documents = documents.len();
566        let mut result = Array2::zeros((n_documents, self.n_components));
567
568        for (doc_idx, doc) in documents.iter().enumerate() {
569            let tokens: Vec<String> = doc
570                .to_lowercase()
571                .split_whitespace()
572                .map(|s| s.to_string())
573                .collect();
574
575            // Get embeddings for tokens
576            let mut token_embeddings = Vec::new();
577            for token in tokens {
578                if let Some(embedding) = self.word_embeddings.get(&token) {
579                    token_embeddings.push(embedding.clone());
580                }
581            }
582
583            if !token_embeddings.is_empty() {
584                // Convert to Array2
585                let embeddings_matrix = Array2::from_shape_vec(
586                    (token_embeddings.len(), self.embedding_dim),
587                    token_embeddings
588                        .iter()
589                        .flat_map(|e| e.iter().cloned())
590                        .collect(),
591                )?;
592
593                // Aggregate embeddings
594                let doc_embedding = self.aggregate_embeddings(&embeddings_matrix);
595
596                // Apply projection for kernel approximation
597                for i in 0..self.n_components {
598                    let projected = self.projection_matrix.row(i).dot(&doc_embedding);
599                    result[[doc_idx, i]] = projected.tanh(); // Non-linear activation
600                }
601            }
602        }
603
604        Ok(result)
605    }
606}
607
608/// Syntactic kernel approximation using parse trees and grammatical features
609#[derive(Debug, Clone, Serialize, Deserialize)]
610/// SyntacticKernelApproximation
611pub struct SyntacticKernelApproximation {
612    /// n_components
613    pub n_components: usize,
614    /// max_tree_depth
615    pub max_tree_depth: usize,
616    /// use_pos_tags
617    pub use_pos_tags: bool,
618    /// use_dependencies
619    pub use_dependencies: bool,
620    /// tree_kernel_type
621    pub tree_kernel_type: TreeKernelType,
622}
623
624#[derive(Debug, Clone, Serialize, Deserialize)]
625/// TreeKernelType
626pub enum TreeKernelType {
627    /// Subset
628    Subset,
629    /// Subsequence
630    Subsequence,
631    /// Partial
632    Partial,
633}
634
635impl SyntacticKernelApproximation {
636    pub fn new(n_components: usize) -> Self {
637        Self {
638            n_components,
639            max_tree_depth: 10,
640            use_pos_tags: true,
641            use_dependencies: true,
642            tree_kernel_type: TreeKernelType::Subset,
643        }
644    }
645
646    pub fn max_tree_depth(mut self, depth: usize) -> Self {
647        self.max_tree_depth = depth;
648        self
649    }
650
651    pub fn use_pos_tags(mut self, use_pos: bool) -> Self {
652        self.use_pos_tags = use_pos;
653        self
654    }
655
656    pub fn use_dependencies(mut self, use_deps: bool) -> Self {
657        self.use_dependencies = use_deps;
658        self
659    }
660
661    pub fn tree_kernel_type(mut self, kernel_type: TreeKernelType) -> Self {
662        self.tree_kernel_type = kernel_type;
663        self
664    }
665
666    fn extract_syntactic_features(&self, text: &str) -> Vec<String> {
667        let mut features = Vec::new();
668
669        // Simplified syntactic feature extraction
670        let tokens: Vec<&str> = text.split_whitespace().collect();
671
672        // Add POS tag features (simplified)
673        if self.use_pos_tags {
674            for token in &tokens {
675                let pos_tag = self.simple_pos_tag(token);
676                features.push(format!("POS_{}", pos_tag));
677            }
678        }
679
680        // Add dependency features (simplified)
681        if self.use_dependencies {
682            for i in 0..tokens.len() {
683                if i > 0 {
684                    features.push(format!("DEP_{}_{}", tokens[i - 1], tokens[i]));
685                }
686            }
687        }
688
689        // Add n-gram features
690        for n in 1..=3 {
691            for i in 0..=tokens.len().saturating_sub(n) {
692                if i + n <= tokens.len() {
693                    let ngram = tokens[i..i + n].join("_");
694                    features.push(format!("NGRAM_{}", ngram));
695                }
696            }
697        }
698
699        features
700    }
701
702    fn simple_pos_tag(&self, token: &str) -> String {
703        // Very simplified POS tagging
704        let token_lower = token.to_lowercase();
705
706        if token_lower.ends_with("ing") {
707            "VBG".to_string()
708        } else if token_lower.ends_with("ed") {
709            "VBD".to_string()
710        } else if token_lower.ends_with("ly") {
711            "RB".to_string()
712        } else if token_lower.ends_with("s") && !token_lower.ends_with("ss") {
713            "NNS".to_string()
714        } else if token.chars().all(|c| c.is_alphabetic() && c.is_uppercase()) {
715            "NNP".to_string()
716        } else if token.chars().all(|c| c.is_alphabetic()) {
717            "NN".to_string()
718        } else if token.chars().all(|c| c.is_numeric()) {
719            "CD".to_string()
720        } else {
721            "UNK".to_string()
722        }
723    }
724}
725
726#[derive(Debug, Clone, Serialize, Deserialize)]
727/// FittedSyntacticKernelApproximation
728pub struct FittedSyntacticKernelApproximation {
729    /// n_components
730    pub n_components: usize,
731    /// max_tree_depth
732    pub max_tree_depth: usize,
733    /// use_pos_tags
734    pub use_pos_tags: bool,
735    /// use_dependencies
736    pub use_dependencies: bool,
737    /// tree_kernel_type
738    pub tree_kernel_type: TreeKernelType,
739    /// feature_vocabulary
740    pub feature_vocabulary: HashMap<String, usize>,
741    /// random_weights
742    pub random_weights: Array2<f64>,
743}
744
745impl Fit<Vec<String>, ()> for SyntacticKernelApproximation {
746    type Fitted = FittedSyntacticKernelApproximation;
747
748    fn fit(self, documents: &Vec<String>, _y: &()) -> Result<Self::Fitted> {
749        let mut feature_vocabulary = HashMap::new();
750
751        // Extract syntactic features from all documents
752        for doc in documents {
753            let features = self.extract_syntactic_features(doc);
754            for feature in features {
755                if !feature_vocabulary.contains_key(&feature) {
756                    feature_vocabulary.insert(feature, feature_vocabulary.len());
757                }
758            }
759        }
760
761        // Generate random weights
762        let mut rng = RealStdRng::from_seed(thread_rng().gen());
763        let normal = RandNormal::new(0.0, 1.0).unwrap();
764
765        let vocab_size = feature_vocabulary.len();
766        let mut random_weights = Array2::zeros((self.n_components, vocab_size));
767
768        for i in 0..self.n_components {
769            for j in 0..vocab_size {
770                random_weights[[i, j]] = rng.sample(normal);
771            }
772        }
773
774        Ok(FittedSyntacticKernelApproximation {
775            n_components: self.n_components,
776            max_tree_depth: self.max_tree_depth,
777            use_pos_tags: self.use_pos_tags,
778            use_dependencies: self.use_dependencies,
779            tree_kernel_type: self.tree_kernel_type,
780            feature_vocabulary,
781            random_weights,
782        })
783    }
784}
785
786impl FittedSyntacticKernelApproximation {
787    fn extract_syntactic_features(&self, text: &str) -> Vec<String> {
788        let mut features = Vec::new();
789
790        // Simplified syntactic feature extraction
791        let tokens: Vec<&str> = text.split_whitespace().collect();
792
793        // Add POS tag features (simplified)
794        if self.use_pos_tags {
795            for token in &tokens {
796                let pos_tag = self.simple_pos_tag(token);
797                features.push(format!("POS_{}", pos_tag));
798            }
799        }
800
801        // Add dependency features (simplified)
802        if self.use_dependencies {
803            for i in 0..tokens.len() {
804                if i > 0 {
805                    features.push(format!("DEP_{}_{}", tokens[i - 1], tokens[i]));
806                }
807            }
808        }
809
810        // Add n-gram features
811        for n in 1..=3 {
812            for i in 0..=tokens.len().saturating_sub(n) {
813                if i + n <= tokens.len() {
814                    let ngram = tokens[i..i + n].join("_");
815                    features.push(format!("NGRAM_{}", ngram));
816                }
817            }
818        }
819
820        features
821    }
822
823    fn simple_pos_tag(&self, token: &str) -> String {
824        // Very simplified POS tagging
825        let token_lower = token.to_lowercase();
826
827        if token_lower.ends_with("ing") {
828            "VBG".to_string()
829        } else if token_lower.ends_with("ed") {
830            "VBD".to_string()
831        } else if token_lower.ends_with("ly") {
832            "RB".to_string()
833        } else if token_lower.ends_with("s") && !token_lower.ends_with("ss") {
834            "NNS".to_string()
835        } else if token.chars().all(|c| c.is_alphabetic() && c.is_uppercase()) {
836            "NNP".to_string()
837        } else if token.chars().all(|c| c.is_alphabetic()) {
838            "NN".to_string()
839        } else if token.chars().all(|c| c.is_numeric()) {
840            "CD".to_string()
841        } else {
842            "UNK".to_string()
843        }
844    }
845}
846
847impl Transform<Vec<String>, Array2<f64>> for FittedSyntacticKernelApproximation {
848    fn transform(&self, documents: &Vec<String>) -> Result<Array2<f64>> {
849        let n_documents = documents.len();
850        let vocab_size = self.feature_vocabulary.len();
851
852        // Convert documents to feature vectors
853        let mut feature_matrix = Array2::zeros((n_documents, vocab_size));
854
855        for (doc_idx, doc) in documents.iter().enumerate() {
856            let features = self.extract_syntactic_features(doc);
857            let mut feature_counts = HashMap::new();
858
859            for feature in features {
860                *feature_counts.entry(feature).or_insert(0) += 1;
861            }
862
863            for (feature, count) in feature_counts {
864                if let Some(&vocab_idx) = self.feature_vocabulary.get(&feature) {
865                    feature_matrix[[doc_idx, vocab_idx]] = count as f64;
866                }
867            }
868        }
869
870        // Apply random projection
871        let mut result = Array2::zeros((n_documents, self.n_components));
872
873        for i in 0..n_documents {
874            for j in 0..self.n_components {
875                let mut dot_product = 0.0;
876                for k in 0..vocab_size {
877                    dot_product += feature_matrix[[i, k]] * self.random_weights[[j, k]];
878                }
879                result[[i, j]] = dot_product.tanh();
880            }
881        }
882
883        Ok(result)
884    }
885}
886
887/// Document kernel approximation for document-level features
888#[derive(Debug, Clone, Serialize, Deserialize)]
889/// DocumentKernelApproximation
890pub struct DocumentKernelApproximation {
891    /// n_components
892    pub n_components: usize,
893    /// use_topic_features
894    pub use_topic_features: bool,
895    /// use_readability_features
896    pub use_readability_features: bool,
897    /// use_stylometric_features
898    pub use_stylometric_features: bool,
899    /// n_topics
900    pub n_topics: usize,
901}
902
903impl DocumentKernelApproximation {
904    pub fn new(n_components: usize) -> Self {
905        Self {
906            n_components,
907            use_topic_features: true,
908            use_readability_features: true,
909            use_stylometric_features: true,
910            n_topics: 10,
911        }
912    }
913
914    pub fn use_topic_features(mut self, use_topics: bool) -> Self {
915        self.use_topic_features = use_topics;
916        self
917    }
918
919    pub fn use_readability_features(mut self, use_readability: bool) -> Self {
920        self.use_readability_features = use_readability;
921        self
922    }
923
924    pub fn use_stylometric_features(mut self, use_stylometric: bool) -> Self {
925        self.use_stylometric_features = use_stylometric;
926        self
927    }
928
929    pub fn n_topics(mut self, n_topics: usize) -> Self {
930        self.n_topics = n_topics;
931        self
932    }
933
934    fn extract_document_features(&self, text: &str) -> Vec<f64> {
935        let mut features = Vec::new();
936
937        let sentences: Vec<&str> = text.split(&['.', '!', '?'][..]).collect();
938        let words: Vec<&str> = text.split_whitespace().collect();
939        let characters: Vec<char> = text.chars().collect();
940
941        if self.use_readability_features {
942            // Readability features
943            let avg_sentence_length = if !sentences.is_empty() {
944                words.len() as f64 / sentences.len() as f64
945            } else {
946                0.0
947            };
948
949            let avg_word_length = if !words.is_empty() {
950                characters.len() as f64 / words.len() as f64
951            } else {
952                0.0
953            };
954
955            features.push(avg_sentence_length);
956            features.push(avg_word_length);
957            features.push(sentences.len() as f64);
958            features.push(words.len() as f64);
959        }
960
961        if self.use_stylometric_features {
962            // Stylometric features
963            let punctuation_count = characters
964                .iter()
965                .filter(|c| c.is_ascii_punctuation())
966                .count();
967            let uppercase_count = characters.iter().filter(|c| c.is_uppercase()).count();
968            let digit_count = characters.iter().filter(|c| c.is_numeric()).count();
969
970            features.push(punctuation_count as f64 / characters.len() as f64);
971            features.push(uppercase_count as f64 / characters.len() as f64);
972            features.push(digit_count as f64 / characters.len() as f64);
973
974            // Type-token ratio
975            let unique_words: HashSet<&str> = words.iter().cloned().collect();
976            let ttr = if !words.is_empty() {
977                unique_words.len() as f64 / words.len() as f64
978            } else {
979                0.0
980            };
981            features.push(ttr);
982        }
983
984        if self.use_topic_features {
985            // Simplified topic features (in practice, use LDA or similar)
986            let mut topic_features = vec![0.0; self.n_topics];
987            let mut hasher = DefaultHasher::new();
988            text.hash(&mut hasher);
989            let hash = hasher.finish();
990
991            for i in 0..self.n_topics {
992                topic_features[i] = ((hash + i as u64) % 1000) as f64 / 1000.0;
993            }
994
995            features.extend(topic_features);
996        }
997
998        features
999    }
1000}
1001
1002#[derive(Debug, Clone, Serialize, Deserialize)]
1003/// FittedDocumentKernelApproximation
1004pub struct FittedDocumentKernelApproximation {
1005    /// n_components
1006    pub n_components: usize,
1007    /// use_topic_features
1008    pub use_topic_features: bool,
1009    /// use_readability_features
1010    pub use_readability_features: bool,
1011    /// use_stylometric_features
1012    pub use_stylometric_features: bool,
1013    /// n_topics
1014    pub n_topics: usize,
1015    /// feature_dim
1016    pub feature_dim: usize,
1017    /// random_weights
1018    pub random_weights: Array2<f64>,
1019}
1020
1021impl Fit<Vec<String>, ()> for DocumentKernelApproximation {
1022    type Fitted = FittedDocumentKernelApproximation;
1023
1024    fn fit(self, documents: &Vec<String>, _y: &()) -> Result<Self::Fitted> {
1025        // Determine feature dimension
1026        let sample_features = self.extract_document_features(&documents[0]);
1027        let feature_dim = sample_features.len();
1028
1029        // Generate random weights
1030        let mut rng = RealStdRng::from_seed(thread_rng().gen());
1031        let normal = RandNormal::new(0.0, 1.0).unwrap();
1032
1033        let mut random_weights = Array2::zeros((self.n_components, feature_dim));
1034        for i in 0..self.n_components {
1035            for j in 0..feature_dim {
1036                random_weights[[i, j]] = rng.sample(normal);
1037            }
1038        }
1039
1040        Ok(FittedDocumentKernelApproximation {
1041            n_components: self.n_components,
1042            use_topic_features: self.use_topic_features,
1043            use_readability_features: self.use_readability_features,
1044            use_stylometric_features: self.use_stylometric_features,
1045            n_topics: self.n_topics,
1046            feature_dim,
1047            random_weights,
1048        })
1049    }
1050}
1051
1052impl FittedDocumentKernelApproximation {
1053    fn extract_document_features(&self, text: &str) -> Vec<f64> {
1054        let mut features = Vec::new();
1055
1056        let sentences: Vec<&str> = text.split(&['.', '!', '?'][..]).collect();
1057        let words: Vec<&str> = text.split_whitespace().collect();
1058        let characters: Vec<char> = text.chars().collect();
1059
1060        if self.use_readability_features {
1061            // Readability features
1062            let avg_sentence_length = if !sentences.is_empty() {
1063                words.len() as f64 / sentences.len() as f64
1064            } else {
1065                0.0
1066            };
1067
1068            let avg_word_length = if !words.is_empty() {
1069                characters.len() as f64 / words.len() as f64
1070            } else {
1071                0.0
1072            };
1073
1074            features.push(avg_sentence_length);
1075            features.push(avg_word_length);
1076            features.push(sentences.len() as f64);
1077            features.push(words.len() as f64);
1078        }
1079
1080        if self.use_stylometric_features {
1081            // Stylometric features
1082            let punctuation_count = characters
1083                .iter()
1084                .filter(|c| c.is_ascii_punctuation())
1085                .count();
1086            let uppercase_count = characters.iter().filter(|c| c.is_uppercase()).count();
1087            let digit_count = characters.iter().filter(|c| c.is_numeric()).count();
1088
1089            features.push(punctuation_count as f64 / characters.len() as f64);
1090            features.push(uppercase_count as f64 / characters.len() as f64);
1091            features.push(digit_count as f64 / characters.len() as f64);
1092
1093            // Type-token ratio
1094            let unique_words: HashSet<&str> = words.iter().cloned().collect();
1095            let ttr = if !words.is_empty() {
1096                unique_words.len() as f64 / words.len() as f64
1097            } else {
1098                0.0
1099            };
1100            features.push(ttr);
1101        }
1102
1103        if self.use_topic_features {
1104            // Simplified topic features (in practice, use LDA or similar)
1105            let mut topic_features = vec![0.0; self.n_topics];
1106            let mut hasher = DefaultHasher::new();
1107            text.hash(&mut hasher);
1108            let hash = hasher.finish();
1109
1110            for i in 0..self.n_topics {
1111                topic_features[i] = ((hash + i as u64) % 1000) as f64 / 1000.0;
1112            }
1113
1114            features.extend(topic_features);
1115        }
1116
1117        features
1118    }
1119}
1120
1121impl Transform<Vec<String>, Array2<f64>> for FittedDocumentKernelApproximation {
1122    fn transform(&self, documents: &Vec<String>) -> Result<Array2<f64>> {
1123        let n_documents = documents.len();
1124        let mut result = Array2::zeros((n_documents, self.n_components));
1125
1126        for (doc_idx, doc) in documents.iter().enumerate() {
1127            let features = self.extract_document_features(doc);
1128            let feature_array = Array1::from_vec(features);
1129
1130            for i in 0..self.n_components {
1131                let projected = self.random_weights.row(i).dot(&feature_array);
1132                result[[doc_idx, i]] = projected.tanh();
1133            }
1134        }
1135
1136        Ok(result)
1137    }
1138}
1139
1140#[allow(non_snake_case)]
1141#[cfg(test)]
1142mod tests {
1143    use super::*;
1144
1145    #[test]
1146    fn test_text_kernel_approximation() {
1147        let docs = vec![
1148            "This is a test document".to_string(),
1149            "Another test document here".to_string(),
1150            "Third document for testing".to_string(),
1151        ];
1152
1153        let text_kernel = TextKernelApproximation::new(50);
1154        let fitted = text_kernel.fit(&docs, &()).unwrap();
1155        let transformed = fitted.transform(&docs).unwrap();
1156
1157        assert_eq!(transformed.shape()[0], 3);
1158        assert_eq!(transformed.shape()[1], 50);
1159    }
1160
1161    #[test]
1162    fn test_semantic_kernel_approximation() {
1163        let docs = vec![
1164            "Semantic similarity test".to_string(),
1165            "Another semantic test".to_string(),
1166        ];
1167
1168        let semantic_kernel = SemanticKernelApproximation::new(30, 100);
1169        let fitted = semantic_kernel.fit(&docs, &()).unwrap();
1170        let transformed = fitted.transform(&docs).unwrap();
1171
1172        assert_eq!(transformed.shape()[0], 2);
1173        assert_eq!(transformed.shape()[1], 30);
1174    }
1175
1176    #[test]
1177    fn test_syntactic_kernel_approximation() {
1178        let docs = vec![
1179            "The cat sat on the mat".to_string(),
1180            "Dogs are running quickly".to_string(),
1181        ];
1182
1183        let syntactic_kernel = SyntacticKernelApproximation::new(40);
1184        let fitted = syntactic_kernel.fit(&docs, &()).unwrap();
1185        let transformed = fitted.transform(&docs).unwrap();
1186
1187        assert_eq!(transformed.shape()[0], 2);
1188        assert_eq!(transformed.shape()[1], 40);
1189    }
1190
1191    #[test]
1192    fn test_document_kernel_approximation() {
1193        let docs = vec![
1194            "This is a long document with multiple sentences. It contains various features."
1195                .to_string(),
1196            "Short doc.".to_string(),
1197        ];
1198
1199        let doc_kernel = DocumentKernelApproximation::new(25);
1200        let fitted = doc_kernel.fit(&docs, &()).unwrap();
1201        let transformed = fitted.transform(&docs).unwrap();
1202
1203        assert_eq!(transformed.shape()[0], 2);
1204        assert_eq!(transformed.shape()[1], 25);
1205    }
1206}