scirs2_cluster/
text_clustering.rs

1//! Text clustering algorithms with semantic similarity support
2//!
3//! This module provides specialized clustering algorithms for text data that leverage
4//! semantic similarity measures rather than traditional distance metrics. It includes
5//! algorithms optimized for document clustering, sentence clustering, and topic modeling.
6
7use scirs2_core::ndarray::{s, Array1, Array2, ArrayView1, ArrayView2, Axis};
8use scirs2_core::numeric::{Float, FromPrimitive, Zero};
9use serde::{Deserialize, Serialize};
10use std::collections::{HashMap, HashSet};
11use std::fmt::Debug;
12
13use serde::{Deserialize as SerdeDeserialize, Serialize as SerdeSerialize};
14
15use crate::error::{ClusteringError, Result};
16use crate::vq::euclidean_distance;
17use statrs::statistics::Statistics;
18
19/// Text representation types for clustering
20#[derive(Debug, Clone)]
21#[cfg_attr(feature = "serde", derive(SerdeSerialize, SerdeDeserialize))]
22pub enum TextRepresentation {
23    /// Term Frequency-Inverse Document Frequency vectors
24    TfIdf {
25        /// TF-IDF matrix (documents x terms)
26        vectors: Array2<f64>,
27        /// Vocabulary mapping
28        vocabulary: HashMap<String, usize>,
29    },
30    /// Word embeddings (Word2Vec, GloVe, etc.)
31    WordEmbeddings {
32        /// Embedding vectors (documents x embedding_dim)
33        vectors: Array2<f64>,
34        /// Embedding dimension
35        embedding_dim: usize,
36    },
37    /// Contextualized embeddings (BERT, RoBERTa, etc.)
38    ContextualEmbeddings {
39        /// Embedding vectors (documents x embedding_dim)
40        vectors: Array2<f64>,
41        /// Model name used for embeddings
42        model_name: String,
43    },
44    /// Document-Term matrix for traditional approaches
45    DocumentTerm {
46        /// Document-term matrix (documents x terms)
47        matrix: Array2<f64>,
48        /// Term vocabulary
49        vocabulary: Vec<String>,
50    },
51}
52
53/// Semantic similarity metrics for text clustering
54#[derive(Debug, Clone, Copy, PartialEq)]
55#[cfg_attr(feature = "serde", derive(SerdeSerialize, SerdeDeserialize))]
56pub enum SemanticSimilarity {
57    /// Cosine similarity (most common for text)
58    Cosine,
59    /// Jaccard similarity for binary/sparse features
60    Jaccard,
61    /// Euclidean distance (L2 norm)
62    Euclidean,
63    /// Manhattan distance (L1 norm)
64    Manhattan,
65    /// Pearson correlation coefficient
66    Pearson,
67    /// Jensen-Shannon divergence
68    JensenShannon,
69    /// Hellinger distance
70    Hellinger,
71    /// Bhattacharyya distance
72    Bhattacharyya,
73}
74
75/// Configuration for semantic text clustering
76#[derive(Debug, Clone)]
77#[cfg_attr(feature = "serde", derive(SerdeSerialize, SerdeDeserialize))]
78pub struct SemanticClusteringConfig {
79    /// Similarity metric to use
80    pub similarity_metric: SemanticSimilarity,
81    /// Number of clusters (for k-means style algorithms)
82    pub n_clusters: Option<usize>,
83    /// Minimum similarity threshold for hierarchical clustering
84    pub similarity_threshold: f64,
85    /// Maximum number of iterations
86    pub max_iterations: usize,
87    /// Convergence tolerance
88    pub tolerance: f64,
89    /// Use dimensionality reduction preprocessing
90    pub use_dimension_reduction: bool,
91    /// Target dimensions for dimensionality reduction
92    pub target_dimensions: Option<usize>,
93    /// Preprocessing options
94    pub preprocessing: TextPreprocessing,
95}
96
97impl Default for SemanticClusteringConfig {
98    fn default() -> Self {
99        Self {
100            similarity_metric: SemanticSimilarity::Cosine,
101            n_clusters: Some(10),
102            similarity_threshold: 0.5,
103            max_iterations: 100,
104            tolerance: 1e-4,
105            use_dimension_reduction: false,
106            target_dimensions: None,
107            preprocessing: TextPreprocessing::default(),
108        }
109    }
110}
111
112/// Text preprocessing configuration
113#[derive(Debug, Clone)]
114#[cfg_attr(feature = "serde", derive(SerdeSerialize, SerdeDeserialize))]
115pub struct TextPreprocessing {
116    /// Normalize vectors to unit length
117    pub normalize_vectors: bool,
118    /// Remove zero-variance features
119    pub remove_zero_variance: bool,
120    /// Apply TF-IDF weighting (if using raw term frequencies)
121    pub apply_tfidf: bool,
122    /// Minimum document frequency for terms
123    pub min_df: f64,
124    /// Maximum document frequency for terms
125    pub max_df: f64,
126    /// Maximum number of features to keep
127    pub max_features: Option<usize>,
128}
129
130impl Default for TextPreprocessing {
131    fn default() -> Self {
132        Self {
133            normalize_vectors: true,
134            remove_zero_variance: true,
135            apply_tfidf: true,
136            min_df: 0.01,
137            max_df: 0.95,
138            max_features: None,
139        }
140    }
141}
142
143/// Semantic K-means clustering for text data
144pub struct SemanticKMeans {
145    config: SemanticClusteringConfig,
146    centroids: Option<Array2<f64>>,
147    labels: Option<Array1<usize>>,
148    inertia: Option<f64>,
149    n_iterations: Option<usize>,
150}
151
152impl SemanticKMeans {
153    /// Create a new semantic K-means clusterer
154    pub fn new(config: SemanticClusteringConfig) -> Self {
155        Self {
156            config,
157            centroids: None,
158            labels: None,
159            inertia: None,
160            n_iterations: None,
161        }
162    }
163
164    /// Fit the model to text data
165    pub fn fit(&mut self, text_repr: &TextRepresentation) -> Result<()> {
166        let vectors = self.extract_vectors(text_repr)?;
167        let preprocessed = self.preprocess_vectors(vectors)?;
168
169        let n_clusters = self.config.n_clusters.unwrap_or(10);
170        self.fit_kmeans(preprocessed.view(), n_clusters)?;
171
172        Ok(())
173    }
174
175    /// Extract numerical vectors from text representation
176    fn extract_vectors(&self, text_repr: &TextRepresentation) -> Result<Array2<f64>> {
177        match text_repr {
178            TextRepresentation::TfIdf { vectors, .. } => Ok(vectors.clone()),
179            TextRepresentation::WordEmbeddings { vectors, .. } => Ok(vectors.clone()),
180            TextRepresentation::ContextualEmbeddings { vectors, .. } => Ok(vectors.clone()),
181            TextRepresentation::DocumentTerm { matrix, .. } => Ok(matrix.clone()),
182        }
183    }
184
185    /// Preprocess vectors according to configuration
186    fn preprocess_vectors(&self, vectors: Array2<f64>) -> Result<Array2<f64>> {
187        let mut processed = vectors;
188
189        // Apply TF-IDF if requested
190        if self.config.preprocessing.apply_tfidf {
191            processed = self.apply_tfidf_weighting(processed)?;
192        }
193
194        // Remove zero-variance features
195        if self.config.preprocessing.remove_zero_variance {
196            processed = self.remove_zero_variance_features(processed)?;
197        }
198
199        // Normalize vectors
200        if self.config.preprocessing.normalize_vectors {
201            processed = self.normalize_vectors(processed)?;
202        }
203
204        // Dimensionality reduction
205        if self.config.use_dimension_reduction {
206            if let Some(target_dim) = self.config.target_dimensions {
207                processed = self.reduce_dimensions(processed, target_dim)?;
208            }
209        }
210
211        Ok(processed)
212    }
213
214    /// Apply TF-IDF weighting to document-term matrix
215    fn apply_tfidf_weighting(&self, matrix: Array2<f64>) -> Result<Array2<f64>> {
216        let (n_docs, n_terms) = matrix.dim();
217        let mut tfidf_matrix = matrix.clone();
218
219        // Compute document frequencies
220        let mut df = Array1::zeros(n_terms);
221        for term_idx in 0..n_terms {
222            let mut doc_count = 0;
223            for doc_idx in 0..n_docs {
224                if matrix[[doc_idx, term_idx]] > 0.0 {
225                    doc_count += 1;
226                }
227            }
228            df[term_idx] = doc_count as f64;
229        }
230
231        // Compute TF-IDF
232        for doc_idx in 0..n_docs {
233            for term_idx in 0..n_terms {
234                let tf = matrix[[doc_idx, term_idx]];
235                if tf > 0.0 && df[term_idx] > 0.0 {
236                    let idf = (n_docs as f64 / df[term_idx]).ln();
237                    tfidf_matrix[[doc_idx, term_idx]] = tf * idf;
238                }
239            }
240        }
241
242        Ok(tfidf_matrix)
243    }
244
245    /// Remove features with zero variance
246    fn remove_zero_variance_features(&self, matrix: Array2<f64>) -> Result<Array2<f64>> {
247        let mut feature_mask = Vec::new();
248
249        for col_idx in 0..matrix.ncols() {
250            let column = matrix.column(col_idx);
251            let mean = column.mean();
252            let variance =
253                column.iter().map(|&x| (x - mean).powi(2)).sum::<f64>() / column.len() as f64;
254
255            feature_mask.push(variance > 1e-10); // Keep non-zero variance features
256        }
257
258        let valid_features: Vec<usize> = feature_mask
259            .iter()
260            .enumerate()
261            .filter_map(|(i, &keep)| if keep { Some(i) } else { None })
262            .collect();
263
264        if valid_features.is_empty() {
265            return Err(ClusteringError::InvalidInput(
266                "All features have zero variance".to_string(),
267            ));
268        }
269
270        let filtered = matrix.select(Axis(1), &valid_features);
271        Ok(filtered)
272    }
273
274    /// Normalize vectors to unit length
275    fn normalize_vectors(&self, matrix: Array2<f64>) -> Result<Array2<f64>> {
276        let mut normalized = matrix.clone();
277
278        for mut row in normalized.rows_mut() {
279            let norm = (row.iter().map(|&x| x * x).sum::<f64>()).sqrt();
280            if norm > 1e-10 {
281                row.mapv_inplace(|x| x / norm);
282            }
283        }
284
285        Ok(normalized)
286    }
287
288    /// Reduce dimensionality using PCA (simplified)
289    fn reduce_dimensions(&self, matrix: Array2<f64>, target_dim: usize) -> Result<Array2<f64>> {
290        let (_n_samples, n_features) = matrix.dim();
291
292        if target_dim >= n_features {
293            return Ok(matrix);
294        }
295
296        // Simplified dimensionality reduction: just keep first N dimensions
297        // In practice, this would use proper PCA or other dimensionality reduction
298        let reduced = matrix.slice(s![.., 0..target_dim]).to_owned();
299        Ok(reduced)
300    }
301
302    /// Fit K-means clustering with semantic similarity
303    fn fit_kmeans(&mut self, data: ArrayView2<f64>, k: usize) -> Result<()> {
304        let (n_samples, n_features) = data.dim();
305
306        if k > n_samples {
307            return Err(ClusteringError::InvalidInput(
308                "Number of clusters cannot exceed number of samples".to_string(),
309            ));
310        }
311
312        // Initialize centroids using k-means++
313        let mut centroids = Array2::zeros((k, n_features));
314        self.initialize_centroids_plus_plus(&mut centroids, data)?;
315
316        let mut labels = Array1::zeros(n_samples);
317        let mut prev_inertia = f64::INFINITY;
318        let mut n_iter = 0;
319
320        for iter in 0..self.config.max_iterations {
321            n_iter = iter + 1;
322
323            // Assign points to clusters
324            let mut inertia = 0.0;
325            for (i, sample) in data.rows().into_iter().enumerate() {
326                let (best_cluster, distance) =
327                    self.find_closest_centroid(sample, centroids.view())?;
328                labels[i] = best_cluster;
329                inertia += distance;
330            }
331
332            // Check convergence
333            if (prev_inertia - inertia).abs() < self.config.tolerance {
334                break;
335            }
336            prev_inertia = inertia;
337
338            // Update centroids
339            self.update_centroids(&mut centroids, data, labels.view())?;
340        }
341
342        self.centroids = Some(centroids);
343        self.labels = Some(labels);
344        self.inertia = Some(prev_inertia);
345        self.n_iterations = Some(n_iter);
346
347        Ok(())
348    }
349
350    /// Initialize centroids using k-means++ method
351    fn initialize_centroids_plus_plus(
352        &self,
353        centroids: &mut Array2<f64>,
354        data: ArrayView2<f64>,
355    ) -> Result<()> {
356        let n_samples = data.nrows();
357        let k = centroids.nrows();
358
359        // Choose first centroid randomly
360        centroids.row_mut(0).assign(&data.row(0));
361
362        // Choose remaining centroids
363        for i in 1..k {
364            let mut distances = Array1::zeros(n_samples);
365            let mut total_distance = 0.0;
366
367            // Calculate distances to nearest existing centroid
368            for (j, sample) in data.rows().into_iter().enumerate() {
369                let mut min_dist = f64::INFINITY;
370                for centroid_idx in 0..i {
371                    let dist = self.compute_distance(sample, centroids.row(centroid_idx))?;
372                    if dist < min_dist {
373                        min_dist = dist;
374                    }
375                }
376                distances[j] = min_dist * min_dist;
377                total_distance += distances[j];
378            }
379
380            // Select next centroid probabilistically
381            if total_distance > 0.0 {
382                let target = total_distance * 0.5; // Simplified: use middle point
383                let mut cumsum = 0.0;
384                for (j, &dist) in distances.iter().enumerate() {
385                    cumsum += dist;
386                    if cumsum >= target {
387                        centroids.row_mut(i).assign(&data.row(j));
388                        break;
389                    }
390                }
391            } else {
392                // Fallback: use next available point
393                if i < n_samples {
394                    centroids.row_mut(i).assign(&data.row(i));
395                }
396            }
397        }
398
399        Ok(())
400    }
401
402    /// Find the closest centroid to a sample
403    fn find_closest_centroid(
404        &self,
405        sample: ArrayView1<f64>,
406        centroids: ArrayView2<f64>,
407    ) -> Result<(usize, f64)> {
408        let mut min_distance = f64::INFINITY;
409        let mut best_cluster = 0;
410
411        for (i, centroid) in centroids.rows().into_iter().enumerate() {
412            let distance = self.compute_distance(sample, centroid)?;
413            if distance < min_distance {
414                min_distance = distance;
415                best_cluster = i;
416            }
417        }
418
419        Ok((best_cluster, min_distance))
420    }
421
422    /// Compute distance based on configured similarity metric
423    fn compute_distance(&self, a: ArrayView1<f64>, b: ArrayView1<f64>) -> Result<f64> {
424        match self.config.similarity_metric {
425            SemanticSimilarity::Cosine => {
426                let similarity = self.cosine_similarity(a, b)?;
427                Ok(1.0 - similarity) // Convert similarity to distance
428            }
429            SemanticSimilarity::Euclidean => Ok(euclidean_distance(a, b)),
430            SemanticSimilarity::Manhattan => Ok(self.manhattan_distance(a, b)?),
431            SemanticSimilarity::Jaccard => {
432                let similarity = self.jaccard_similarity(a, b)?;
433                Ok(1.0 - similarity)
434            }
435            SemanticSimilarity::Pearson => {
436                let correlation = self.pearson_correlation(a, b)?;
437                Ok(1.0 - correlation.abs()) // Use absolute correlation
438            }
439            SemanticSimilarity::JensenShannon => self.jensen_shannon_distance(a, b),
440            SemanticSimilarity::Hellinger => self.hellinger_distance(a, b),
441            SemanticSimilarity::Bhattacharyya => self.bhattacharyya_distance(a, b),
442        }
443    }
444
445    /// Compute cosine similarity
446    fn cosine_similarity(&self, a: ArrayView1<f64>, b: ArrayView1<f64>) -> Result<f64> {
447        let dot_product = a.dot(&b);
448        let norm_a = (a.dot(&a)).sqrt();
449        let norm_b = (b.dot(&b)).sqrt();
450
451        if norm_a == 0.0 || norm_b == 0.0 {
452            Ok(0.0)
453        } else {
454            Ok(dot_product / (norm_a * norm_b))
455        }
456    }
457
458    /// Compute Manhattan distance
459    fn manhattan_distance(&self, a: ArrayView1<f64>, b: ArrayView1<f64>) -> Result<f64> {
460        Ok(a.iter().zip(b.iter()).map(|(&x, &y)| (x - y).abs()).sum())
461    }
462
463    /// Compute Jaccard similarity
464    fn jaccard_similarity(&self, a: ArrayView1<f64>, b: ArrayView1<f64>) -> Result<f64> {
465        let threshold = 1e-10;
466        let mut intersection = 0.0;
467        let mut union = 0.0;
468
469        for (&x, &y) in a.iter().zip(b.iter()) {
470            let x_present = x > threshold;
471            let y_present = y > threshold;
472
473            if x_present && y_present {
474                intersection += 1.0;
475            }
476            if x_present || y_present {
477                union += 1.0;
478            }
479        }
480
481        if union == 0.0 {
482            Ok(1.0) // Both vectors are zero
483        } else {
484            Ok(intersection / union)
485        }
486    }
487
488    /// Compute Pearson correlation coefficient
489    fn pearson_correlation(&self, a: ArrayView1<f64>, b: ArrayView1<f64>) -> Result<f64> {
490        let n = a.len() as f64;
491        if n < 2.0 {
492            return Ok(0.0);
493        }
494
495        let mean_a = a.mean();
496        let mean_b = b.mean();
497
498        let mut numerator = 0.0;
499        let mut sum_sq_a = 0.0;
500        let mut sum_sq_b = 0.0;
501
502        for (&x, &y) in a.iter().zip(b.iter()) {
503            let diff_a = x - mean_a;
504            let diff_b = y - mean_b;
505            numerator += diff_a * diff_b;
506            sum_sq_a += diff_a * diff_a;
507            sum_sq_b += diff_b * diff_b;
508        }
509
510        let denominator = (sum_sq_a * sum_sq_b).sqrt();
511        if denominator == 0.0 {
512            Ok(0.0)
513        } else {
514            Ok(numerator / denominator)
515        }
516    }
517
518    /// Compute Jensen-Shannon distance
519    fn jensen_shannon_distance(&self, a: ArrayView1<f64>, b: ArrayView1<f64>) -> Result<f64> {
520        // Normalize to probability distributions
521        let sum_a: f64 = a.iter().map(|&x| x.max(0.0)).sum();
522        let sum_b: f64 = b.iter().map(|&x| x.max(0.0)).sum();
523
524        if sum_a == 0.0 || sum_b == 0.0 {
525            return Ok(1.0);
526        }
527
528        let p: Vec<f64> = a.iter().map(|&x| x.max(0.0) / sum_a).collect();
529        let q: Vec<f64> = b.iter().map(|&x| x.max(0.0) / sum_b).collect();
530
531        // Compute average distribution
532        let m: Vec<f64> = p
533            .iter()
534            .zip(q.iter())
535            .map(|(&x, &y)| (x + y) / 2.0)
536            .collect();
537
538        // Compute KL divergences
539        let kl_pm = self.kl_divergence(&p, &m);
540        let kl_qm = self.kl_divergence(&q, &m);
541
542        // Jensen-Shannon divergence
543        let js = (kl_pm + kl_qm) / 2.0;
544        Ok(js.sqrt()) // Return Jensen-Shannon distance
545    }
546
547    /// Compute KL divergence between two probability distributions
548    fn kl_divergence(&self, p: &[f64], q: &[f64]) -> f64 {
549        let mut kl = 0.0;
550        for (&pi, &qi) in p.iter().zip(q.iter()) {
551            if pi > 1e-10 && qi > 1e-10 {
552                kl += pi * (pi / qi).ln();
553            }
554        }
555        kl
556    }
557
558    /// Compute Hellinger distance
559    fn hellinger_distance(&self, a: ArrayView1<f64>, b: ArrayView1<f64>) -> Result<f64> {
560        // Normalize to probability distributions
561        let sum_a: f64 = a.iter().map(|&x| x.max(0.0)).sum();
562        let sum_b: f64 = b.iter().map(|&x| x.max(0.0)).sum();
563
564        if sum_a == 0.0 || sum_b == 0.0 {
565            return Ok(1.0);
566        }
567
568        let mut sum_sqrt_products = 0.0;
569        for (&x, &y) in a.iter().zip(b.iter()) {
570            let p = x.max(0.0) / sum_a;
571            let q = y.max(0.0) / sum_b;
572            sum_sqrt_products += (p * q).sqrt();
573        }
574
575        let hellinger = (1.0 - sum_sqrt_products).sqrt();
576        Ok(hellinger)
577    }
578
579    /// Compute Bhattacharyya distance
580    fn bhattacharyya_distance(&self, a: ArrayView1<f64>, b: ArrayView1<f64>) -> Result<f64> {
581        // Normalize to probability distributions
582        let sum_a: f64 = a.iter().map(|&x| x.max(0.0)).sum();
583        let sum_b: f64 = b.iter().map(|&x| x.max(0.0)).sum();
584
585        if sum_a == 0.0 || sum_b == 0.0 {
586            return Ok(f64::INFINITY);
587        }
588
589        let mut bc = 0.0; // Bhattacharyya coefficient
590        for (&x, &y) in a.iter().zip(b.iter()) {
591            let p = x.max(0.0) / sum_a;
592            let q = y.max(0.0) / sum_b;
593            bc += (p * q).sqrt();
594        }
595
596        if bc <= 0.0 {
597            Ok(f64::INFINITY)
598        } else {
599            Ok(-bc.ln())
600        }
601    }
602
603    /// Update centroids
604    fn update_centroids(
605        &self,
606        centroids: &mut Array2<f64>,
607        data: ArrayView2<f64>,
608        labels: ArrayView1<usize>,
609    ) -> Result<()> {
610        let k = centroids.nrows();
611        let n_features = centroids.ncols();
612
613        // Reset centroids
614        centroids.fill(0.0);
615        let mut cluster_sizes = vec![0; k];
616
617        // Accumulate points for each cluster
618        for (i, &label) in labels.iter().enumerate() {
619            if label < k {
620                for j in 0..n_features {
621                    centroids[[label, j]] += data[[i, j]];
622                }
623                cluster_sizes[label] += 1;
624            }
625        }
626
627        // Average to get centroids
628        for i in 0..k {
629            if cluster_sizes[i] > 0 {
630                for j in 0..n_features {
631                    centroids[[i, j]] /= cluster_sizes[i] as f64;
632                }
633            }
634        }
635
636        Ok(())
637    }
638
639    /// Predict cluster assignments for new text data
640    pub fn predict(&self, text_repr: &TextRepresentation) -> Result<Array1<usize>> {
641        let vectors = self.extract_vectors(text_repr)?;
642        let preprocessed = self.preprocess_vectors(vectors)?;
643
644        if let Some(ref centroids) = self.centroids {
645            let mut labels = Array1::zeros(preprocessed.nrows());
646
647            for (i, sample) in preprocessed.rows().into_iter().enumerate() {
648                let (best_cluster, _distance) =
649                    self.find_closest_centroid(sample, centroids.view())?;
650                labels[i] = best_cluster;
651            }
652
653            Ok(labels)
654        } else {
655            Err(ClusteringError::InvalidInput(
656                "Model has not been fitted yet".to_string(),
657            ))
658        }
659    }
660
661    /// Get cluster centroids
662    pub fn cluster_centers(&self) -> Option<&Array2<f64>> {
663        self.centroids.as_ref()
664    }
665
666    /// Get inertia (sum of distances to centroids)
667    pub fn inertia(&self) -> Option<f64> {
668        self.inertia
669    }
670
671    /// Get number of iterations performed
672    pub fn n_iterations(&self) -> Option<usize> {
673        self.n_iterations
674    }
675}
676
677/// Hierarchical clustering for text with semantic similarity
678pub struct SemanticHierarchical {
679    config: SemanticClusteringConfig,
680    linkage_matrix: Option<Array2<f64>>,
681    n_clusters: Option<usize>,
682}
683
684impl SemanticHierarchical {
685    /// Create a new semantic hierarchical clusterer
686    pub fn new(config: SemanticClusteringConfig) -> Self {
687        Self {
688            config,
689            linkage_matrix: None,
690            n_clusters: None,
691        }
692    }
693
694    /// Fit hierarchical clustering to text data
695    pub fn fit(&mut self, text_repr: &TextRepresentation) -> Result<()> {
696        let vectors = self.extract_vectors(text_repr)?;
697        let preprocessed = self.preprocess_vectors(vectors)?;
698
699        self.fit_hierarchical(preprocessed.view())?;
700        Ok(())
701    }
702
703    /// Extract and preprocess vectors (same as SemanticKMeans)
704    fn extract_vectors(&self, text_repr: &TextRepresentation) -> Result<Array2<f64>> {
705        match text_repr {
706            TextRepresentation::TfIdf { vectors, .. } => Ok(vectors.clone()),
707            TextRepresentation::WordEmbeddings { vectors, .. } => Ok(vectors.clone()),
708            TextRepresentation::ContextualEmbeddings { vectors, .. } => Ok(vectors.clone()),
709            TextRepresentation::DocumentTerm { matrix, .. } => Ok(matrix.clone()),
710        }
711    }
712
713    /// Preprocess vectors (simplified version)
714    fn preprocess_vectors(&self, vectors: Array2<f64>) -> Result<Array2<f64>> {
715        // Simplified preprocessing for hierarchical clustering
716        if self.config.preprocessing.normalize_vectors {
717            let mut normalized = vectors.clone();
718            for mut row in normalized.rows_mut() {
719                let norm = (row.iter().map(|&x| x * x).sum::<f64>()).sqrt();
720                if norm > 1e-10 {
721                    row.mapv_inplace(|x| x / norm);
722                }
723            }
724            Ok(normalized)
725        } else {
726            Ok(vectors)
727        }
728    }
729
730    /// Fit hierarchical clustering using single linkage
731    fn fit_hierarchical(&mut self, data: ArrayView2<f64>) -> Result<()> {
732        let n_samples = data.nrows();
733
734        if n_samples < 2 {
735            return Err(ClusteringError::InvalidInput(
736                "Need at least 2 samples for hierarchical clustering".to_string(),
737            ));
738        }
739
740        // Compute distance matrix
741        let mut distance_matrix = Array2::zeros((n_samples, n_samples));
742        for i in 0..n_samples {
743            for j in i + 1..n_samples {
744                let distance = self.compute_distance(data.row(i), data.row(j))?;
745                distance_matrix[[i, j]] = distance;
746                distance_matrix[[j, i]] = distance;
747            }
748        }
749
750        // Perform single linkage clustering (simplified)
751        let mut clusters: Vec<HashSet<usize>> = (0..n_samples)
752            .map(|i| {
753                let mut set = HashSet::new();
754                set.insert(i);
755                set
756            })
757            .collect();
758
759        let mut linkage_steps = Vec::new();
760
761        while clusters.len() > 1 {
762            let mut min_distance = f64::INFINITY;
763            let mut merge_i = 0;
764            let mut merge_j = 1;
765
766            // Find closest clusters
767            for i in 0..clusters.len() {
768                for j in i + 1..clusters.len() {
769                    let distance =
770                        self.cluster_distance(&clusters[i], &clusters[j], &distance_matrix);
771                    if distance < min_distance {
772                        min_distance = distance;
773                        merge_i = i;
774                        merge_j = j;
775                    }
776                }
777            }
778
779            // Record linkage step
780            linkage_steps.push([
781                merge_i as f64,
782                merge_j as f64,
783                min_distance,
784                (clusters[merge_i].len() + clusters[merge_j].len()) as f64,
785            ]);
786
787            // Merge clusters
788            let cluster_j = clusters.remove(merge_j);
789            clusters[merge_i].extend(cluster_j);
790        }
791
792        // Convert to linkage matrix
793        let linkage_matrix = Array2::from_shape_vec(
794            (linkage_steps.len(), 4),
795            linkage_steps.into_iter().flatten().collect(),
796        )
797        .unwrap();
798
799        self.linkage_matrix = Some(linkage_matrix);
800        Ok(())
801    }
802
803    /// Compute distance between clusters (single linkage)
804    fn cluster_distance(
805        &self,
806        cluster_a: &HashSet<usize>,
807        cluster_b: &HashSet<usize>,
808        distance_matrix: &Array2<f64>,
809    ) -> f64 {
810        let mut min_distance = f64::INFINITY;
811
812        for &i in cluster_a {
813            for &j in cluster_b {
814                let distance = distance_matrix[[i, j]];
815                if distance < min_distance {
816                    min_distance = distance;
817                }
818            }
819        }
820
821        min_distance
822    }
823
824    /// Compute distance based on similarity metric
825    fn compute_distance(&self, a: ArrayView1<f64>, b: ArrayView1<f64>) -> Result<f64> {
826        match self.config.similarity_metric {
827            SemanticSimilarity::Cosine => {
828                let dot_product = a.dot(&b);
829                let norm_a = (a.dot(&a)).sqrt();
830                let norm_b = (b.dot(&b)).sqrt();
831
832                if norm_a == 0.0 || norm_b == 0.0 {
833                    Ok(1.0)
834                } else {
835                    let similarity = dot_product / (norm_a * norm_b);
836                    Ok(1.0 - similarity)
837                }
838            }
839            SemanticSimilarity::Euclidean => Ok(euclidean_distance(a, b)),
840            _ => {
841                // For other metrics, use Euclidean as fallback
842                Ok(euclidean_distance(a, b))
843            }
844        }
845    }
846
847    /// Get linkage matrix
848    pub fn linkage_matrix(&self) -> Option<&Array2<f64>> {
849        self.linkage_matrix.as_ref()
850    }
851}
852
853/// Topic modeling-based clustering using semantic similarity
854pub struct TopicBasedClustering {
855    config: SemanticClusteringConfig,
856    topics: Option<Array2<f64>>,
857    document_topic_distributions: Option<Array2<f64>>,
858    n_topics: usize,
859}
860
861impl TopicBasedClustering {
862    /// Create a new topic-based clusterer
863    pub fn new(config: SemanticClusteringConfig, n_topics: usize) -> Self {
864        Self {
865            config,
866            topics: None,
867            document_topic_distributions: None,
868            n_topics,
869        }
870    }
871
872    /// Fit topic-based clustering (simplified NMF-like approach)
873    pub fn fit(&mut self, text_repr: &TextRepresentation) -> Result<()> {
874        let vectors = self.extract_vectors(text_repr)?;
875        let preprocessed = self.preprocess_vectors(vectors)?;
876
877        self.fit_topics(preprocessed.view())?;
878        Ok(())
879    }
880
881    /// Extract vectors from text representation
882    fn extract_vectors(&self, text_repr: &TextRepresentation) -> Result<Array2<f64>> {
883        match text_repr {
884            TextRepresentation::TfIdf { vectors, .. } => Ok(vectors.clone()),
885            TextRepresentation::DocumentTerm { matrix, .. } => Ok(matrix.clone()),
886            _ => Err(ClusteringError::InvalidInput(
887                "Topic modeling requires TF-IDF or document-term matrix".to_string(),
888            )),
889        }
890    }
891
892    /// Preprocess vectors
893    fn preprocess_vectors(&self, vectors: Array2<f64>) -> Result<Array2<f64>> {
894        // Ensure non-negative values for topic modeling
895        let mut processed = vectors.mapv(|x| x.max(0.0));
896
897        // Normalize documents
898        for mut row in processed.rows_mut() {
899            let sum: f64 = row.sum();
900            if sum > 1e-10 {
901                row.mapv_inplace(|x| x / sum);
902            }
903        }
904
905        Ok(processed)
906    }
907
908    /// Fit topic model using simplified NMF
909    fn fit_topics(&mut self, data: ArrayView2<f64>) -> Result<()> {
910        let (n_docs, n_terms) = data.dim();
911
912        // Initialize topics and document-topic distributions randomly
913        let mut topics = Array2::from_elem((self.n_topics, n_terms), 1.0 / n_terms as f64);
914        let mut doc_topics = Array2::from_elem((n_docs, self.n_topics), 1.0 / self.n_topics as f64);
915
916        // Simplified NMF iterations
917        for _iter in 0..self.config.max_iterations {
918            // Update document-topic distributions
919            for doc_idx in 0..n_docs {
920                for topic_idx in 0..self.n_topics {
921                    let mut numerator = 0.0;
922                    let mut denominator = 0.0;
923
924                    for term_idx in 0..n_terms {
925                        let observed = data[[doc_idx, term_idx]];
926                        let expected = topics[[topic_idx, term_idx]];
927
928                        if expected > 1e-10 {
929                            numerator += observed * expected;
930                            denominator += expected;
931                        }
932                    }
933
934                    if denominator > 1e-10 {
935                        doc_topics[[doc_idx, topic_idx]] = numerator / denominator;
936                    }
937                }
938
939                // Normalize document-topic distribution
940                let sum: f64 = doc_topics.row(doc_idx).sum();
941                if sum > 1e-10 {
942                    for topic_idx in 0..self.n_topics {
943                        doc_topics[[doc_idx, topic_idx]] /= sum;
944                    }
945                }
946            }
947
948            // Update topics
949            for topic_idx in 0..self.n_topics {
950                for term_idx in 0..n_terms {
951                    let mut numerator = 0.0;
952                    let mut denominator = 0.0;
953
954                    for doc_idx in 0..n_docs {
955                        let observed = data[[doc_idx, term_idx]];
956                        let doc_topic_weight = doc_topics[[doc_idx, topic_idx]];
957
958                        numerator += observed * doc_topic_weight;
959                        denominator += doc_topic_weight;
960                    }
961
962                    if denominator > 1e-10 {
963                        topics[[topic_idx, term_idx]] = numerator / denominator;
964                    }
965                }
966
967                // Normalize topic distribution
968                let sum: f64 = topics.row(topic_idx).sum();
969                if sum > 1e-10 {
970                    for term_idx in 0..n_terms {
971                        topics[[topic_idx, term_idx]] /= sum;
972                    }
973                }
974            }
975        }
976
977        self.topics = Some(topics);
978        self.document_topic_distributions = Some(doc_topics);
979        Ok(())
980    }
981
982    /// Get cluster assignments based on dominant topics
983    pub fn predict(&self, text_repr: &TextRepresentation) -> Result<Array1<usize>> {
984        if let Some(ref doc_topics) = self.document_topic_distributions {
985            let mut labels = Array1::zeros(doc_topics.nrows());
986
987            for (doc_idx, doc_topic_dist) in doc_topics.rows().into_iter().enumerate() {
988                let mut max_prob = 0.0;
989                let mut best_topic = 0;
990
991                for (topic_idx, &prob) in doc_topic_dist.iter().enumerate() {
992                    if prob > max_prob {
993                        max_prob = prob;
994                        best_topic = topic_idx;
995                    }
996                }
997
998                labels[doc_idx] = best_topic;
999            }
1000
1001            Ok(labels)
1002        } else {
1003            Err(ClusteringError::InvalidInput(
1004                "Model has not been fitted yet".to_string(),
1005            ))
1006        }
1007    }
1008
1009    /// Get topics (term distributions)
1010    pub fn topics(&self) -> Option<&Array2<f64>> {
1011        self.topics.as_ref()
1012    }
1013
1014    /// Get document-topic distributions
1015    pub fn document_topic_distributions(&self) -> Option<&Array2<f64>> {
1016        self.document_topic_distributions.as_ref()
1017    }
1018}
1019
1020/// Convenience functions for text clustering
1021
1022/// Perform semantic K-means clustering on text data
1023#[allow(dead_code)]
1024pub fn semantic_kmeans(
1025    text_repr: &TextRepresentation,
1026    n_clusters: usize,
1027    similarity_metric: SemanticSimilarity,
1028) -> Result<(Array2<f64>, Array1<usize>)> {
1029    let config = SemanticClusteringConfig {
1030        n_clusters: Some(n_clusters),
1031        similarity_metric,
1032        ..Default::default()
1033    };
1034
1035    let mut clusterer = SemanticKMeans::new(config);
1036    clusterer.fit(text_repr)?;
1037
1038    let centers = clusterer
1039        .cluster_centers()
1040        .ok_or_else(|| {
1041            ClusteringError::ComputationError("Failed to get cluster centers".to_string())
1042        })?
1043        .clone();
1044
1045    let labels = clusterer.predict(text_repr)?;
1046
1047    Ok((centers, labels))
1048}
1049
1050/// Perform hierarchical clustering on text data
1051#[allow(dead_code)]
1052pub fn semantic_hierarchical(
1053    text_repr: &TextRepresentation,
1054    similarity_metric: SemanticSimilarity,
1055) -> Result<Array2<f64>> {
1056    let config = SemanticClusteringConfig {
1057        similarity_metric,
1058        ..Default::default()
1059    };
1060
1061    let mut clusterer = SemanticHierarchical::new(config);
1062    clusterer.fit(text_repr)?;
1063
1064    clusterer
1065        .linkage_matrix()
1066        .ok_or_else(|| {
1067            ClusteringError::ComputationError("Failed to get linkage matrix".to_string())
1068        })
1069        .cloned()
1070}
1071
1072/// Perform topic-based clustering on text data
1073#[allow(dead_code)]
1074pub fn topic_clustering(
1075    text_repr: &TextRepresentation,
1076    n_topics: usize,
1077) -> Result<(Array2<f64>, Array2<f64>, Array1<usize>)> {
1078    let config = SemanticClusteringConfig::default();
1079    let mut clusterer = TopicBasedClustering::new(config, n_topics);
1080
1081    clusterer.fit(text_repr)?;
1082
1083    let _topics = clusterer
1084        .topics()
1085        .ok_or_else(|| ClusteringError::ComputationError("Failed to get topics".to_string()))?
1086        .clone();
1087
1088    let doc_topics = clusterer
1089        .document_topic_distributions()
1090        .ok_or_else(|| {
1091            ClusteringError::ComputationError(
1092                "Failed to get document-topic distributions".to_string(),
1093            )
1094        })?
1095        .clone();
1096
1097    let labels = clusterer.predict(text_repr)?;
1098
1099    Ok((_topics, doc_topics, labels))
1100}
1101
1102#[cfg(test)]
1103mod tests {
1104    use super::*;
1105    use scirs2_core::ndarray::Array2;
1106
1107    #[test]
1108    fn test_semantic_kmeans_basic() {
1109        // Create sample TF-IDF vectors
1110        let vectors = Array2::from_shape_vec(
1111            (4, 3),
1112            vec![1.0, 0.0, 0.0, 0.9, 0.1, 0.0, 0.0, 0.0, 1.0, 0.0, 0.1, 0.9],
1113        )
1114        .unwrap();
1115
1116        let text_repr = TextRepresentation::TfIdf {
1117            vectors,
1118            vocabulary: HashMap::new(),
1119        };
1120
1121        let result = semantic_kmeans(&text_repr, 2, SemanticSimilarity::Cosine);
1122        assert!(result.is_ok());
1123
1124        let (centers, labels) = result.unwrap();
1125        assert_eq!(centers.nrows(), 2);
1126        assert_eq!(labels.len(), 4);
1127    }
1128
1129    #[test]
1130    fn test_similarity_metrics() {
1131        let a = scirs2_core::ndarray::Array1::from_vec(vec![1.0, 0.0, 0.0]);
1132        let b = scirs2_core::ndarray::Array1::from_vec(vec![0.0, 1.0, 0.0]);
1133
1134        let config = SemanticClusteringConfig::default();
1135        let clusterer = SemanticKMeans::new(config);
1136
1137        // Test cosine similarity
1138        let cosine_sim = clusterer.cosine_similarity(a.view(), b.view()).unwrap();
1139        assert_eq!(cosine_sim, 0.0); // Orthogonal vectors
1140
1141        // Test Manhattan distance
1142        let manhattan = clusterer.manhattan_distance(a.view(), b.view()).unwrap();
1143        assert_eq!(manhattan, 2.0);
1144    }
1145
1146    #[test]
1147    fn testtext_preprocessing() {
1148        let config = SemanticClusteringConfig::default();
1149        let clusterer = SemanticKMeans::new(config);
1150
1151        let matrix = Array2::from_shape_vec((2, 3), vec![3.0, 4.0, 0.0, 1.0, 2.0, 2.0]).unwrap();
1152
1153        let normalized = clusterer.normalize_vectors(matrix).unwrap();
1154
1155        // Check that vectors are normalized
1156        for row in normalized.rows() {
1157            let norm = (row.iter().map(|&x| x * x).sum::<f64>()).sqrt();
1158            assert!((norm - 1.0).abs() < 1e-10);
1159        }
1160    }
1161
1162    #[test]
1163    fn test_topic_clustering_basic() {
1164        // Create sample document-term matrix
1165        let matrix = Array2::from_shape_vec(
1166            (3, 4),
1167            vec![2.0, 0.0, 1.0, 0.0, 0.0, 3.0, 0.0, 1.0, 1.0, 1.0, 2.0, 1.0],
1168        )
1169        .unwrap();
1170
1171        let text_repr = TextRepresentation::DocumentTerm {
1172            matrix,
1173            vocabulary: vec![
1174                "word1".to_string(),
1175                "word2".to_string(),
1176                "word3".to_string(),
1177                "word4".to_string(),
1178            ],
1179        };
1180
1181        let result = topic_clustering(&text_repr, 2);
1182        assert!(result.is_ok());
1183
1184        let (topics, doc_topics, labels) = result.unwrap();
1185        assert_eq!(topics.nrows(), 2);
1186        assert_eq!(doc_topics.nrows(), 3);
1187        assert_eq!(labels.len(), 3);
1188    }
1189
1190    #[test]
1191    fn test_semantic_similarity_enum() {
1192        assert_eq!(SemanticSimilarity::Cosine, SemanticSimilarity::Cosine);
1193        assert_ne!(SemanticSimilarity::Cosine, SemanticSimilarity::Euclidean);
1194    }
1195}