1use scirs2_core::ndarray::{Array1, Array2, ArrayView1, Axis};
7use scirs2_core::random::essentials::Normal as RandNormal;
8use scirs2_core::random::rngs::StdRng as RealStdRng;
9use scirs2_core::random::Rng;
10use scirs2_core::random::{thread_rng, SeedableRng};
11use serde::{Deserialize, Serialize};
12use std::collections::hash_map::DefaultHasher;
13use std::collections::{HashMap, HashSet};
14use std::hash::{Hash, Hasher};
15
16use sklears_core::error::Result;
17use sklears_core::traits::{Fit, Transform};
18
19#[derive(Debug, Clone, Serialize, Deserialize)]
21pub struct TextKernelApproximation {
23 pub n_components: usize,
25 pub max_features: usize,
27 pub ngram_range: (usize, usize),
29 pub min_df: usize,
31 pub max_df: f64,
33 pub use_tf_idf: bool,
35 pub use_hashing: bool,
37 pub sublinear_tf: bool,
39}
40
41impl TextKernelApproximation {
42 pub fn new(n_components: usize) -> Self {
43 Self {
44 n_components,
45 max_features: 10000,
46 ngram_range: (1, 1),
47 min_df: 1,
48 max_df: 1.0,
49 use_tf_idf: true,
50 use_hashing: false,
51 sublinear_tf: false,
52 }
53 }
54
55 pub fn max_features(mut self, max_features: usize) -> Self {
56 self.max_features = max_features;
57 self
58 }
59
60 pub fn ngram_range(mut self, ngram_range: (usize, usize)) -> Self {
61 self.ngram_range = ngram_range;
62 self
63 }
64
65 pub fn min_df(mut self, min_df: usize) -> Self {
66 self.min_df = min_df;
67 self
68 }
69
70 pub fn max_df(mut self, max_df: f64) -> Self {
71 self.max_df = max_df;
72 self
73 }
74
75 pub fn use_tf_idf(mut self, use_tf_idf: bool) -> Self {
76 self.use_tf_idf = use_tf_idf;
77 self
78 }
79
80 pub fn use_hashing(mut self, use_hashing: bool) -> Self {
81 self.use_hashing = use_hashing;
82 self
83 }
84
85 pub fn sublinear_tf(mut self, sublinear_tf: bool) -> Self {
86 self.sublinear_tf = sublinear_tf;
87 self
88 }
89
90 fn tokenize(&self, text: &str) -> Vec<String> {
91 text.to_lowercase()
92 .split_whitespace()
93 .map(|s| s.trim_matches(|c: char| !c.is_alphanumeric()))
94 .filter(|s| !s.is_empty())
95 .map(|s| s.to_string())
96 .collect()
97 }
98
99 fn extract_ngrams(&self, tokens: &[String]) -> Vec<String> {
100 let mut ngrams = Vec::new();
101
102 for n in self.ngram_range.0..=self.ngram_range.1 {
103 for i in 0..=tokens.len().saturating_sub(n) {
104 if i + n <= tokens.len() {
105 let ngram = tokens[i..i + n].join(" ");
106 ngrams.push(ngram);
107 }
108 }
109 }
110
111 ngrams
112 }
113
114 fn hash_feature(&self, feature: &str) -> usize {
115 let mut hasher = DefaultHasher::new();
116 feature.hash(&mut hasher);
117 hasher.finish() as usize % self.max_features
118 }
119}
120
121#[derive(Debug, Clone, Serialize, Deserialize)]
122pub struct FittedTextKernelApproximation {
124 pub n_components: usize,
126 pub max_features: usize,
128 pub ngram_range: (usize, usize),
130 pub min_df: usize,
132 pub max_df: f64,
134 pub use_tf_idf: bool,
136 pub use_hashing: bool,
138 pub sublinear_tf: bool,
140 pub vocabulary: HashMap<String, usize>,
142 pub idf_values: Array1<f64>,
144 pub random_weights: Array2<f64>,
146}
147
148impl Fit<Vec<String>, ()> for TextKernelApproximation {
149 type Fitted = FittedTextKernelApproximation;
150
151 fn fit(self, documents: &Vec<String>, _y: &()) -> Result<Self::Fitted> {
152 let mut vocabulary = HashMap::new();
153 let mut document_frequencies = HashMap::new();
154 let n_documents = documents.len();
155
156 for doc in documents {
158 let tokens = self.tokenize(doc);
159 let ngrams = self.extract_ngrams(&tokens);
160 let unique_ngrams: HashSet<String> = ngrams.into_iter().collect();
161
162 for ngram in unique_ngrams {
163 *document_frequencies.entry(ngram.clone()).or_insert(0) += 1;
164 if !vocabulary.contains_key(&ngram) {
165 vocabulary.insert(ngram, vocabulary.len());
166 }
167 }
168 }
169
170 let mut filtered_vocabulary = HashMap::new();
172 for (term, &df) in &document_frequencies {
173 let df_ratio = df as f64 / n_documents as f64;
174 if df >= self.min_df && df_ratio <= self.max_df {
175 filtered_vocabulary.insert(term.clone(), filtered_vocabulary.len());
176 }
177 }
178
179 if filtered_vocabulary.len() > self.max_features {
181 let mut sorted_vocab: Vec<_> = filtered_vocabulary.iter().collect();
182 sorted_vocab.sort_by(|a, b| {
183 document_frequencies[a.0]
184 .cmp(&document_frequencies[b.0])
185 .reverse()
186 });
187
188 let mut new_vocabulary = HashMap::new();
189 for (term, _) in sorted_vocab.iter().take(self.max_features) {
190 new_vocabulary.insert(term.to_string(), new_vocabulary.len());
191 }
192 filtered_vocabulary = new_vocabulary;
193 }
194
195 let vocab_size = filtered_vocabulary.len();
197 let mut idf_values = Array1::zeros(vocab_size);
198
199 if self.use_tf_idf {
200 for (term, &idx) in &filtered_vocabulary {
201 let df = document_frequencies.get(term).unwrap_or(&0);
202 idf_values[idx] = (n_documents as f64 / (*df as f64 + 1.0)).ln() + 1.0;
203 }
204 } else {
205 idf_values.fill(1.0);
206 }
207
208 let mut rng = RealStdRng::from_seed(thread_rng().gen());
210 let normal = RandNormal::new(0.0, 1.0).unwrap();
211
212 let mut random_weights = Array2::zeros((self.n_components, vocab_size));
213 for i in 0..self.n_components {
214 for j in 0..vocab_size {
215 random_weights[[i, j]] = rng.sample(normal);
216 }
217 }
218
219 Ok(FittedTextKernelApproximation {
220 n_components: self.n_components,
221 max_features: self.max_features,
222 ngram_range: self.ngram_range,
223 min_df: self.min_df,
224 max_df: self.max_df,
225 use_tf_idf: self.use_tf_idf,
226 use_hashing: self.use_hashing,
227 sublinear_tf: self.sublinear_tf,
228 vocabulary: filtered_vocabulary,
229 idf_values,
230 random_weights,
231 })
232 }
233}
234
235impl FittedTextKernelApproximation {
236 fn tokenize(&self, text: &str) -> Vec<String> {
237 text.to_lowercase()
238 .split_whitespace()
239 .map(|s| s.trim_matches(|c: char| !c.is_alphanumeric()))
240 .filter(|s| !s.is_empty())
241 .map(|s| s.to_string())
242 .collect()
243 }
244
245 fn extract_ngrams(&self, tokens: &[String]) -> Vec<String> {
246 let mut ngrams = Vec::new();
247
248 for n in self.ngram_range.0..=self.ngram_range.1 {
249 for i in 0..=tokens.len().saturating_sub(n) {
250 if i + n <= tokens.len() {
251 let ngram = tokens[i..i + n].join(" ");
252 ngrams.push(ngram);
253 }
254 }
255 }
256
257 ngrams
258 }
259}
260
261impl Transform<Vec<String>, Array2<f64>> for FittedTextKernelApproximation {
262 fn transform(&self, documents: &Vec<String>) -> Result<Array2<f64>> {
263 let n_documents = documents.len();
264 let vocab_size = self.vocabulary.len();
265
266 let mut tf_idf_matrix = Array2::zeros((n_documents, vocab_size));
268
269 for (doc_idx, doc) in documents.iter().enumerate() {
270 let tokens = self.tokenize(doc);
271 let ngrams = self.extract_ngrams(&tokens);
272
273 let mut term_counts = HashMap::new();
275 for ngram in ngrams {
276 *term_counts.entry(ngram).or_insert(0) += 1;
277 }
278
279 for (term, &count) in &term_counts {
281 if let Some(&vocab_idx) = self.vocabulary.get(term) {
282 let tf = if self.sublinear_tf {
283 1.0 + (count as f64).ln()
284 } else {
285 count as f64
286 };
287
288 let tf_idf = tf * self.idf_values[vocab_idx];
289 tf_idf_matrix[[doc_idx, vocab_idx]] = tf_idf;
290 }
291 }
292 }
293
294 let mut result = Array2::zeros((n_documents, self.n_components));
296
297 for i in 0..n_documents {
298 for j in 0..self.n_components {
299 let mut dot_product = 0.0;
300 for k in 0..vocab_size {
301 dot_product += tf_idf_matrix[[i, k]] * self.random_weights[[j, k]];
302 }
303 result[[i, j]] = dot_product;
304 }
305 }
306
307 Ok(result)
308 }
309}
310
311#[derive(Debug, Clone, Serialize, Deserialize)]
313pub struct SemanticKernelApproximation {
315 pub n_components: usize,
317 pub embedding_dim: usize,
319 pub similarity_measure: SimilarityMeasure,
321 pub aggregation_method: AggregationMethod,
323 pub use_attention: bool,
325}
326
327#[derive(Debug, Clone, Serialize, Deserialize)]
328pub enum SimilarityMeasure {
330 Cosine,
332 Euclidean,
334 Manhattan,
336 Dot,
338}
339
340#[derive(Debug, Clone, Serialize, Deserialize)]
341pub enum AggregationMethod {
343 Mean,
345 Max,
347 Sum,
349 AttentionWeighted,
351}
352
353impl SemanticKernelApproximation {
354 pub fn new(n_components: usize, embedding_dim: usize) -> Self {
355 Self {
356 n_components,
357 embedding_dim,
358 similarity_measure: SimilarityMeasure::Cosine,
359 aggregation_method: AggregationMethod::Mean,
360 use_attention: false,
361 }
362 }
363
364 pub fn similarity_measure(mut self, measure: SimilarityMeasure) -> Self {
365 self.similarity_measure = measure;
366 self
367 }
368
369 pub fn aggregation_method(mut self, method: AggregationMethod) -> Self {
370 self.aggregation_method = method;
371 self
372 }
373
374 pub fn use_attention(mut self, use_attention: bool) -> Self {
375 self.use_attention = use_attention;
376 self
377 }
378
379 fn compute_similarity(&self, vec1: &ArrayView1<f64>, vec2: &ArrayView1<f64>) -> f64 {
380 match self.similarity_measure {
381 SimilarityMeasure::Cosine => {
382 let dot = vec1.dot(vec2);
383 let norm1 = vec1.dot(vec1).sqrt();
384 let norm2 = vec2.dot(vec2).sqrt();
385 if norm1 > 0.0 && norm2 > 0.0 {
386 dot / (norm1 * norm2)
387 } else {
388 0.0
389 }
390 }
391 SimilarityMeasure::Euclidean => {
392 let diff = vec1 - vec2;
393 -diff.dot(&diff).sqrt()
394 }
395 SimilarityMeasure::Manhattan => {
396 let diff = vec1 - vec2;
397 -diff.mapv(|x| x.abs()).sum()
398 }
399 SimilarityMeasure::Dot => vec1.dot(vec2),
400 }
401 }
402
403 fn aggregate_embeddings(&self, embeddings: &Array2<f64>) -> Array1<f64> {
404 match self.aggregation_method {
405 AggregationMethod::Mean => embeddings.mean_axis(Axis(0)).unwrap(),
406 AggregationMethod::Max => {
407 let mut result = Array1::zeros(embeddings.ncols());
408 for i in 0..embeddings.ncols() {
409 let col = embeddings.column(i);
410 result[i] = col.fold(f64::NEG_INFINITY, |acc, &x| acc.max(x));
411 }
412 result
413 }
414 AggregationMethod::Sum => embeddings.sum_axis(Axis(0)),
415 AggregationMethod::AttentionWeighted => {
416 let n_tokens = embeddings.nrows();
418 let mut attention_weights = Array1::zeros(n_tokens);
419
420 for i in 0..n_tokens {
421 let token_embedding = embeddings.row(i);
422 attention_weights[i] = token_embedding.dot(&token_embedding).sqrt();
423 }
424
425 let max_weight = attention_weights.fold(f64::NEG_INFINITY, |acc, &x| acc.max(x));
427 attention_weights.mapv_inplace(|x| (x - max_weight).exp());
428 let sum_weights = attention_weights.sum();
429 if sum_weights > 0.0 {
430 attention_weights /= sum_weights;
431 }
432
433 let mut result = Array1::zeros(embeddings.ncols());
435 for i in 0..n_tokens {
436 let token_embedding = embeddings.row(i);
437 for j in 0..embeddings.ncols() {
438 result[j] += attention_weights[i] * token_embedding[j];
439 }
440 }
441 result
442 }
443 }
444 }
445}
446
447#[derive(Debug, Clone, Serialize, Deserialize)]
448pub struct FittedSemanticKernelApproximation {
450 pub n_components: usize,
452 pub embedding_dim: usize,
454 pub similarity_measure: SimilarityMeasure,
456 pub aggregation_method: AggregationMethod,
458 pub use_attention: bool,
460 pub word_embeddings: HashMap<String, Array1<f64>>,
462 pub projection_matrix: Array2<f64>,
464}
465
466impl Fit<Vec<String>, ()> for SemanticKernelApproximation {
467 type Fitted = FittedSemanticKernelApproximation;
468
469 fn fit(self, documents: &Vec<String>, _y: &()) -> Result<Self::Fitted> {
470 let mut rng = RealStdRng::from_seed(thread_rng().gen());
471 let normal = RandNormal::new(0.0, 1.0).unwrap();
472
473 let mut word_embeddings = HashMap::new();
475 let mut vocabulary = HashSet::new();
476
477 for doc in documents {
478 let tokens: Vec<String> = doc
479 .to_lowercase()
480 .split_whitespace()
481 .map(|s| s.to_string())
482 .collect();
483
484 for token in tokens {
485 vocabulary.insert(token);
486 }
487 }
488
489 for word in vocabulary {
490 let embedding = Array1::from_vec(
491 (0..self.embedding_dim)
492 .map(|_| rng.sample(normal))
493 .collect(),
494 );
495 word_embeddings.insert(word, embedding);
496 }
497
498 let mut projection_matrix = Array2::zeros((self.n_components, self.embedding_dim));
500 for i in 0..self.n_components {
501 for j in 0..self.embedding_dim {
502 projection_matrix[[i, j]] = rng.sample(normal);
503 }
504 }
505
506 Ok(FittedSemanticKernelApproximation {
507 n_components: self.n_components,
508 embedding_dim: self.embedding_dim,
509 similarity_measure: self.similarity_measure,
510 aggregation_method: self.aggregation_method,
511 use_attention: self.use_attention,
512 word_embeddings,
513 projection_matrix,
514 })
515 }
516}
517
518impl FittedSemanticKernelApproximation {
519 fn aggregate_embeddings(&self, embeddings: &Array2<f64>) -> Array1<f64> {
520 match self.aggregation_method {
521 AggregationMethod::Mean => embeddings.mean_axis(Axis(0)).unwrap(),
522 AggregationMethod::Max => {
523 let mut result = Array1::zeros(embeddings.ncols());
524 for i in 0..embeddings.ncols() {
525 let col = embeddings.column(i);
526 result[i] = col.fold(f64::NEG_INFINITY, |acc, &x| acc.max(x));
527 }
528 result
529 }
530 AggregationMethod::Sum => embeddings.sum_axis(Axis(0)),
531 AggregationMethod::AttentionWeighted => {
532 let n_tokens = embeddings.nrows();
534 let mut attention_weights = Array1::zeros(n_tokens);
535
536 for i in 0..n_tokens {
537 let token_embedding = embeddings.row(i);
538 attention_weights[i] = token_embedding.dot(&token_embedding).sqrt();
539 }
540
541 let max_weight = attention_weights.fold(f64::NEG_INFINITY, |acc, &x| acc.max(x));
543 attention_weights.mapv_inplace(|x| (x - max_weight).exp());
544 let sum_weights = attention_weights.sum();
545 if sum_weights > 0.0 {
546 attention_weights /= sum_weights;
547 }
548
549 let mut result = Array1::zeros(embeddings.ncols());
551 for i in 0..n_tokens {
552 let token_embedding = embeddings.row(i);
553 for j in 0..embeddings.ncols() {
554 result[j] += attention_weights[i] * token_embedding[j];
555 }
556 }
557 result
558 }
559 }
560 }
561}
562
563impl Transform<Vec<String>, Array2<f64>> for FittedSemanticKernelApproximation {
564 fn transform(&self, documents: &Vec<String>) -> Result<Array2<f64>> {
565 let n_documents = documents.len();
566 let mut result = Array2::zeros((n_documents, self.n_components));
567
568 for (doc_idx, doc) in documents.iter().enumerate() {
569 let tokens: Vec<String> = doc
570 .to_lowercase()
571 .split_whitespace()
572 .map(|s| s.to_string())
573 .collect();
574
575 let mut token_embeddings = Vec::new();
577 for token in tokens {
578 if let Some(embedding) = self.word_embeddings.get(&token) {
579 token_embeddings.push(embedding.clone());
580 }
581 }
582
583 if !token_embeddings.is_empty() {
584 let embeddings_matrix = Array2::from_shape_vec(
586 (token_embeddings.len(), self.embedding_dim),
587 token_embeddings
588 .iter()
589 .flat_map(|e| e.iter().cloned())
590 .collect(),
591 )?;
592
593 let doc_embedding = self.aggregate_embeddings(&embeddings_matrix);
595
596 for i in 0..self.n_components {
598 let projected = self.projection_matrix.row(i).dot(&doc_embedding);
599 result[[doc_idx, i]] = projected.tanh(); }
601 }
602 }
603
604 Ok(result)
605 }
606}
607
608#[derive(Debug, Clone, Serialize, Deserialize)]
610pub struct SyntacticKernelApproximation {
612 pub n_components: usize,
614 pub max_tree_depth: usize,
616 pub use_pos_tags: bool,
618 pub use_dependencies: bool,
620 pub tree_kernel_type: TreeKernelType,
622}
623
624#[derive(Debug, Clone, Serialize, Deserialize)]
625pub enum TreeKernelType {
627 Subset,
629 Subsequence,
631 Partial,
633}
634
635impl SyntacticKernelApproximation {
636 pub fn new(n_components: usize) -> Self {
637 Self {
638 n_components,
639 max_tree_depth: 10,
640 use_pos_tags: true,
641 use_dependencies: true,
642 tree_kernel_type: TreeKernelType::Subset,
643 }
644 }
645
646 pub fn max_tree_depth(mut self, depth: usize) -> Self {
647 self.max_tree_depth = depth;
648 self
649 }
650
651 pub fn use_pos_tags(mut self, use_pos: bool) -> Self {
652 self.use_pos_tags = use_pos;
653 self
654 }
655
656 pub fn use_dependencies(mut self, use_deps: bool) -> Self {
657 self.use_dependencies = use_deps;
658 self
659 }
660
661 pub fn tree_kernel_type(mut self, kernel_type: TreeKernelType) -> Self {
662 self.tree_kernel_type = kernel_type;
663 self
664 }
665
666 fn extract_syntactic_features(&self, text: &str) -> Vec<String> {
667 let mut features = Vec::new();
668
669 let tokens: Vec<&str> = text.split_whitespace().collect();
671
672 if self.use_pos_tags {
674 for token in &tokens {
675 let pos_tag = self.simple_pos_tag(token);
676 features.push(format!("POS_{}", pos_tag));
677 }
678 }
679
680 if self.use_dependencies {
682 for i in 0..tokens.len() {
683 if i > 0 {
684 features.push(format!("DEP_{}_{}", tokens[i - 1], tokens[i]));
685 }
686 }
687 }
688
689 for n in 1..=3 {
691 for i in 0..=tokens.len().saturating_sub(n) {
692 if i + n <= tokens.len() {
693 let ngram = tokens[i..i + n].join("_");
694 features.push(format!("NGRAM_{}", ngram));
695 }
696 }
697 }
698
699 features
700 }
701
702 fn simple_pos_tag(&self, token: &str) -> String {
703 let token_lower = token.to_lowercase();
705
706 if token_lower.ends_with("ing") {
707 "VBG".to_string()
708 } else if token_lower.ends_with("ed") {
709 "VBD".to_string()
710 } else if token_lower.ends_with("ly") {
711 "RB".to_string()
712 } else if token_lower.ends_with("s") && !token_lower.ends_with("ss") {
713 "NNS".to_string()
714 } else if token.chars().all(|c| c.is_alphabetic() && c.is_uppercase()) {
715 "NNP".to_string()
716 } else if token.chars().all(|c| c.is_alphabetic()) {
717 "NN".to_string()
718 } else if token.chars().all(|c| c.is_numeric()) {
719 "CD".to_string()
720 } else {
721 "UNK".to_string()
722 }
723 }
724}
725
726#[derive(Debug, Clone, Serialize, Deserialize)]
727pub struct FittedSyntacticKernelApproximation {
729 pub n_components: usize,
731 pub max_tree_depth: usize,
733 pub use_pos_tags: bool,
735 pub use_dependencies: bool,
737 pub tree_kernel_type: TreeKernelType,
739 pub feature_vocabulary: HashMap<String, usize>,
741 pub random_weights: Array2<f64>,
743}
744
745impl Fit<Vec<String>, ()> for SyntacticKernelApproximation {
746 type Fitted = FittedSyntacticKernelApproximation;
747
748 fn fit(self, documents: &Vec<String>, _y: &()) -> Result<Self::Fitted> {
749 let mut feature_vocabulary = HashMap::new();
750
751 for doc in documents {
753 let features = self.extract_syntactic_features(doc);
754 for feature in features {
755 if !feature_vocabulary.contains_key(&feature) {
756 feature_vocabulary.insert(feature, feature_vocabulary.len());
757 }
758 }
759 }
760
761 let mut rng = RealStdRng::from_seed(thread_rng().gen());
763 let normal = RandNormal::new(0.0, 1.0).unwrap();
764
765 let vocab_size = feature_vocabulary.len();
766 let mut random_weights = Array2::zeros((self.n_components, vocab_size));
767
768 for i in 0..self.n_components {
769 for j in 0..vocab_size {
770 random_weights[[i, j]] = rng.sample(normal);
771 }
772 }
773
774 Ok(FittedSyntacticKernelApproximation {
775 n_components: self.n_components,
776 max_tree_depth: self.max_tree_depth,
777 use_pos_tags: self.use_pos_tags,
778 use_dependencies: self.use_dependencies,
779 tree_kernel_type: self.tree_kernel_type,
780 feature_vocabulary,
781 random_weights,
782 })
783 }
784}
785
786impl FittedSyntacticKernelApproximation {
787 fn extract_syntactic_features(&self, text: &str) -> Vec<String> {
788 let mut features = Vec::new();
789
790 let tokens: Vec<&str> = text.split_whitespace().collect();
792
793 if self.use_pos_tags {
795 for token in &tokens {
796 let pos_tag = self.simple_pos_tag(token);
797 features.push(format!("POS_{}", pos_tag));
798 }
799 }
800
801 if self.use_dependencies {
803 for i in 0..tokens.len() {
804 if i > 0 {
805 features.push(format!("DEP_{}_{}", tokens[i - 1], tokens[i]));
806 }
807 }
808 }
809
810 for n in 1..=3 {
812 for i in 0..=tokens.len().saturating_sub(n) {
813 if i + n <= tokens.len() {
814 let ngram = tokens[i..i + n].join("_");
815 features.push(format!("NGRAM_{}", ngram));
816 }
817 }
818 }
819
820 features
821 }
822
823 fn simple_pos_tag(&self, token: &str) -> String {
824 let token_lower = token.to_lowercase();
826
827 if token_lower.ends_with("ing") {
828 "VBG".to_string()
829 } else if token_lower.ends_with("ed") {
830 "VBD".to_string()
831 } else if token_lower.ends_with("ly") {
832 "RB".to_string()
833 } else if token_lower.ends_with("s") && !token_lower.ends_with("ss") {
834 "NNS".to_string()
835 } else if token.chars().all(|c| c.is_alphabetic() && c.is_uppercase()) {
836 "NNP".to_string()
837 } else if token.chars().all(|c| c.is_alphabetic()) {
838 "NN".to_string()
839 } else if token.chars().all(|c| c.is_numeric()) {
840 "CD".to_string()
841 } else {
842 "UNK".to_string()
843 }
844 }
845}
846
847impl Transform<Vec<String>, Array2<f64>> for FittedSyntacticKernelApproximation {
848 fn transform(&self, documents: &Vec<String>) -> Result<Array2<f64>> {
849 let n_documents = documents.len();
850 let vocab_size = self.feature_vocabulary.len();
851
852 let mut feature_matrix = Array2::zeros((n_documents, vocab_size));
854
855 for (doc_idx, doc) in documents.iter().enumerate() {
856 let features = self.extract_syntactic_features(doc);
857 let mut feature_counts = HashMap::new();
858
859 for feature in features {
860 *feature_counts.entry(feature).or_insert(0) += 1;
861 }
862
863 for (feature, count) in feature_counts {
864 if let Some(&vocab_idx) = self.feature_vocabulary.get(&feature) {
865 feature_matrix[[doc_idx, vocab_idx]] = count as f64;
866 }
867 }
868 }
869
870 let mut result = Array2::zeros((n_documents, self.n_components));
872
873 for i in 0..n_documents {
874 for j in 0..self.n_components {
875 let mut dot_product = 0.0;
876 for k in 0..vocab_size {
877 dot_product += feature_matrix[[i, k]] * self.random_weights[[j, k]];
878 }
879 result[[i, j]] = dot_product.tanh();
880 }
881 }
882
883 Ok(result)
884 }
885}
886
887#[derive(Debug, Clone, Serialize, Deserialize)]
889pub struct DocumentKernelApproximation {
891 pub n_components: usize,
893 pub use_topic_features: bool,
895 pub use_readability_features: bool,
897 pub use_stylometric_features: bool,
899 pub n_topics: usize,
901}
902
903impl DocumentKernelApproximation {
904 pub fn new(n_components: usize) -> Self {
905 Self {
906 n_components,
907 use_topic_features: true,
908 use_readability_features: true,
909 use_stylometric_features: true,
910 n_topics: 10,
911 }
912 }
913
914 pub fn use_topic_features(mut self, use_topics: bool) -> Self {
915 self.use_topic_features = use_topics;
916 self
917 }
918
919 pub fn use_readability_features(mut self, use_readability: bool) -> Self {
920 self.use_readability_features = use_readability;
921 self
922 }
923
924 pub fn use_stylometric_features(mut self, use_stylometric: bool) -> Self {
925 self.use_stylometric_features = use_stylometric;
926 self
927 }
928
929 pub fn n_topics(mut self, n_topics: usize) -> Self {
930 self.n_topics = n_topics;
931 self
932 }
933
934 fn extract_document_features(&self, text: &str) -> Vec<f64> {
935 let mut features = Vec::new();
936
937 let sentences: Vec<&str> = text.split(&['.', '!', '?'][..]).collect();
938 let words: Vec<&str> = text.split_whitespace().collect();
939 let characters: Vec<char> = text.chars().collect();
940
941 if self.use_readability_features {
942 let avg_sentence_length = if !sentences.is_empty() {
944 words.len() as f64 / sentences.len() as f64
945 } else {
946 0.0
947 };
948
949 let avg_word_length = if !words.is_empty() {
950 characters.len() as f64 / words.len() as f64
951 } else {
952 0.0
953 };
954
955 features.push(avg_sentence_length);
956 features.push(avg_word_length);
957 features.push(sentences.len() as f64);
958 features.push(words.len() as f64);
959 }
960
961 if self.use_stylometric_features {
962 let punctuation_count = characters
964 .iter()
965 .filter(|c| c.is_ascii_punctuation())
966 .count();
967 let uppercase_count = characters.iter().filter(|c| c.is_uppercase()).count();
968 let digit_count = characters.iter().filter(|c| c.is_numeric()).count();
969
970 features.push(punctuation_count as f64 / characters.len() as f64);
971 features.push(uppercase_count as f64 / characters.len() as f64);
972 features.push(digit_count as f64 / characters.len() as f64);
973
974 let unique_words: HashSet<&str> = words.iter().cloned().collect();
976 let ttr = if !words.is_empty() {
977 unique_words.len() as f64 / words.len() as f64
978 } else {
979 0.0
980 };
981 features.push(ttr);
982 }
983
984 if self.use_topic_features {
985 let mut topic_features = vec![0.0; self.n_topics];
987 let mut hasher = DefaultHasher::new();
988 text.hash(&mut hasher);
989 let hash = hasher.finish();
990
991 for i in 0..self.n_topics {
992 topic_features[i] = ((hash + i as u64) % 1000) as f64 / 1000.0;
993 }
994
995 features.extend(topic_features);
996 }
997
998 features
999 }
1000}
1001
1002#[derive(Debug, Clone, Serialize, Deserialize)]
1003pub struct FittedDocumentKernelApproximation {
1005 pub n_components: usize,
1007 pub use_topic_features: bool,
1009 pub use_readability_features: bool,
1011 pub use_stylometric_features: bool,
1013 pub n_topics: usize,
1015 pub feature_dim: usize,
1017 pub random_weights: Array2<f64>,
1019}
1020
1021impl Fit<Vec<String>, ()> for DocumentKernelApproximation {
1022 type Fitted = FittedDocumentKernelApproximation;
1023
1024 fn fit(self, documents: &Vec<String>, _y: &()) -> Result<Self::Fitted> {
1025 let sample_features = self.extract_document_features(&documents[0]);
1027 let feature_dim = sample_features.len();
1028
1029 let mut rng = RealStdRng::from_seed(thread_rng().gen());
1031 let normal = RandNormal::new(0.0, 1.0).unwrap();
1032
1033 let mut random_weights = Array2::zeros((self.n_components, feature_dim));
1034 for i in 0..self.n_components {
1035 for j in 0..feature_dim {
1036 random_weights[[i, j]] = rng.sample(normal);
1037 }
1038 }
1039
1040 Ok(FittedDocumentKernelApproximation {
1041 n_components: self.n_components,
1042 use_topic_features: self.use_topic_features,
1043 use_readability_features: self.use_readability_features,
1044 use_stylometric_features: self.use_stylometric_features,
1045 n_topics: self.n_topics,
1046 feature_dim,
1047 random_weights,
1048 })
1049 }
1050}
1051
1052impl FittedDocumentKernelApproximation {
1053 fn extract_document_features(&self, text: &str) -> Vec<f64> {
1054 let mut features = Vec::new();
1055
1056 let sentences: Vec<&str> = text.split(&['.', '!', '?'][..]).collect();
1057 let words: Vec<&str> = text.split_whitespace().collect();
1058 let characters: Vec<char> = text.chars().collect();
1059
1060 if self.use_readability_features {
1061 let avg_sentence_length = if !sentences.is_empty() {
1063 words.len() as f64 / sentences.len() as f64
1064 } else {
1065 0.0
1066 };
1067
1068 let avg_word_length = if !words.is_empty() {
1069 characters.len() as f64 / words.len() as f64
1070 } else {
1071 0.0
1072 };
1073
1074 features.push(avg_sentence_length);
1075 features.push(avg_word_length);
1076 features.push(sentences.len() as f64);
1077 features.push(words.len() as f64);
1078 }
1079
1080 if self.use_stylometric_features {
1081 let punctuation_count = characters
1083 .iter()
1084 .filter(|c| c.is_ascii_punctuation())
1085 .count();
1086 let uppercase_count = characters.iter().filter(|c| c.is_uppercase()).count();
1087 let digit_count = characters.iter().filter(|c| c.is_numeric()).count();
1088
1089 features.push(punctuation_count as f64 / characters.len() as f64);
1090 features.push(uppercase_count as f64 / characters.len() as f64);
1091 features.push(digit_count as f64 / characters.len() as f64);
1092
1093 let unique_words: HashSet<&str> = words.iter().cloned().collect();
1095 let ttr = if !words.is_empty() {
1096 unique_words.len() as f64 / words.len() as f64
1097 } else {
1098 0.0
1099 };
1100 features.push(ttr);
1101 }
1102
1103 if self.use_topic_features {
1104 let mut topic_features = vec![0.0; self.n_topics];
1106 let mut hasher = DefaultHasher::new();
1107 text.hash(&mut hasher);
1108 let hash = hasher.finish();
1109
1110 for i in 0..self.n_topics {
1111 topic_features[i] = ((hash + i as u64) % 1000) as f64 / 1000.0;
1112 }
1113
1114 features.extend(topic_features);
1115 }
1116
1117 features
1118 }
1119}
1120
1121impl Transform<Vec<String>, Array2<f64>> for FittedDocumentKernelApproximation {
1122 fn transform(&self, documents: &Vec<String>) -> Result<Array2<f64>> {
1123 let n_documents = documents.len();
1124 let mut result = Array2::zeros((n_documents, self.n_components));
1125
1126 for (doc_idx, doc) in documents.iter().enumerate() {
1127 let features = self.extract_document_features(doc);
1128 let feature_array = Array1::from_vec(features);
1129
1130 for i in 0..self.n_components {
1131 let projected = self.random_weights.row(i).dot(&feature_array);
1132 result[[doc_idx, i]] = projected.tanh();
1133 }
1134 }
1135
1136 Ok(result)
1137 }
1138}
1139
1140#[allow(non_snake_case)]
1141#[cfg(test)]
1142mod tests {
1143 use super::*;
1144
1145 #[test]
1146 fn test_text_kernel_approximation() {
1147 let docs = vec![
1148 "This is a test document".to_string(),
1149 "Another test document here".to_string(),
1150 "Third document for testing".to_string(),
1151 ];
1152
1153 let text_kernel = TextKernelApproximation::new(50);
1154 let fitted = text_kernel.fit(&docs, &()).unwrap();
1155 let transformed = fitted.transform(&docs).unwrap();
1156
1157 assert_eq!(transformed.shape()[0], 3);
1158 assert_eq!(transformed.shape()[1], 50);
1159 }
1160
1161 #[test]
1162 fn test_semantic_kernel_approximation() {
1163 let docs = vec![
1164 "Semantic similarity test".to_string(),
1165 "Another semantic test".to_string(),
1166 ];
1167
1168 let semantic_kernel = SemanticKernelApproximation::new(30, 100);
1169 let fitted = semantic_kernel.fit(&docs, &()).unwrap();
1170 let transformed = fitted.transform(&docs).unwrap();
1171
1172 assert_eq!(transformed.shape()[0], 2);
1173 assert_eq!(transformed.shape()[1], 30);
1174 }
1175
1176 #[test]
1177 fn test_syntactic_kernel_approximation() {
1178 let docs = vec![
1179 "The cat sat on the mat".to_string(),
1180 "Dogs are running quickly".to_string(),
1181 ];
1182
1183 let syntactic_kernel = SyntacticKernelApproximation::new(40);
1184 let fitted = syntactic_kernel.fit(&docs, &()).unwrap();
1185 let transformed = fitted.transform(&docs).unwrap();
1186
1187 assert_eq!(transformed.shape()[0], 2);
1188 assert_eq!(transformed.shape()[1], 40);
1189 }
1190
1191 #[test]
1192 fn test_document_kernel_approximation() {
1193 let docs = vec![
1194 "This is a long document with multiple sentences. It contains various features."
1195 .to_string(),
1196 "Short doc.".to_string(),
1197 ];
1198
1199 let doc_kernel = DocumentKernelApproximation::new(25);
1200 let fitted = doc_kernel.fit(&docs, &()).unwrap();
1201 let transformed = fitted.transform(&docs).unwrap();
1202
1203 assert_eq!(transformed.shape()[0], 2);
1204 assert_eq!(transformed.shape()[1], 25);
1205 }
1206}