1use scirs2_core::ndarray::{s, Array1, Array2, ArrayView1, ArrayView2, Axis};
8use scirs2_core::numeric::{Float, FromPrimitive, Zero};
9use serde::{Deserialize, Serialize};
10use std::collections::{HashMap, HashSet};
11use std::fmt::Debug;
12
13use serde::{Deserialize as SerdeDeserialize, Serialize as SerdeSerialize};
14
15use crate::error::{ClusteringError, Result};
16use crate::vq::euclidean_distance;
17use statrs::statistics::Statistics;
18
19#[derive(Debug, Clone)]
21#[cfg_attr(feature = "serde", derive(SerdeSerialize, SerdeDeserialize))]
22pub enum TextRepresentation {
23 TfIdf {
25 vectors: Array2<f64>,
27 vocabulary: HashMap<String, usize>,
29 },
30 WordEmbeddings {
32 vectors: Array2<f64>,
34 embedding_dim: usize,
36 },
37 ContextualEmbeddings {
39 vectors: Array2<f64>,
41 model_name: String,
43 },
44 DocumentTerm {
46 matrix: Array2<f64>,
48 vocabulary: Vec<String>,
50 },
51}
52
53#[derive(Debug, Clone, Copy, PartialEq)]
55#[cfg_attr(feature = "serde", derive(SerdeSerialize, SerdeDeserialize))]
56pub enum SemanticSimilarity {
57 Cosine,
59 Jaccard,
61 Euclidean,
63 Manhattan,
65 Pearson,
67 JensenShannon,
69 Hellinger,
71 Bhattacharyya,
73}
74
75#[derive(Debug, Clone)]
77#[cfg_attr(feature = "serde", derive(SerdeSerialize, SerdeDeserialize))]
78pub struct SemanticClusteringConfig {
79 pub similarity_metric: SemanticSimilarity,
81 pub n_clusters: Option<usize>,
83 pub similarity_threshold: f64,
85 pub max_iterations: usize,
87 pub tolerance: f64,
89 pub use_dimension_reduction: bool,
91 pub target_dimensions: Option<usize>,
93 pub preprocessing: TextPreprocessing,
95}
96
97impl Default for SemanticClusteringConfig {
98 fn default() -> Self {
99 Self {
100 similarity_metric: SemanticSimilarity::Cosine,
101 n_clusters: Some(10),
102 similarity_threshold: 0.5,
103 max_iterations: 100,
104 tolerance: 1e-4,
105 use_dimension_reduction: false,
106 target_dimensions: None,
107 preprocessing: TextPreprocessing::default(),
108 }
109 }
110}
111
112#[derive(Debug, Clone)]
114#[cfg_attr(feature = "serde", derive(SerdeSerialize, SerdeDeserialize))]
115pub struct TextPreprocessing {
116 pub normalize_vectors: bool,
118 pub remove_zero_variance: bool,
120 pub apply_tfidf: bool,
122 pub min_df: f64,
124 pub max_df: f64,
126 pub max_features: Option<usize>,
128}
129
130impl Default for TextPreprocessing {
131 fn default() -> Self {
132 Self {
133 normalize_vectors: true,
134 remove_zero_variance: true,
135 apply_tfidf: true,
136 min_df: 0.01,
137 max_df: 0.95,
138 max_features: None,
139 }
140 }
141}
142
143pub struct SemanticKMeans {
145 config: SemanticClusteringConfig,
146 centroids: Option<Array2<f64>>,
147 labels: Option<Array1<usize>>,
148 inertia: Option<f64>,
149 n_iterations: Option<usize>,
150}
151
152impl SemanticKMeans {
153 pub fn new(config: SemanticClusteringConfig) -> Self {
155 Self {
156 config,
157 centroids: None,
158 labels: None,
159 inertia: None,
160 n_iterations: None,
161 }
162 }
163
164 pub fn fit(&mut self, text_repr: &TextRepresentation) -> Result<()> {
166 let vectors = self.extract_vectors(text_repr)?;
167 let preprocessed = self.preprocess_vectors(vectors)?;
168
169 let n_clusters = self.config.n_clusters.unwrap_or(10);
170 self.fit_kmeans(preprocessed.view(), n_clusters)?;
171
172 Ok(())
173 }
174
175 fn extract_vectors(&self, text_repr: &TextRepresentation) -> Result<Array2<f64>> {
177 match text_repr {
178 TextRepresentation::TfIdf { vectors, .. } => Ok(vectors.clone()),
179 TextRepresentation::WordEmbeddings { vectors, .. } => Ok(vectors.clone()),
180 TextRepresentation::ContextualEmbeddings { vectors, .. } => Ok(vectors.clone()),
181 TextRepresentation::DocumentTerm { matrix, .. } => Ok(matrix.clone()),
182 }
183 }
184
185 fn preprocess_vectors(&self, vectors: Array2<f64>) -> Result<Array2<f64>> {
187 let mut processed = vectors;
188
189 if self.config.preprocessing.apply_tfidf {
191 processed = self.apply_tfidf_weighting(processed)?;
192 }
193
194 if self.config.preprocessing.remove_zero_variance {
196 processed = self.remove_zero_variance_features(processed)?;
197 }
198
199 if self.config.preprocessing.normalize_vectors {
201 processed = self.normalize_vectors(processed)?;
202 }
203
204 if self.config.use_dimension_reduction {
206 if let Some(target_dim) = self.config.target_dimensions {
207 processed = self.reduce_dimensions(processed, target_dim)?;
208 }
209 }
210
211 Ok(processed)
212 }
213
214 fn apply_tfidf_weighting(&self, matrix: Array2<f64>) -> Result<Array2<f64>> {
216 let (n_docs, n_terms) = matrix.dim();
217 let mut tfidf_matrix = matrix.clone();
218
219 let mut df = Array1::zeros(n_terms);
221 for term_idx in 0..n_terms {
222 let mut doc_count = 0;
223 for doc_idx in 0..n_docs {
224 if matrix[[doc_idx, term_idx]] > 0.0 {
225 doc_count += 1;
226 }
227 }
228 df[term_idx] = doc_count as f64;
229 }
230
231 for doc_idx in 0..n_docs {
233 for term_idx in 0..n_terms {
234 let tf = matrix[[doc_idx, term_idx]];
235 if tf > 0.0 && df[term_idx] > 0.0 {
236 let idf = (n_docs as f64 / df[term_idx]).ln();
237 tfidf_matrix[[doc_idx, term_idx]] = tf * idf;
238 }
239 }
240 }
241
242 Ok(tfidf_matrix)
243 }
244
245 fn remove_zero_variance_features(&self, matrix: Array2<f64>) -> Result<Array2<f64>> {
247 let mut feature_mask = Vec::new();
248
249 for col_idx in 0..matrix.ncols() {
250 let column = matrix.column(col_idx);
251 let mean = column.mean();
252 let variance =
253 column.iter().map(|&x| (x - mean).powi(2)).sum::<f64>() / column.len() as f64;
254
255 feature_mask.push(variance > 1e-10); }
257
258 let valid_features: Vec<usize> = feature_mask
259 .iter()
260 .enumerate()
261 .filter_map(|(i, &keep)| if keep { Some(i) } else { None })
262 .collect();
263
264 if valid_features.is_empty() {
265 return Err(ClusteringError::InvalidInput(
266 "All features have zero variance".to_string(),
267 ));
268 }
269
270 let filtered = matrix.select(Axis(1), &valid_features);
271 Ok(filtered)
272 }
273
274 fn normalize_vectors(&self, matrix: Array2<f64>) -> Result<Array2<f64>> {
276 let mut normalized = matrix.clone();
277
278 for mut row in normalized.rows_mut() {
279 let norm = (row.iter().map(|&x| x * x).sum::<f64>()).sqrt();
280 if norm > 1e-10 {
281 row.mapv_inplace(|x| x / norm);
282 }
283 }
284
285 Ok(normalized)
286 }
287
288 fn reduce_dimensions(&self, matrix: Array2<f64>, target_dim: usize) -> Result<Array2<f64>> {
290 let (_n_samples, n_features) = matrix.dim();
291
292 if target_dim >= n_features {
293 return Ok(matrix);
294 }
295
296 let reduced = matrix.slice(s![.., 0..target_dim]).to_owned();
299 Ok(reduced)
300 }
301
302 fn fit_kmeans(&mut self, data: ArrayView2<f64>, k: usize) -> Result<()> {
304 let (n_samples, n_features) = data.dim();
305
306 if k > n_samples {
307 return Err(ClusteringError::InvalidInput(
308 "Number of clusters cannot exceed number of samples".to_string(),
309 ));
310 }
311
312 let mut centroids = Array2::zeros((k, n_features));
314 self.initialize_centroids_plus_plus(&mut centroids, data)?;
315
316 let mut labels = Array1::zeros(n_samples);
317 let mut prev_inertia = f64::INFINITY;
318 let mut n_iter = 0;
319
320 for iter in 0..self.config.max_iterations {
321 n_iter = iter + 1;
322
323 let mut inertia = 0.0;
325 for (i, sample) in data.rows().into_iter().enumerate() {
326 let (best_cluster, distance) =
327 self.find_closest_centroid(sample, centroids.view())?;
328 labels[i] = best_cluster;
329 inertia += distance;
330 }
331
332 if (prev_inertia - inertia).abs() < self.config.tolerance {
334 break;
335 }
336 prev_inertia = inertia;
337
338 self.update_centroids(&mut centroids, data, labels.view())?;
340 }
341
342 self.centroids = Some(centroids);
343 self.labels = Some(labels);
344 self.inertia = Some(prev_inertia);
345 self.n_iterations = Some(n_iter);
346
347 Ok(())
348 }
349
350 fn initialize_centroids_plus_plus(
352 &self,
353 centroids: &mut Array2<f64>,
354 data: ArrayView2<f64>,
355 ) -> Result<()> {
356 let n_samples = data.nrows();
357 let k = centroids.nrows();
358
359 centroids.row_mut(0).assign(&data.row(0));
361
362 for i in 1..k {
364 let mut distances = Array1::zeros(n_samples);
365 let mut total_distance = 0.0;
366
367 for (j, sample) in data.rows().into_iter().enumerate() {
369 let mut min_dist = f64::INFINITY;
370 for centroid_idx in 0..i {
371 let dist = self.compute_distance(sample, centroids.row(centroid_idx))?;
372 if dist < min_dist {
373 min_dist = dist;
374 }
375 }
376 distances[j] = min_dist * min_dist;
377 total_distance += distances[j];
378 }
379
380 if total_distance > 0.0 {
382 let target = total_distance * 0.5; let mut cumsum = 0.0;
384 for (j, &dist) in distances.iter().enumerate() {
385 cumsum += dist;
386 if cumsum >= target {
387 centroids.row_mut(i).assign(&data.row(j));
388 break;
389 }
390 }
391 } else {
392 if i < n_samples {
394 centroids.row_mut(i).assign(&data.row(i));
395 }
396 }
397 }
398
399 Ok(())
400 }
401
402 fn find_closest_centroid(
404 &self,
405 sample: ArrayView1<f64>,
406 centroids: ArrayView2<f64>,
407 ) -> Result<(usize, f64)> {
408 let mut min_distance = f64::INFINITY;
409 let mut best_cluster = 0;
410
411 for (i, centroid) in centroids.rows().into_iter().enumerate() {
412 let distance = self.compute_distance(sample, centroid)?;
413 if distance < min_distance {
414 min_distance = distance;
415 best_cluster = i;
416 }
417 }
418
419 Ok((best_cluster, min_distance))
420 }
421
422 fn compute_distance(&self, a: ArrayView1<f64>, b: ArrayView1<f64>) -> Result<f64> {
424 match self.config.similarity_metric {
425 SemanticSimilarity::Cosine => {
426 let similarity = self.cosine_similarity(a, b)?;
427 Ok(1.0 - similarity) }
429 SemanticSimilarity::Euclidean => Ok(euclidean_distance(a, b)),
430 SemanticSimilarity::Manhattan => Ok(self.manhattan_distance(a, b)?),
431 SemanticSimilarity::Jaccard => {
432 let similarity = self.jaccard_similarity(a, b)?;
433 Ok(1.0 - similarity)
434 }
435 SemanticSimilarity::Pearson => {
436 let correlation = self.pearson_correlation(a, b)?;
437 Ok(1.0 - correlation.abs()) }
439 SemanticSimilarity::JensenShannon => self.jensen_shannon_distance(a, b),
440 SemanticSimilarity::Hellinger => self.hellinger_distance(a, b),
441 SemanticSimilarity::Bhattacharyya => self.bhattacharyya_distance(a, b),
442 }
443 }
444
445 fn cosine_similarity(&self, a: ArrayView1<f64>, b: ArrayView1<f64>) -> Result<f64> {
447 let dot_product = a.dot(&b);
448 let norm_a = (a.dot(&a)).sqrt();
449 let norm_b = (b.dot(&b)).sqrt();
450
451 if norm_a == 0.0 || norm_b == 0.0 {
452 Ok(0.0)
453 } else {
454 Ok(dot_product / (norm_a * norm_b))
455 }
456 }
457
458 fn manhattan_distance(&self, a: ArrayView1<f64>, b: ArrayView1<f64>) -> Result<f64> {
460 Ok(a.iter().zip(b.iter()).map(|(&x, &y)| (x - y).abs()).sum())
461 }
462
463 fn jaccard_similarity(&self, a: ArrayView1<f64>, b: ArrayView1<f64>) -> Result<f64> {
465 let threshold = 1e-10;
466 let mut intersection = 0.0;
467 let mut union = 0.0;
468
469 for (&x, &y) in a.iter().zip(b.iter()) {
470 let x_present = x > threshold;
471 let y_present = y > threshold;
472
473 if x_present && y_present {
474 intersection += 1.0;
475 }
476 if x_present || y_present {
477 union += 1.0;
478 }
479 }
480
481 if union == 0.0 {
482 Ok(1.0) } else {
484 Ok(intersection / union)
485 }
486 }
487
488 fn pearson_correlation(&self, a: ArrayView1<f64>, b: ArrayView1<f64>) -> Result<f64> {
490 let n = a.len() as f64;
491 if n < 2.0 {
492 return Ok(0.0);
493 }
494
495 let mean_a = a.mean();
496 let mean_b = b.mean();
497
498 let mut numerator = 0.0;
499 let mut sum_sq_a = 0.0;
500 let mut sum_sq_b = 0.0;
501
502 for (&x, &y) in a.iter().zip(b.iter()) {
503 let diff_a = x - mean_a;
504 let diff_b = y - mean_b;
505 numerator += diff_a * diff_b;
506 sum_sq_a += diff_a * diff_a;
507 sum_sq_b += diff_b * diff_b;
508 }
509
510 let denominator = (sum_sq_a * sum_sq_b).sqrt();
511 if denominator == 0.0 {
512 Ok(0.0)
513 } else {
514 Ok(numerator / denominator)
515 }
516 }
517
518 fn jensen_shannon_distance(&self, a: ArrayView1<f64>, b: ArrayView1<f64>) -> Result<f64> {
520 let sum_a: f64 = a.iter().map(|&x| x.max(0.0)).sum();
522 let sum_b: f64 = b.iter().map(|&x| x.max(0.0)).sum();
523
524 if sum_a == 0.0 || sum_b == 0.0 {
525 return Ok(1.0);
526 }
527
528 let p: Vec<f64> = a.iter().map(|&x| x.max(0.0) / sum_a).collect();
529 let q: Vec<f64> = b.iter().map(|&x| x.max(0.0) / sum_b).collect();
530
531 let m: Vec<f64> = p
533 .iter()
534 .zip(q.iter())
535 .map(|(&x, &y)| (x + y) / 2.0)
536 .collect();
537
538 let kl_pm = self.kl_divergence(&p, &m);
540 let kl_qm = self.kl_divergence(&q, &m);
541
542 let js = (kl_pm + kl_qm) / 2.0;
544 Ok(js.sqrt()) }
546
547 fn kl_divergence(&self, p: &[f64], q: &[f64]) -> f64 {
549 let mut kl = 0.0;
550 for (&pi, &qi) in p.iter().zip(q.iter()) {
551 if pi > 1e-10 && qi > 1e-10 {
552 kl += pi * (pi / qi).ln();
553 }
554 }
555 kl
556 }
557
558 fn hellinger_distance(&self, a: ArrayView1<f64>, b: ArrayView1<f64>) -> Result<f64> {
560 let sum_a: f64 = a.iter().map(|&x| x.max(0.0)).sum();
562 let sum_b: f64 = b.iter().map(|&x| x.max(0.0)).sum();
563
564 if sum_a == 0.0 || sum_b == 0.0 {
565 return Ok(1.0);
566 }
567
568 let mut sum_sqrt_products = 0.0;
569 for (&x, &y) in a.iter().zip(b.iter()) {
570 let p = x.max(0.0) / sum_a;
571 let q = y.max(0.0) / sum_b;
572 sum_sqrt_products += (p * q).sqrt();
573 }
574
575 let hellinger = (1.0 - sum_sqrt_products).sqrt();
576 Ok(hellinger)
577 }
578
579 fn bhattacharyya_distance(&self, a: ArrayView1<f64>, b: ArrayView1<f64>) -> Result<f64> {
581 let sum_a: f64 = a.iter().map(|&x| x.max(0.0)).sum();
583 let sum_b: f64 = b.iter().map(|&x| x.max(0.0)).sum();
584
585 if sum_a == 0.0 || sum_b == 0.0 {
586 return Ok(f64::INFINITY);
587 }
588
589 let mut bc = 0.0; for (&x, &y) in a.iter().zip(b.iter()) {
591 let p = x.max(0.0) / sum_a;
592 let q = y.max(0.0) / sum_b;
593 bc += (p * q).sqrt();
594 }
595
596 if bc <= 0.0 {
597 Ok(f64::INFINITY)
598 } else {
599 Ok(-bc.ln())
600 }
601 }
602
603 fn update_centroids(
605 &self,
606 centroids: &mut Array2<f64>,
607 data: ArrayView2<f64>,
608 labels: ArrayView1<usize>,
609 ) -> Result<()> {
610 let k = centroids.nrows();
611 let n_features = centroids.ncols();
612
613 centroids.fill(0.0);
615 let mut cluster_sizes = vec![0; k];
616
617 for (i, &label) in labels.iter().enumerate() {
619 if label < k {
620 for j in 0..n_features {
621 centroids[[label, j]] += data[[i, j]];
622 }
623 cluster_sizes[label] += 1;
624 }
625 }
626
627 for i in 0..k {
629 if cluster_sizes[i] > 0 {
630 for j in 0..n_features {
631 centroids[[i, j]] /= cluster_sizes[i] as f64;
632 }
633 }
634 }
635
636 Ok(())
637 }
638
639 pub fn predict(&self, text_repr: &TextRepresentation) -> Result<Array1<usize>> {
641 let vectors = self.extract_vectors(text_repr)?;
642 let preprocessed = self.preprocess_vectors(vectors)?;
643
644 if let Some(ref centroids) = self.centroids {
645 let mut labels = Array1::zeros(preprocessed.nrows());
646
647 for (i, sample) in preprocessed.rows().into_iter().enumerate() {
648 let (best_cluster, _distance) =
649 self.find_closest_centroid(sample, centroids.view())?;
650 labels[i] = best_cluster;
651 }
652
653 Ok(labels)
654 } else {
655 Err(ClusteringError::InvalidInput(
656 "Model has not been fitted yet".to_string(),
657 ))
658 }
659 }
660
661 pub fn cluster_centers(&self) -> Option<&Array2<f64>> {
663 self.centroids.as_ref()
664 }
665
666 pub fn inertia(&self) -> Option<f64> {
668 self.inertia
669 }
670
671 pub fn n_iterations(&self) -> Option<usize> {
673 self.n_iterations
674 }
675}
676
677pub struct SemanticHierarchical {
679 config: SemanticClusteringConfig,
680 linkage_matrix: Option<Array2<f64>>,
681 n_clusters: Option<usize>,
682}
683
684impl SemanticHierarchical {
685 pub fn new(config: SemanticClusteringConfig) -> Self {
687 Self {
688 config,
689 linkage_matrix: None,
690 n_clusters: None,
691 }
692 }
693
694 pub fn fit(&mut self, text_repr: &TextRepresentation) -> Result<()> {
696 let vectors = self.extract_vectors(text_repr)?;
697 let preprocessed = self.preprocess_vectors(vectors)?;
698
699 self.fit_hierarchical(preprocessed.view())?;
700 Ok(())
701 }
702
703 fn extract_vectors(&self, text_repr: &TextRepresentation) -> Result<Array2<f64>> {
705 match text_repr {
706 TextRepresentation::TfIdf { vectors, .. } => Ok(vectors.clone()),
707 TextRepresentation::WordEmbeddings { vectors, .. } => Ok(vectors.clone()),
708 TextRepresentation::ContextualEmbeddings { vectors, .. } => Ok(vectors.clone()),
709 TextRepresentation::DocumentTerm { matrix, .. } => Ok(matrix.clone()),
710 }
711 }
712
713 fn preprocess_vectors(&self, vectors: Array2<f64>) -> Result<Array2<f64>> {
715 if self.config.preprocessing.normalize_vectors {
717 let mut normalized = vectors.clone();
718 for mut row in normalized.rows_mut() {
719 let norm = (row.iter().map(|&x| x * x).sum::<f64>()).sqrt();
720 if norm > 1e-10 {
721 row.mapv_inplace(|x| x / norm);
722 }
723 }
724 Ok(normalized)
725 } else {
726 Ok(vectors)
727 }
728 }
729
730 fn fit_hierarchical(&mut self, data: ArrayView2<f64>) -> Result<()> {
732 let n_samples = data.nrows();
733
734 if n_samples < 2 {
735 return Err(ClusteringError::InvalidInput(
736 "Need at least 2 samples for hierarchical clustering".to_string(),
737 ));
738 }
739
740 let mut distance_matrix = Array2::zeros((n_samples, n_samples));
742 for i in 0..n_samples {
743 for j in i + 1..n_samples {
744 let distance = self.compute_distance(data.row(i), data.row(j))?;
745 distance_matrix[[i, j]] = distance;
746 distance_matrix[[j, i]] = distance;
747 }
748 }
749
750 let mut clusters: Vec<HashSet<usize>> = (0..n_samples)
752 .map(|i| {
753 let mut set = HashSet::new();
754 set.insert(i);
755 set
756 })
757 .collect();
758
759 let mut linkage_steps = Vec::new();
760
761 while clusters.len() > 1 {
762 let mut min_distance = f64::INFINITY;
763 let mut merge_i = 0;
764 let mut merge_j = 1;
765
766 for i in 0..clusters.len() {
768 for j in i + 1..clusters.len() {
769 let distance =
770 self.cluster_distance(&clusters[i], &clusters[j], &distance_matrix);
771 if distance < min_distance {
772 min_distance = distance;
773 merge_i = i;
774 merge_j = j;
775 }
776 }
777 }
778
779 linkage_steps.push([
781 merge_i as f64,
782 merge_j as f64,
783 min_distance,
784 (clusters[merge_i].len() + clusters[merge_j].len()) as f64,
785 ]);
786
787 let cluster_j = clusters.remove(merge_j);
789 clusters[merge_i].extend(cluster_j);
790 }
791
792 let linkage_matrix = Array2::from_shape_vec(
794 (linkage_steps.len(), 4),
795 linkage_steps.into_iter().flatten().collect(),
796 )
797 .unwrap();
798
799 self.linkage_matrix = Some(linkage_matrix);
800 Ok(())
801 }
802
803 fn cluster_distance(
805 &self,
806 cluster_a: &HashSet<usize>,
807 cluster_b: &HashSet<usize>,
808 distance_matrix: &Array2<f64>,
809 ) -> f64 {
810 let mut min_distance = f64::INFINITY;
811
812 for &i in cluster_a {
813 for &j in cluster_b {
814 let distance = distance_matrix[[i, j]];
815 if distance < min_distance {
816 min_distance = distance;
817 }
818 }
819 }
820
821 min_distance
822 }
823
824 fn compute_distance(&self, a: ArrayView1<f64>, b: ArrayView1<f64>) -> Result<f64> {
826 match self.config.similarity_metric {
827 SemanticSimilarity::Cosine => {
828 let dot_product = a.dot(&b);
829 let norm_a = (a.dot(&a)).sqrt();
830 let norm_b = (b.dot(&b)).sqrt();
831
832 if norm_a == 0.0 || norm_b == 0.0 {
833 Ok(1.0)
834 } else {
835 let similarity = dot_product / (norm_a * norm_b);
836 Ok(1.0 - similarity)
837 }
838 }
839 SemanticSimilarity::Euclidean => Ok(euclidean_distance(a, b)),
840 _ => {
841 Ok(euclidean_distance(a, b))
843 }
844 }
845 }
846
847 pub fn linkage_matrix(&self) -> Option<&Array2<f64>> {
849 self.linkage_matrix.as_ref()
850 }
851}
852
853pub struct TopicBasedClustering {
855 config: SemanticClusteringConfig,
856 topics: Option<Array2<f64>>,
857 document_topic_distributions: Option<Array2<f64>>,
858 n_topics: usize,
859}
860
861impl TopicBasedClustering {
862 pub fn new(config: SemanticClusteringConfig, n_topics: usize) -> Self {
864 Self {
865 config,
866 topics: None,
867 document_topic_distributions: None,
868 n_topics,
869 }
870 }
871
872 pub fn fit(&mut self, text_repr: &TextRepresentation) -> Result<()> {
874 let vectors = self.extract_vectors(text_repr)?;
875 let preprocessed = self.preprocess_vectors(vectors)?;
876
877 self.fit_topics(preprocessed.view())?;
878 Ok(())
879 }
880
881 fn extract_vectors(&self, text_repr: &TextRepresentation) -> Result<Array2<f64>> {
883 match text_repr {
884 TextRepresentation::TfIdf { vectors, .. } => Ok(vectors.clone()),
885 TextRepresentation::DocumentTerm { matrix, .. } => Ok(matrix.clone()),
886 _ => Err(ClusteringError::InvalidInput(
887 "Topic modeling requires TF-IDF or document-term matrix".to_string(),
888 )),
889 }
890 }
891
892 fn preprocess_vectors(&self, vectors: Array2<f64>) -> Result<Array2<f64>> {
894 let mut processed = vectors.mapv(|x| x.max(0.0));
896
897 for mut row in processed.rows_mut() {
899 let sum: f64 = row.sum();
900 if sum > 1e-10 {
901 row.mapv_inplace(|x| x / sum);
902 }
903 }
904
905 Ok(processed)
906 }
907
908 fn fit_topics(&mut self, data: ArrayView2<f64>) -> Result<()> {
910 let (n_docs, n_terms) = data.dim();
911
912 let mut topics = Array2::from_elem((self.n_topics, n_terms), 1.0 / n_terms as f64);
914 let mut doc_topics = Array2::from_elem((n_docs, self.n_topics), 1.0 / self.n_topics as f64);
915
916 for _iter in 0..self.config.max_iterations {
918 for doc_idx in 0..n_docs {
920 for topic_idx in 0..self.n_topics {
921 let mut numerator = 0.0;
922 let mut denominator = 0.0;
923
924 for term_idx in 0..n_terms {
925 let observed = data[[doc_idx, term_idx]];
926 let expected = topics[[topic_idx, term_idx]];
927
928 if expected > 1e-10 {
929 numerator += observed * expected;
930 denominator += expected;
931 }
932 }
933
934 if denominator > 1e-10 {
935 doc_topics[[doc_idx, topic_idx]] = numerator / denominator;
936 }
937 }
938
939 let sum: f64 = doc_topics.row(doc_idx).sum();
941 if sum > 1e-10 {
942 for topic_idx in 0..self.n_topics {
943 doc_topics[[doc_idx, topic_idx]] /= sum;
944 }
945 }
946 }
947
948 for topic_idx in 0..self.n_topics {
950 for term_idx in 0..n_terms {
951 let mut numerator = 0.0;
952 let mut denominator = 0.0;
953
954 for doc_idx in 0..n_docs {
955 let observed = data[[doc_idx, term_idx]];
956 let doc_topic_weight = doc_topics[[doc_idx, topic_idx]];
957
958 numerator += observed * doc_topic_weight;
959 denominator += doc_topic_weight;
960 }
961
962 if denominator > 1e-10 {
963 topics[[topic_idx, term_idx]] = numerator / denominator;
964 }
965 }
966
967 let sum: f64 = topics.row(topic_idx).sum();
969 if sum > 1e-10 {
970 for term_idx in 0..n_terms {
971 topics[[topic_idx, term_idx]] /= sum;
972 }
973 }
974 }
975 }
976
977 self.topics = Some(topics);
978 self.document_topic_distributions = Some(doc_topics);
979 Ok(())
980 }
981
982 pub fn predict(&self, text_repr: &TextRepresentation) -> Result<Array1<usize>> {
984 if let Some(ref doc_topics) = self.document_topic_distributions {
985 let mut labels = Array1::zeros(doc_topics.nrows());
986
987 for (doc_idx, doc_topic_dist) in doc_topics.rows().into_iter().enumerate() {
988 let mut max_prob = 0.0;
989 let mut best_topic = 0;
990
991 for (topic_idx, &prob) in doc_topic_dist.iter().enumerate() {
992 if prob > max_prob {
993 max_prob = prob;
994 best_topic = topic_idx;
995 }
996 }
997
998 labels[doc_idx] = best_topic;
999 }
1000
1001 Ok(labels)
1002 } else {
1003 Err(ClusteringError::InvalidInput(
1004 "Model has not been fitted yet".to_string(),
1005 ))
1006 }
1007 }
1008
1009 pub fn topics(&self) -> Option<&Array2<f64>> {
1011 self.topics.as_ref()
1012 }
1013
1014 pub fn document_topic_distributions(&self) -> Option<&Array2<f64>> {
1016 self.document_topic_distributions.as_ref()
1017 }
1018}
1019
1020#[allow(dead_code)]
1024pub fn semantic_kmeans(
1025 text_repr: &TextRepresentation,
1026 n_clusters: usize,
1027 similarity_metric: SemanticSimilarity,
1028) -> Result<(Array2<f64>, Array1<usize>)> {
1029 let config = SemanticClusteringConfig {
1030 n_clusters: Some(n_clusters),
1031 similarity_metric,
1032 ..Default::default()
1033 };
1034
1035 let mut clusterer = SemanticKMeans::new(config);
1036 clusterer.fit(text_repr)?;
1037
1038 let centers = clusterer
1039 .cluster_centers()
1040 .ok_or_else(|| {
1041 ClusteringError::ComputationError("Failed to get cluster centers".to_string())
1042 })?
1043 .clone();
1044
1045 let labels = clusterer.predict(text_repr)?;
1046
1047 Ok((centers, labels))
1048}
1049
1050#[allow(dead_code)]
1052pub fn semantic_hierarchical(
1053 text_repr: &TextRepresentation,
1054 similarity_metric: SemanticSimilarity,
1055) -> Result<Array2<f64>> {
1056 let config = SemanticClusteringConfig {
1057 similarity_metric,
1058 ..Default::default()
1059 };
1060
1061 let mut clusterer = SemanticHierarchical::new(config);
1062 clusterer.fit(text_repr)?;
1063
1064 clusterer
1065 .linkage_matrix()
1066 .ok_or_else(|| {
1067 ClusteringError::ComputationError("Failed to get linkage matrix".to_string())
1068 })
1069 .cloned()
1070}
1071
1072#[allow(dead_code)]
1074pub fn topic_clustering(
1075 text_repr: &TextRepresentation,
1076 n_topics: usize,
1077) -> Result<(Array2<f64>, Array2<f64>, Array1<usize>)> {
1078 let config = SemanticClusteringConfig::default();
1079 let mut clusterer = TopicBasedClustering::new(config, n_topics);
1080
1081 clusterer.fit(text_repr)?;
1082
1083 let _topics = clusterer
1084 .topics()
1085 .ok_or_else(|| ClusteringError::ComputationError("Failed to get topics".to_string()))?
1086 .clone();
1087
1088 let doc_topics = clusterer
1089 .document_topic_distributions()
1090 .ok_or_else(|| {
1091 ClusteringError::ComputationError(
1092 "Failed to get document-topic distributions".to_string(),
1093 )
1094 })?
1095 .clone();
1096
1097 let labels = clusterer.predict(text_repr)?;
1098
1099 Ok((_topics, doc_topics, labels))
1100}
1101
1102#[cfg(test)]
1103mod tests {
1104 use super::*;
1105 use scirs2_core::ndarray::Array2;
1106
1107 #[test]
1108 fn test_semantic_kmeans_basic() {
1109 let vectors = Array2::from_shape_vec(
1111 (4, 3),
1112 vec![1.0, 0.0, 0.0, 0.9, 0.1, 0.0, 0.0, 0.0, 1.0, 0.0, 0.1, 0.9],
1113 )
1114 .unwrap();
1115
1116 let text_repr = TextRepresentation::TfIdf {
1117 vectors,
1118 vocabulary: HashMap::new(),
1119 };
1120
1121 let result = semantic_kmeans(&text_repr, 2, SemanticSimilarity::Cosine);
1122 assert!(result.is_ok());
1123
1124 let (centers, labels) = result.unwrap();
1125 assert_eq!(centers.nrows(), 2);
1126 assert_eq!(labels.len(), 4);
1127 }
1128
1129 #[test]
1130 fn test_similarity_metrics() {
1131 let a = scirs2_core::ndarray::Array1::from_vec(vec![1.0, 0.0, 0.0]);
1132 let b = scirs2_core::ndarray::Array1::from_vec(vec![0.0, 1.0, 0.0]);
1133
1134 let config = SemanticClusteringConfig::default();
1135 let clusterer = SemanticKMeans::new(config);
1136
1137 let cosine_sim = clusterer.cosine_similarity(a.view(), b.view()).unwrap();
1139 assert_eq!(cosine_sim, 0.0); let manhattan = clusterer.manhattan_distance(a.view(), b.view()).unwrap();
1143 assert_eq!(manhattan, 2.0);
1144 }
1145
1146 #[test]
1147 fn testtext_preprocessing() {
1148 let config = SemanticClusteringConfig::default();
1149 let clusterer = SemanticKMeans::new(config);
1150
1151 let matrix = Array2::from_shape_vec((2, 3), vec![3.0, 4.0, 0.0, 1.0, 2.0, 2.0]).unwrap();
1152
1153 let normalized = clusterer.normalize_vectors(matrix).unwrap();
1154
1155 for row in normalized.rows() {
1157 let norm = (row.iter().map(|&x| x * x).sum::<f64>()).sqrt();
1158 assert!((norm - 1.0).abs() < 1e-10);
1159 }
1160 }
1161
1162 #[test]
1163 fn test_topic_clustering_basic() {
1164 let matrix = Array2::from_shape_vec(
1166 (3, 4),
1167 vec![2.0, 0.0, 1.0, 0.0, 0.0, 3.0, 0.0, 1.0, 1.0, 1.0, 2.0, 1.0],
1168 )
1169 .unwrap();
1170
1171 let text_repr = TextRepresentation::DocumentTerm {
1172 matrix,
1173 vocabulary: vec![
1174 "word1".to_string(),
1175 "word2".to_string(),
1176 "word3".to_string(),
1177 "word4".to_string(),
1178 ],
1179 };
1180
1181 let result = topic_clustering(&text_repr, 2);
1182 assert!(result.is_ok());
1183
1184 let (topics, doc_topics, labels) = result.unwrap();
1185 assert_eq!(topics.nrows(), 2);
1186 assert_eq!(doc_topics.nrows(), 3);
1187 assert_eq!(labels.len(), 3);
1188 }
1189
1190 #[test]
1191 fn test_semantic_similarity_enum() {
1192 assert_eq!(SemanticSimilarity::Cosine, SemanticSimilarity::Cosine);
1193 assert_ne!(SemanticSimilarity::Cosine, SemanticSimilarity::Euclidean);
1194 }
1195}