1use scirs2_core::ndarray::{Array1, Array2, ArrayView1, Axis};
7use scirs2_core::random::essentials::Normal as RandNormal;
8use scirs2_core::random::rngs::StdRng as RealStdRng;
9use scirs2_core::random::RngExt;
10use scirs2_core::random::{thread_rng, SeedableRng};
11use serde::{Deserialize, Serialize};
12use std::collections::hash_map::DefaultHasher;
13use std::collections::{HashMap, HashSet};
14use std::hash::{Hash, Hasher};
15
16use sklears_core::error::Result;
17use sklears_core::traits::{Fit, Transform};
18
19#[derive(Debug, Clone, Serialize, Deserialize)]
21pub struct TextKernelApproximation {
23 pub n_components: usize,
25 pub max_features: usize,
27 pub ngram_range: (usize, usize),
29 pub min_df: usize,
31 pub max_df: f64,
33 pub use_tf_idf: bool,
35 pub use_hashing: bool,
37 pub sublinear_tf: bool,
39}
40
41impl TextKernelApproximation {
42 pub fn new(n_components: usize) -> Self {
43 Self {
44 n_components,
45 max_features: 10000,
46 ngram_range: (1, 1),
47 min_df: 1,
48 max_df: 1.0,
49 use_tf_idf: true,
50 use_hashing: false,
51 sublinear_tf: false,
52 }
53 }
54
55 pub fn max_features(mut self, max_features: usize) -> Self {
56 self.max_features = max_features;
57 self
58 }
59
60 pub fn ngram_range(mut self, ngram_range: (usize, usize)) -> Self {
61 self.ngram_range = ngram_range;
62 self
63 }
64
65 pub fn min_df(mut self, min_df: usize) -> Self {
66 self.min_df = min_df;
67 self
68 }
69
70 pub fn max_df(mut self, max_df: f64) -> Self {
71 self.max_df = max_df;
72 self
73 }
74
75 pub fn use_tf_idf(mut self, use_tf_idf: bool) -> Self {
76 self.use_tf_idf = use_tf_idf;
77 self
78 }
79
80 pub fn use_hashing(mut self, use_hashing: bool) -> Self {
81 self.use_hashing = use_hashing;
82 self
83 }
84
85 pub fn sublinear_tf(mut self, sublinear_tf: bool) -> Self {
86 self.sublinear_tf = sublinear_tf;
87 self
88 }
89
90 fn tokenize(&self, text: &str) -> Vec<String> {
91 text.to_lowercase()
92 .split_whitespace()
93 .map(|s| s.trim_matches(|c: char| !c.is_alphanumeric()))
94 .filter(|s| !s.is_empty())
95 .map(|s| s.to_string())
96 .collect()
97 }
98
99 fn extract_ngrams(&self, tokens: &[String]) -> Vec<String> {
100 let mut ngrams = Vec::new();
101
102 for n in self.ngram_range.0..=self.ngram_range.1 {
103 for i in 0..=tokens.len().saturating_sub(n) {
104 if i + n <= tokens.len() {
105 let ngram = tokens[i..i + n].join(" ");
106 ngrams.push(ngram);
107 }
108 }
109 }
110
111 ngrams
112 }
113
114 fn hash_feature(&self, feature: &str) -> usize {
115 let mut hasher = DefaultHasher::new();
116 feature.hash(&mut hasher);
117 hasher.finish() as usize % self.max_features
118 }
119}
120
121#[derive(Debug, Clone, Serialize, Deserialize)]
122pub struct FittedTextKernelApproximation {
124 pub n_components: usize,
126 pub max_features: usize,
128 pub ngram_range: (usize, usize),
130 pub min_df: usize,
132 pub max_df: f64,
134 pub use_tf_idf: bool,
136 pub use_hashing: bool,
138 pub sublinear_tf: bool,
140 pub vocabulary: HashMap<String, usize>,
142 pub idf_values: Array1<f64>,
144 pub random_weights: Array2<f64>,
146}
147
148impl Fit<Vec<String>, ()> for TextKernelApproximation {
149 type Fitted = FittedTextKernelApproximation;
150
151 fn fit(self, documents: &Vec<String>, _y: &()) -> Result<Self::Fitted> {
152 let mut vocabulary = HashMap::new();
153 let mut document_frequencies = HashMap::new();
154 let n_documents = documents.len();
155
156 for doc in documents {
158 let tokens = self.tokenize(doc);
159 let ngrams = self.extract_ngrams(&tokens);
160 let unique_ngrams: HashSet<String> = ngrams.into_iter().collect();
161
162 for ngram in unique_ngrams {
163 *document_frequencies.entry(ngram.clone()).or_insert(0) += 1;
164 if !vocabulary.contains_key(&ngram) {
165 vocabulary.insert(ngram, vocabulary.len());
166 }
167 }
168 }
169
170 let mut filtered_vocabulary = HashMap::new();
172 for (term, &df) in &document_frequencies {
173 let df_ratio = df as f64 / n_documents as f64;
174 if df >= self.min_df && df_ratio <= self.max_df {
175 filtered_vocabulary.insert(term.clone(), filtered_vocabulary.len());
176 }
177 }
178
179 if filtered_vocabulary.len() > self.max_features {
181 let mut sorted_vocab: Vec<_> = filtered_vocabulary.iter().collect();
182 sorted_vocab.sort_by(|a, b| {
183 document_frequencies[a.0]
184 .cmp(&document_frequencies[b.0])
185 .reverse()
186 });
187
188 let mut new_vocabulary = HashMap::new();
189 for (term, _) in sorted_vocab.iter().take(self.max_features) {
190 new_vocabulary.insert(term.to_string(), new_vocabulary.len());
191 }
192 filtered_vocabulary = new_vocabulary;
193 }
194
195 let vocab_size = filtered_vocabulary.len();
197 let mut idf_values = Array1::zeros(vocab_size);
198
199 if self.use_tf_idf {
200 for (term, &idx) in &filtered_vocabulary {
201 let df = document_frequencies.get(term).unwrap_or(&0);
202 idf_values[idx] = (n_documents as f64 / (*df as f64 + 1.0)).ln() + 1.0;
203 }
204 } else {
205 idf_values.fill(1.0);
206 }
207
208 let mut rng = RealStdRng::from_seed(thread_rng().random());
210 let normal = RandNormal::new(0.0, 1.0).expect("operation should succeed");
211
212 let mut random_weights = Array2::zeros((self.n_components, vocab_size));
213 for i in 0..self.n_components {
214 for j in 0..vocab_size {
215 random_weights[[i, j]] = rng.sample(normal);
216 }
217 }
218
219 Ok(FittedTextKernelApproximation {
220 n_components: self.n_components,
221 max_features: self.max_features,
222 ngram_range: self.ngram_range,
223 min_df: self.min_df,
224 max_df: self.max_df,
225 use_tf_idf: self.use_tf_idf,
226 use_hashing: self.use_hashing,
227 sublinear_tf: self.sublinear_tf,
228 vocabulary: filtered_vocabulary,
229 idf_values,
230 random_weights,
231 })
232 }
233}
234
235impl FittedTextKernelApproximation {
236 fn tokenize(&self, text: &str) -> Vec<String> {
237 text.to_lowercase()
238 .split_whitespace()
239 .map(|s| s.trim_matches(|c: char| !c.is_alphanumeric()))
240 .filter(|s| !s.is_empty())
241 .map(|s| s.to_string())
242 .collect()
243 }
244
245 fn extract_ngrams(&self, tokens: &[String]) -> Vec<String> {
246 let mut ngrams = Vec::new();
247
248 for n in self.ngram_range.0..=self.ngram_range.1 {
249 for i in 0..=tokens.len().saturating_sub(n) {
250 if i + n <= tokens.len() {
251 let ngram = tokens[i..i + n].join(" ");
252 ngrams.push(ngram);
253 }
254 }
255 }
256
257 ngrams
258 }
259}
260
261impl Transform<Vec<String>, Array2<f64>> for FittedTextKernelApproximation {
262 fn transform(&self, documents: &Vec<String>) -> Result<Array2<f64>> {
263 let n_documents = documents.len();
264 let vocab_size = self.vocabulary.len();
265
266 let mut tf_idf_matrix = Array2::zeros((n_documents, vocab_size));
268
269 for (doc_idx, doc) in documents.iter().enumerate() {
270 let tokens = self.tokenize(doc);
271 let ngrams = self.extract_ngrams(&tokens);
272
273 let mut term_counts = HashMap::new();
275 for ngram in ngrams {
276 *term_counts.entry(ngram).or_insert(0) += 1;
277 }
278
279 for (term, &count) in &term_counts {
281 if let Some(&vocab_idx) = self.vocabulary.get(term) {
282 let tf = if self.sublinear_tf {
283 1.0 + (count as f64).ln()
284 } else {
285 count as f64
286 };
287
288 let tf_idf = tf * self.idf_values[vocab_idx];
289 tf_idf_matrix[[doc_idx, vocab_idx]] = tf_idf;
290 }
291 }
292 }
293
294 let mut result = Array2::zeros((n_documents, self.n_components));
296
297 for i in 0..n_documents {
298 for j in 0..self.n_components {
299 let mut dot_product = 0.0;
300 for k in 0..vocab_size {
301 dot_product += tf_idf_matrix[[i, k]] * self.random_weights[[j, k]];
302 }
303 result[[i, j]] = dot_product;
304 }
305 }
306
307 Ok(result)
308 }
309}
310
311#[derive(Debug, Clone, Serialize, Deserialize)]
313pub struct SemanticKernelApproximation {
315 pub n_components: usize,
317 pub embedding_dim: usize,
319 pub similarity_measure: SimilarityMeasure,
321 pub aggregation_method: AggregationMethod,
323 pub use_attention: bool,
325}
326
327#[derive(Debug, Clone, Serialize, Deserialize)]
328pub enum SimilarityMeasure {
330 Cosine,
332 Euclidean,
334 Manhattan,
336 Dot,
338}
339
340#[derive(Debug, Clone, Serialize, Deserialize)]
341pub enum AggregationMethod {
343 Mean,
345 Max,
347 Sum,
349 AttentionWeighted,
351}
352
353impl SemanticKernelApproximation {
354 pub fn new(n_components: usize, embedding_dim: usize) -> Self {
355 Self {
356 n_components,
357 embedding_dim,
358 similarity_measure: SimilarityMeasure::Cosine,
359 aggregation_method: AggregationMethod::Mean,
360 use_attention: false,
361 }
362 }
363
364 pub fn similarity_measure(mut self, measure: SimilarityMeasure) -> Self {
365 self.similarity_measure = measure;
366 self
367 }
368
369 pub fn aggregation_method(mut self, method: AggregationMethod) -> Self {
370 self.aggregation_method = method;
371 self
372 }
373
374 pub fn use_attention(mut self, use_attention: bool) -> Self {
375 self.use_attention = use_attention;
376 self
377 }
378
379 fn compute_similarity(&self, vec1: &ArrayView1<f64>, vec2: &ArrayView1<f64>) -> f64 {
380 match self.similarity_measure {
381 SimilarityMeasure::Cosine => {
382 let dot = vec1.dot(vec2);
383 let norm1 = vec1.dot(vec1).sqrt();
384 let norm2 = vec2.dot(vec2).sqrt();
385 if norm1 > 0.0 && norm2 > 0.0 {
386 dot / (norm1 * norm2)
387 } else {
388 0.0
389 }
390 }
391 SimilarityMeasure::Euclidean => {
392 let diff = vec1 - vec2;
393 -diff.dot(&diff).sqrt()
394 }
395 SimilarityMeasure::Manhattan => {
396 let diff = vec1 - vec2;
397 -diff.mapv(|x| x.abs()).sum()
398 }
399 SimilarityMeasure::Dot => vec1.dot(vec2),
400 }
401 }
402
403 fn aggregate_embeddings(&self, embeddings: &Array2<f64>) -> Array1<f64> {
404 match self.aggregation_method {
405 AggregationMethod::Mean => embeddings
406 .mean_axis(Axis(0))
407 .expect("operation should succeed"),
408 AggregationMethod::Max => {
409 let mut result = Array1::zeros(embeddings.ncols());
410 for i in 0..embeddings.ncols() {
411 let col = embeddings.column(i);
412 result[i] = col.fold(f64::NEG_INFINITY, |acc, &x| acc.max(x));
413 }
414 result
415 }
416 AggregationMethod::Sum => embeddings.sum_axis(Axis(0)),
417 AggregationMethod::AttentionWeighted => {
418 let n_tokens = embeddings.nrows();
420 let mut attention_weights = Array1::zeros(n_tokens);
421
422 for i in 0..n_tokens {
423 let token_embedding = embeddings.row(i);
424 attention_weights[i] = token_embedding.dot(&token_embedding).sqrt();
425 }
426
427 let max_weight = attention_weights.fold(f64::NEG_INFINITY, |acc, &x| acc.max(x));
429 attention_weights.mapv_inplace(|x| (x - max_weight).exp());
430 let sum_weights = attention_weights.sum();
431 if sum_weights > 0.0 {
432 attention_weights /= sum_weights;
433 }
434
435 let mut result = Array1::zeros(embeddings.ncols());
437 for i in 0..n_tokens {
438 let token_embedding = embeddings.row(i);
439 for j in 0..embeddings.ncols() {
440 result[j] += attention_weights[i] * token_embedding[j];
441 }
442 }
443 result
444 }
445 }
446 }
447}
448
449#[derive(Debug, Clone, Serialize, Deserialize)]
450pub struct FittedSemanticKernelApproximation {
452 pub n_components: usize,
454 pub embedding_dim: usize,
456 pub similarity_measure: SimilarityMeasure,
458 pub aggregation_method: AggregationMethod,
460 pub use_attention: bool,
462 pub word_embeddings: HashMap<String, Array1<f64>>,
464 pub projection_matrix: Array2<f64>,
466}
467
468impl Fit<Vec<String>, ()> for SemanticKernelApproximation {
469 type Fitted = FittedSemanticKernelApproximation;
470
471 fn fit(self, documents: &Vec<String>, _y: &()) -> Result<Self::Fitted> {
472 let mut rng = RealStdRng::from_seed(thread_rng().random());
473 let normal = RandNormal::new(0.0, 1.0).expect("operation should succeed");
474
475 let mut word_embeddings = HashMap::new();
477 let mut vocabulary = HashSet::new();
478
479 for doc in documents {
480 let tokens: Vec<String> = doc
481 .to_lowercase()
482 .split_whitespace()
483 .map(|s| s.to_string())
484 .collect();
485
486 for token in tokens {
487 vocabulary.insert(token);
488 }
489 }
490
491 for word in vocabulary {
492 let embedding = Array1::from_vec(
493 (0..self.embedding_dim)
494 .map(|_| rng.sample(normal))
495 .collect(),
496 );
497 word_embeddings.insert(word, embedding);
498 }
499
500 let mut projection_matrix = Array2::zeros((self.n_components, self.embedding_dim));
502 for i in 0..self.n_components {
503 for j in 0..self.embedding_dim {
504 projection_matrix[[i, j]] = rng.sample(normal);
505 }
506 }
507
508 Ok(FittedSemanticKernelApproximation {
509 n_components: self.n_components,
510 embedding_dim: self.embedding_dim,
511 similarity_measure: self.similarity_measure,
512 aggregation_method: self.aggregation_method,
513 use_attention: self.use_attention,
514 word_embeddings,
515 projection_matrix,
516 })
517 }
518}
519
520impl FittedSemanticKernelApproximation {
521 fn aggregate_embeddings(&self, embeddings: &Array2<f64>) -> Array1<f64> {
522 match self.aggregation_method {
523 AggregationMethod::Mean => embeddings
524 .mean_axis(Axis(0))
525 .expect("operation should succeed"),
526 AggregationMethod::Max => {
527 let mut result = Array1::zeros(embeddings.ncols());
528 for i in 0..embeddings.ncols() {
529 let col = embeddings.column(i);
530 result[i] = col.fold(f64::NEG_INFINITY, |acc, &x| acc.max(x));
531 }
532 result
533 }
534 AggregationMethod::Sum => embeddings.sum_axis(Axis(0)),
535 AggregationMethod::AttentionWeighted => {
536 let n_tokens = embeddings.nrows();
538 let mut attention_weights = Array1::zeros(n_tokens);
539
540 for i in 0..n_tokens {
541 let token_embedding = embeddings.row(i);
542 attention_weights[i] = token_embedding.dot(&token_embedding).sqrt();
543 }
544
545 let max_weight = attention_weights.fold(f64::NEG_INFINITY, |acc, &x| acc.max(x));
547 attention_weights.mapv_inplace(|x| (x - max_weight).exp());
548 let sum_weights = attention_weights.sum();
549 if sum_weights > 0.0 {
550 attention_weights /= sum_weights;
551 }
552
553 let mut result = Array1::zeros(embeddings.ncols());
555 for i in 0..n_tokens {
556 let token_embedding = embeddings.row(i);
557 for j in 0..embeddings.ncols() {
558 result[j] += attention_weights[i] * token_embedding[j];
559 }
560 }
561 result
562 }
563 }
564 }
565}
566
567impl Transform<Vec<String>, Array2<f64>> for FittedSemanticKernelApproximation {
568 fn transform(&self, documents: &Vec<String>) -> Result<Array2<f64>> {
569 let n_documents = documents.len();
570 let mut result = Array2::zeros((n_documents, self.n_components));
571
572 for (doc_idx, doc) in documents.iter().enumerate() {
573 let tokens: Vec<String> = doc
574 .to_lowercase()
575 .split_whitespace()
576 .map(|s| s.to_string())
577 .collect();
578
579 let mut token_embeddings = Vec::new();
581 for token in tokens {
582 if let Some(embedding) = self.word_embeddings.get(&token) {
583 token_embeddings.push(embedding.clone());
584 }
585 }
586
587 if !token_embeddings.is_empty() {
588 let embeddings_matrix = Array2::from_shape_vec(
590 (token_embeddings.len(), self.embedding_dim),
591 token_embeddings
592 .iter()
593 .flat_map(|e| e.iter().cloned())
594 .collect(),
595 )?;
596
597 let doc_embedding = self.aggregate_embeddings(&embeddings_matrix);
599
600 for i in 0..self.n_components {
602 let projected = self.projection_matrix.row(i).dot(&doc_embedding);
603 result[[doc_idx, i]] = projected.tanh(); }
605 }
606 }
607
608 Ok(result)
609 }
610}
611
612#[derive(Debug, Clone, Serialize, Deserialize)]
614pub struct SyntacticKernelApproximation {
616 pub n_components: usize,
618 pub max_tree_depth: usize,
620 pub use_pos_tags: bool,
622 pub use_dependencies: bool,
624 pub tree_kernel_type: TreeKernelType,
626}
627
628#[derive(Debug, Clone, Serialize, Deserialize)]
629pub enum TreeKernelType {
631 Subset,
633 Subsequence,
635 Partial,
637}
638
639impl SyntacticKernelApproximation {
640 pub fn new(n_components: usize) -> Self {
641 Self {
642 n_components,
643 max_tree_depth: 10,
644 use_pos_tags: true,
645 use_dependencies: true,
646 tree_kernel_type: TreeKernelType::Subset,
647 }
648 }
649
650 pub fn max_tree_depth(mut self, depth: usize) -> Self {
651 self.max_tree_depth = depth;
652 self
653 }
654
655 pub fn use_pos_tags(mut self, use_pos: bool) -> Self {
656 self.use_pos_tags = use_pos;
657 self
658 }
659
660 pub fn use_dependencies(mut self, use_deps: bool) -> Self {
661 self.use_dependencies = use_deps;
662 self
663 }
664
665 pub fn tree_kernel_type(mut self, kernel_type: TreeKernelType) -> Self {
666 self.tree_kernel_type = kernel_type;
667 self
668 }
669
670 fn extract_syntactic_features(&self, text: &str) -> Vec<String> {
671 let mut features = Vec::new();
672
673 let tokens: Vec<&str> = text.split_whitespace().collect();
675
676 if self.use_pos_tags {
678 for token in &tokens {
679 let pos_tag = self.simple_pos_tag(token);
680 features.push(format!("POS_{}", pos_tag));
681 }
682 }
683
684 if self.use_dependencies {
686 for i in 0..tokens.len() {
687 if i > 0 {
688 features.push(format!("DEP_{}_{}", tokens[i - 1], tokens[i]));
689 }
690 }
691 }
692
693 for n in 1..=3 {
695 for i in 0..=tokens.len().saturating_sub(n) {
696 if i + n <= tokens.len() {
697 let ngram = tokens[i..i + n].join("_");
698 features.push(format!("NGRAM_{}", ngram));
699 }
700 }
701 }
702
703 features
704 }
705
706 fn simple_pos_tag(&self, token: &str) -> String {
707 let token_lower = token.to_lowercase();
709
710 if token_lower.ends_with("ing") {
711 "VBG".to_string()
712 } else if token_lower.ends_with("ed") {
713 "VBD".to_string()
714 } else if token_lower.ends_with("ly") {
715 "RB".to_string()
716 } else if token_lower.ends_with("s") && !token_lower.ends_with("ss") {
717 "NNS".to_string()
718 } else if token.chars().all(|c| c.is_alphabetic() && c.is_uppercase()) {
719 "NNP".to_string()
720 } else if token.chars().all(|c| c.is_alphabetic()) {
721 "NN".to_string()
722 } else if token.chars().all(|c| c.is_numeric()) {
723 "CD".to_string()
724 } else {
725 "UNK".to_string()
726 }
727 }
728}
729
730#[derive(Debug, Clone, Serialize, Deserialize)]
731pub struct FittedSyntacticKernelApproximation {
733 pub n_components: usize,
735 pub max_tree_depth: usize,
737 pub use_pos_tags: bool,
739 pub use_dependencies: bool,
741 pub tree_kernel_type: TreeKernelType,
743 pub feature_vocabulary: HashMap<String, usize>,
745 pub random_weights: Array2<f64>,
747}
748
749impl Fit<Vec<String>, ()> for SyntacticKernelApproximation {
750 type Fitted = FittedSyntacticKernelApproximation;
751
752 fn fit(self, documents: &Vec<String>, _y: &()) -> Result<Self::Fitted> {
753 let mut feature_vocabulary = HashMap::new();
754
755 for doc in documents {
757 let features = self.extract_syntactic_features(doc);
758 for feature in features {
759 if !feature_vocabulary.contains_key(&feature) {
760 feature_vocabulary.insert(feature, feature_vocabulary.len());
761 }
762 }
763 }
764
765 let mut rng = RealStdRng::from_seed(thread_rng().random());
767 let normal = RandNormal::new(0.0, 1.0).expect("operation should succeed");
768
769 let vocab_size = feature_vocabulary.len();
770 let mut random_weights = Array2::zeros((self.n_components, vocab_size));
771
772 for i in 0..self.n_components {
773 for j in 0..vocab_size {
774 random_weights[[i, j]] = rng.sample(normal);
775 }
776 }
777
778 Ok(FittedSyntacticKernelApproximation {
779 n_components: self.n_components,
780 max_tree_depth: self.max_tree_depth,
781 use_pos_tags: self.use_pos_tags,
782 use_dependencies: self.use_dependencies,
783 tree_kernel_type: self.tree_kernel_type,
784 feature_vocabulary,
785 random_weights,
786 })
787 }
788}
789
790impl FittedSyntacticKernelApproximation {
791 fn extract_syntactic_features(&self, text: &str) -> Vec<String> {
792 let mut features = Vec::new();
793
794 let tokens: Vec<&str> = text.split_whitespace().collect();
796
797 if self.use_pos_tags {
799 for token in &tokens {
800 let pos_tag = self.simple_pos_tag(token);
801 features.push(format!("POS_{}", pos_tag));
802 }
803 }
804
805 if self.use_dependencies {
807 for i in 0..tokens.len() {
808 if i > 0 {
809 features.push(format!("DEP_{}_{}", tokens[i - 1], tokens[i]));
810 }
811 }
812 }
813
814 for n in 1..=3 {
816 for i in 0..=tokens.len().saturating_sub(n) {
817 if i + n <= tokens.len() {
818 let ngram = tokens[i..i + n].join("_");
819 features.push(format!("NGRAM_{}", ngram));
820 }
821 }
822 }
823
824 features
825 }
826
827 fn simple_pos_tag(&self, token: &str) -> String {
828 let token_lower = token.to_lowercase();
830
831 if token_lower.ends_with("ing") {
832 "VBG".to_string()
833 } else if token_lower.ends_with("ed") {
834 "VBD".to_string()
835 } else if token_lower.ends_with("ly") {
836 "RB".to_string()
837 } else if token_lower.ends_with("s") && !token_lower.ends_with("ss") {
838 "NNS".to_string()
839 } else if token.chars().all(|c| c.is_alphabetic() && c.is_uppercase()) {
840 "NNP".to_string()
841 } else if token.chars().all(|c| c.is_alphabetic()) {
842 "NN".to_string()
843 } else if token.chars().all(|c| c.is_numeric()) {
844 "CD".to_string()
845 } else {
846 "UNK".to_string()
847 }
848 }
849}
850
851impl Transform<Vec<String>, Array2<f64>> for FittedSyntacticKernelApproximation {
852 fn transform(&self, documents: &Vec<String>) -> Result<Array2<f64>> {
853 let n_documents = documents.len();
854 let vocab_size = self.feature_vocabulary.len();
855
856 let mut feature_matrix = Array2::zeros((n_documents, vocab_size));
858
859 for (doc_idx, doc) in documents.iter().enumerate() {
860 let features = self.extract_syntactic_features(doc);
861 let mut feature_counts = HashMap::new();
862
863 for feature in features {
864 *feature_counts.entry(feature).or_insert(0) += 1;
865 }
866
867 for (feature, count) in feature_counts {
868 if let Some(&vocab_idx) = self.feature_vocabulary.get(&feature) {
869 feature_matrix[[doc_idx, vocab_idx]] = count as f64;
870 }
871 }
872 }
873
874 let mut result = Array2::zeros((n_documents, self.n_components));
876
877 for i in 0..n_documents {
878 for j in 0..self.n_components {
879 let mut dot_product = 0.0;
880 for k in 0..vocab_size {
881 dot_product += feature_matrix[[i, k]] * self.random_weights[[j, k]];
882 }
883 result[[i, j]] = dot_product.tanh();
884 }
885 }
886
887 Ok(result)
888 }
889}
890
891#[derive(Debug, Clone, Serialize, Deserialize)]
893pub struct DocumentKernelApproximation {
895 pub n_components: usize,
897 pub use_topic_features: bool,
899 pub use_readability_features: bool,
901 pub use_stylometric_features: bool,
903 pub n_topics: usize,
905}
906
907impl DocumentKernelApproximation {
908 pub fn new(n_components: usize) -> Self {
909 Self {
910 n_components,
911 use_topic_features: true,
912 use_readability_features: true,
913 use_stylometric_features: true,
914 n_topics: 10,
915 }
916 }
917
918 pub fn use_topic_features(mut self, use_topics: bool) -> Self {
919 self.use_topic_features = use_topics;
920 self
921 }
922
923 pub fn use_readability_features(mut self, use_readability: bool) -> Self {
924 self.use_readability_features = use_readability;
925 self
926 }
927
928 pub fn use_stylometric_features(mut self, use_stylometric: bool) -> Self {
929 self.use_stylometric_features = use_stylometric;
930 self
931 }
932
933 pub fn n_topics(mut self, n_topics: usize) -> Self {
934 self.n_topics = n_topics;
935 self
936 }
937
938 fn extract_document_features(&self, text: &str) -> Vec<f64> {
939 let mut features = Vec::new();
940
941 let sentences: Vec<&str> = text.split(&['.', '!', '?'][..]).collect();
942 let words: Vec<&str> = text.split_whitespace().collect();
943 let characters: Vec<char> = text.chars().collect();
944
945 if self.use_readability_features {
946 let avg_sentence_length = if !sentences.is_empty() {
948 words.len() as f64 / sentences.len() as f64
949 } else {
950 0.0
951 };
952
953 let avg_word_length = if !words.is_empty() {
954 characters.len() as f64 / words.len() as f64
955 } else {
956 0.0
957 };
958
959 features.push(avg_sentence_length);
960 features.push(avg_word_length);
961 features.push(sentences.len() as f64);
962 features.push(words.len() as f64);
963 }
964
965 if self.use_stylometric_features {
966 let punctuation_count = characters
968 .iter()
969 .filter(|c| c.is_ascii_punctuation())
970 .count();
971 let uppercase_count = characters.iter().filter(|c| c.is_uppercase()).count();
972 let digit_count = characters.iter().filter(|c| c.is_numeric()).count();
973
974 features.push(punctuation_count as f64 / characters.len() as f64);
975 features.push(uppercase_count as f64 / characters.len() as f64);
976 features.push(digit_count as f64 / characters.len() as f64);
977
978 let unique_words: HashSet<&str> = words.iter().cloned().collect();
980 let ttr = if !words.is_empty() {
981 unique_words.len() as f64 / words.len() as f64
982 } else {
983 0.0
984 };
985 features.push(ttr);
986 }
987
988 if self.use_topic_features {
989 let mut topic_features = vec![0.0; self.n_topics];
991 let mut hasher = DefaultHasher::new();
992 text.hash(&mut hasher);
993 let hash = hasher.finish();
994
995 for i in 0..self.n_topics {
996 topic_features[i] = ((hash + i as u64) % 1000) as f64 / 1000.0;
997 }
998
999 features.extend(topic_features);
1000 }
1001
1002 features
1003 }
1004}
1005
1006#[derive(Debug, Clone, Serialize, Deserialize)]
1007pub struct FittedDocumentKernelApproximation {
1009 pub n_components: usize,
1011 pub use_topic_features: bool,
1013 pub use_readability_features: bool,
1015 pub use_stylometric_features: bool,
1017 pub n_topics: usize,
1019 pub feature_dim: usize,
1021 pub random_weights: Array2<f64>,
1023}
1024
1025impl Fit<Vec<String>, ()> for DocumentKernelApproximation {
1026 type Fitted = FittedDocumentKernelApproximation;
1027
1028 fn fit(self, documents: &Vec<String>, _y: &()) -> Result<Self::Fitted> {
1029 let sample_features = self.extract_document_features(&documents[0]);
1031 let feature_dim = sample_features.len();
1032
1033 let mut rng = RealStdRng::from_seed(thread_rng().random());
1035 let normal = RandNormal::new(0.0, 1.0).expect("operation should succeed");
1036
1037 let mut random_weights = Array2::zeros((self.n_components, feature_dim));
1038 for i in 0..self.n_components {
1039 for j in 0..feature_dim {
1040 random_weights[[i, j]] = rng.sample(normal);
1041 }
1042 }
1043
1044 Ok(FittedDocumentKernelApproximation {
1045 n_components: self.n_components,
1046 use_topic_features: self.use_topic_features,
1047 use_readability_features: self.use_readability_features,
1048 use_stylometric_features: self.use_stylometric_features,
1049 n_topics: self.n_topics,
1050 feature_dim,
1051 random_weights,
1052 })
1053 }
1054}
1055
1056impl FittedDocumentKernelApproximation {
1057 fn extract_document_features(&self, text: &str) -> Vec<f64> {
1058 let mut features = Vec::new();
1059
1060 let sentences: Vec<&str> = text.split(&['.', '!', '?'][..]).collect();
1061 let words: Vec<&str> = text.split_whitespace().collect();
1062 let characters: Vec<char> = text.chars().collect();
1063
1064 if self.use_readability_features {
1065 let avg_sentence_length = if !sentences.is_empty() {
1067 words.len() as f64 / sentences.len() as f64
1068 } else {
1069 0.0
1070 };
1071
1072 let avg_word_length = if !words.is_empty() {
1073 characters.len() as f64 / words.len() as f64
1074 } else {
1075 0.0
1076 };
1077
1078 features.push(avg_sentence_length);
1079 features.push(avg_word_length);
1080 features.push(sentences.len() as f64);
1081 features.push(words.len() as f64);
1082 }
1083
1084 if self.use_stylometric_features {
1085 let punctuation_count = characters
1087 .iter()
1088 .filter(|c| c.is_ascii_punctuation())
1089 .count();
1090 let uppercase_count = characters.iter().filter(|c| c.is_uppercase()).count();
1091 let digit_count = characters.iter().filter(|c| c.is_numeric()).count();
1092
1093 features.push(punctuation_count as f64 / characters.len() as f64);
1094 features.push(uppercase_count as f64 / characters.len() as f64);
1095 features.push(digit_count as f64 / characters.len() as f64);
1096
1097 let unique_words: HashSet<&str> = words.iter().cloned().collect();
1099 let ttr = if !words.is_empty() {
1100 unique_words.len() as f64 / words.len() as f64
1101 } else {
1102 0.0
1103 };
1104 features.push(ttr);
1105 }
1106
1107 if self.use_topic_features {
1108 let mut topic_features = vec![0.0; self.n_topics];
1110 let mut hasher = DefaultHasher::new();
1111 text.hash(&mut hasher);
1112 let hash = hasher.finish();
1113
1114 for i in 0..self.n_topics {
1115 topic_features[i] = ((hash + i as u64) % 1000) as f64 / 1000.0;
1116 }
1117
1118 features.extend(topic_features);
1119 }
1120
1121 features
1122 }
1123}
1124
1125impl Transform<Vec<String>, Array2<f64>> for FittedDocumentKernelApproximation {
1126 fn transform(&self, documents: &Vec<String>) -> Result<Array2<f64>> {
1127 let n_documents = documents.len();
1128 let mut result = Array2::zeros((n_documents, self.n_components));
1129
1130 for (doc_idx, doc) in documents.iter().enumerate() {
1131 let features = self.extract_document_features(doc);
1132 let feature_array = Array1::from_vec(features);
1133
1134 for i in 0..self.n_components {
1135 let projected = self.random_weights.row(i).dot(&feature_array);
1136 result[[doc_idx, i]] = projected.tanh();
1137 }
1138 }
1139
1140 Ok(result)
1141 }
1142}
1143
1144#[allow(non_snake_case)]
1145#[cfg(test)]
1146mod tests {
1147 use super::*;
1148
1149 #[test]
1150 fn test_text_kernel_approximation() {
1151 let docs = vec![
1152 "This is a test document".to_string(),
1153 "Another test document here".to_string(),
1154 "Third document for testing".to_string(),
1155 ];
1156
1157 let text_kernel = TextKernelApproximation::new(50);
1158 let fitted = text_kernel
1159 .fit(&docs, &())
1160 .expect("operation should succeed");
1161 let transformed = fitted.transform(&docs).expect("operation should succeed");
1162
1163 assert_eq!(transformed.shape()[0], 3);
1164 assert_eq!(transformed.shape()[1], 50);
1165 }
1166
1167 #[test]
1168 fn test_semantic_kernel_approximation() {
1169 let docs = vec![
1170 "Semantic similarity test".to_string(),
1171 "Another semantic test".to_string(),
1172 ];
1173
1174 let semantic_kernel = SemanticKernelApproximation::new(30, 100);
1175 let fitted = semantic_kernel
1176 .fit(&docs, &())
1177 .expect("operation should succeed");
1178 let transformed = fitted.transform(&docs).expect("operation should succeed");
1179
1180 assert_eq!(transformed.shape()[0], 2);
1181 assert_eq!(transformed.shape()[1], 30);
1182 }
1183
1184 #[test]
1185 fn test_syntactic_kernel_approximation() {
1186 let docs = vec![
1187 "The cat sat on the mat".to_string(),
1188 "Dogs are running quickly".to_string(),
1189 ];
1190
1191 let syntactic_kernel = SyntacticKernelApproximation::new(40);
1192 let fitted = syntactic_kernel
1193 .fit(&docs, &())
1194 .expect("operation should succeed");
1195 let transformed = fitted.transform(&docs).expect("operation should succeed");
1196
1197 assert_eq!(transformed.shape()[0], 2);
1198 assert_eq!(transformed.shape()[1], 40);
1199 }
1200
1201 #[test]
1202 fn test_document_kernel_approximation() {
1203 let docs = vec![
1204 "This is a long document with multiple sentences. It contains various features."
1205 .to_string(),
1206 "Short doc.".to_string(),
1207 ];
1208
1209 let doc_kernel = DocumentKernelApproximation::new(25);
1210 let fitted = doc_kernel
1211 .fit(&docs, &())
1212 .expect("operation should succeed");
1213 let transformed = fitted.transform(&docs).expect("operation should succeed");
1214
1215 assert_eq!(transformed.shape()[0], 2);
1216 assert_eq!(transformed.shape()[1], 25);
1217 }
1218}