1use serde::{Deserialize, Serialize};
7use std::collections::HashMap;
8
9use super::entities::ClusterId;
10
11#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
13pub enum ClusteringMethod {
14 HDBSCAN,
17
18 KMeans {
20 k: usize,
22 },
23
24 Spectral {
26 n_clusters: usize,
28 },
29
30 Agglomerative {
32 n_clusters: usize,
34 linkage: LinkageMethod,
36 },
37}
38
39impl Default for ClusteringMethod {
40 fn default() -> Self {
41 Self::HDBSCAN
42 }
43}
44
45impl std::fmt::Display for ClusteringMethod {
46 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
47 match self {
48 ClusteringMethod::HDBSCAN => write!(f, "HDBSCAN"),
49 ClusteringMethod::KMeans { k } => write!(f, "K-Means (k={})", k),
50 ClusteringMethod::Spectral { n_clusters } => {
51 write!(f, "Spectral (n={})", n_clusters)
52 }
53 ClusteringMethod::Agglomerative { n_clusters, linkage } => {
54 write!(f, "Agglomerative (n={}, {:?})", n_clusters, linkage)
55 }
56 }
57 }
58}
59
60#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
62pub enum LinkageMethod {
63 Ward,
65 Complete,
67 Average,
69 Single,
71}
72
73impl Default for LinkageMethod {
74 fn default() -> Self {
75 Self::Ward
76 }
77}
78
79#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
81pub enum DistanceMetric {
82 Euclidean,
84 Cosine,
86 Manhattan,
88 Poincare,
90}
91
92impl Default for DistanceMetric {
93 fn default() -> Self {
94 Self::Cosine
95 }
96}
97
98impl std::fmt::Display for DistanceMetric {
99 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
100 match self {
101 DistanceMetric::Euclidean => write!(f, "Euclidean"),
102 DistanceMetric::Cosine => write!(f, "Cosine"),
103 DistanceMetric::Manhattan => write!(f, "Manhattan"),
104 DistanceMetric::Poincare => write!(f, "Poincare"),
105 }
106 }
107}
108
109#[derive(Debug, Clone, Serialize, Deserialize)]
111pub struct ClusteringParameters {
112 pub min_cluster_size: usize,
114
115 pub min_samples: usize,
117
118 pub epsilon: Option<f32>,
120
121 pub metric: DistanceMetric,
123
124 pub max_clusters: Option<usize>,
126
127 pub allow_single_cluster: bool,
129}
130
131impl Default for ClusteringParameters {
132 fn default() -> Self {
133 Self {
134 min_cluster_size: 5,
135 min_samples: 3,
136 epsilon: None,
137 metric: DistanceMetric::Cosine,
138 max_clusters: None,
139 allow_single_cluster: false,
140 }
141 }
142}
143
144impl ClusteringParameters {
145 #[must_use]
147 pub fn hdbscan(min_cluster_size: usize, min_samples: usize) -> Self {
148 Self {
149 min_cluster_size,
150 min_samples,
151 ..Default::default()
152 }
153 }
154
155 #[must_use]
157 pub fn kmeans() -> Self {
158 Self {
159 min_cluster_size: 1,
160 min_samples: 1,
161 allow_single_cluster: true,
162 ..Default::default()
163 }
164 }
165
166 #[must_use]
168 pub fn with_metric(mut self, metric: DistanceMetric) -> Self {
169 self.metric = metric;
170 self
171 }
172
173 #[must_use]
175 pub fn with_epsilon(mut self, epsilon: f32) -> Self {
176 self.epsilon = Some(epsilon);
177 self
178 }
179}
180
181#[derive(Debug, Clone, Serialize, Deserialize)]
183pub struct ClusteringConfig {
184 pub method: ClusteringMethod,
186
187 pub parameters: ClusteringParameters,
189
190 pub compute_prototypes: bool,
192
193 pub prototypes_per_cluster: usize,
195
196 pub compute_silhouette: bool,
198
199 pub random_seed: Option<u64>,
201}
202
203impl Default for ClusteringConfig {
204 fn default() -> Self {
205 Self {
206 method: ClusteringMethod::HDBSCAN,
207 parameters: ClusteringParameters::default(),
208 compute_prototypes: true,
209 prototypes_per_cluster: 3,
210 compute_silhouette: true,
211 random_seed: None,
212 }
213 }
214}
215
216impl ClusteringConfig {
217 #[must_use]
219 pub fn hdbscan(min_cluster_size: usize, min_samples: usize) -> Self {
220 Self {
221 method: ClusteringMethod::HDBSCAN,
222 parameters: ClusteringParameters::hdbscan(min_cluster_size, min_samples),
223 ..Default::default()
224 }
225 }
226
227 #[must_use]
229 pub fn kmeans(k: usize) -> Self {
230 Self {
231 method: ClusteringMethod::KMeans { k },
232 parameters: ClusteringParameters::kmeans(),
233 ..Default::default()
234 }
235 }
236
237 #[must_use]
239 pub fn with_seed(mut self, seed: u64) -> Self {
240 self.random_seed = Some(seed);
241 self
242 }
243}
244
245#[derive(Debug, Clone, Serialize, Deserialize)]
247pub struct MotifConfig {
248 pub min_length: usize,
250
251 pub max_length: usize,
253
254 pub min_occurrences: usize,
256
257 pub min_confidence: f32,
259
260 pub allow_overlap: bool,
262
263 pub max_gap: usize,
265}
266
267impl Default for MotifConfig {
268 fn default() -> Self {
269 Self {
270 min_length: 2,
271 max_length: 10,
272 min_occurrences: 3,
273 min_confidence: 0.5,
274 allow_overlap: false,
275 max_gap: 0,
276 }
277 }
278}
279
280impl MotifConfig {
281 #[must_use]
283 pub fn strict() -> Self {
284 Self {
285 min_length: 3,
286 max_length: 8,
287 min_occurrences: 5,
288 min_confidence: 0.7,
289 allow_overlap: false,
290 max_gap: 0,
291 }
292 }
293
294 #[must_use]
296 pub fn relaxed() -> Self {
297 Self {
298 min_length: 2,
299 max_length: 15,
300 min_occurrences: 2,
301 min_confidence: 0.3,
302 allow_overlap: true,
303 max_gap: 2,
304 }
305 }
306
307 #[must_use]
309 pub fn with_length_range(mut self, min: usize, max: usize) -> Self {
310 self.min_length = min;
311 self.max_length = max;
312 self
313 }
314}
315
316#[derive(Debug, Clone, Serialize, Deserialize)]
318pub struct SequenceMetrics {
319 pub entropy: f32,
321
322 pub normalized_entropy: f32,
324
325 pub stereotypy: f32,
327
328 pub unique_clusters: usize,
330
331 pub unique_transitions: usize,
333
334 pub total_transitions: usize,
336
337 pub dominant_transition: Option<(ClusterId, ClusterId, f32)>,
339
340 pub repetition_rate: f32,
342}
343
344impl Default for SequenceMetrics {
345 fn default() -> Self {
346 Self {
347 entropy: 0.0,
348 normalized_entropy: 0.0,
349 stereotypy: 1.0,
350 unique_clusters: 0,
351 unique_transitions: 0,
352 total_transitions: 0,
353 dominant_transition: None,
354 repetition_rate: 0.0,
355 }
356 }
357}
358
359#[derive(Debug, Clone, Serialize, Deserialize)]
363pub struct TransitionMatrix {
364 pub cluster_ids: Vec<ClusterId>,
366
367 pub probabilities: Vec<Vec<f32>>,
370
371 pub observations: Vec<Vec<u32>>,
373
374 #[serde(skip)]
376 index_map: HashMap<ClusterId, usize>,
377}
378
379impl TransitionMatrix {
380 #[must_use]
382 pub fn new(cluster_ids: Vec<ClusterId>) -> Self {
383 let n = cluster_ids.len();
384 let index_map: HashMap<ClusterId, usize> = cluster_ids
385 .iter()
386 .enumerate()
387 .map(|(i, id)| (*id, i))
388 .collect();
389
390 Self {
391 cluster_ids,
392 probabilities: vec![vec![0.0; n]; n],
393 observations: vec![vec![0; n]; n],
394 index_map,
395 }
396 }
397
398 #[must_use]
400 pub fn size(&self) -> usize {
401 self.cluster_ids.len()
402 }
403
404 #[must_use]
406 pub fn index_of(&self, cluster_id: &ClusterId) -> Option<usize> {
407 self.index_map.get(cluster_id).copied()
408 }
409
410 pub fn record_transition(&mut self, from: &ClusterId, to: &ClusterId) {
412 if let (Some(i), Some(j)) = (self.index_of(from), self.index_of(to)) {
413 self.observations[i][j] += 1;
414 }
415 }
416
417 pub fn compute_probabilities(&mut self) {
419 for i in 0..self.size() {
420 let row_sum: u32 = self.observations[i].iter().sum();
421 if row_sum > 0 {
422 for j in 0..self.size() {
423 self.probabilities[i][j] = self.observations[i][j] as f32 / row_sum as f32;
424 }
425 }
426 }
427 }
428
429 #[must_use]
431 pub fn probability(&self, from: &ClusterId, to: &ClusterId) -> Option<f32> {
432 match (self.index_of(from), self.index_of(to)) {
433 (Some(i), Some(j)) => Some(self.probabilities[i][j]),
434 _ => None,
435 }
436 }
437
438 #[must_use]
440 pub fn observation_count(&self, from: &ClusterId, to: &ClusterId) -> Option<u32> {
441 match (self.index_of(from), self.index_of(to)) {
442 (Some(i), Some(j)) => Some(self.observations[i][j]),
443 _ => None,
444 }
445 }
446
447 #[must_use]
449 pub fn non_zero_transitions(&self) -> Vec<(ClusterId, ClusterId, f32)> {
450 let mut transitions = Vec::new();
451 for (i, from) in self.cluster_ids.iter().enumerate() {
452 for (j, to) in self.cluster_ids.iter().enumerate() {
453 let prob = self.probabilities[i][j];
454 if prob > 0.0 {
455 transitions.push((*from, *to, prob));
456 }
457 }
458 }
459 transitions
460 }
461
462 #[must_use]
465 pub fn stationary_distribution(&self) -> Option<Vec<f32>> {
466 let n = self.size();
468 if n == 0 {
469 return None;
470 }
471
472 let mut dist = vec![1.0 / n as f32; n];
473 let max_iterations = 1000;
474 let tolerance = 1e-8;
475
476 for _ in 0..max_iterations {
477 let mut new_dist = vec![0.0; n];
478
479 for j in 0..n {
481 for i in 0..n {
482 new_dist[j] += dist[i] * self.probabilities[i][j];
483 }
484 }
485
486 let diff: f32 = dist
488 .iter()
489 .zip(new_dist.iter())
490 .map(|(a, b)| (a - b).abs())
491 .sum();
492
493 dist = new_dist;
494
495 if diff < tolerance {
496 return Some(dist);
497 }
498 }
499
500 Some(dist)
501 }
502
503 pub fn rebuild_index_map(&mut self) {
505 self.index_map = self
506 .cluster_ids
507 .iter()
508 .enumerate()
509 .map(|(i, id)| (*id, i))
510 .collect();
511 }
512}
513
514#[derive(Debug, Clone, Serialize, Deserialize)]
516pub struct ClusteringResult {
517 pub clusters: Vec<super::entities::Cluster>,
519
520 pub noise: Vec<super::entities::EmbeddingId>,
522
523 pub silhouette_score: Option<f32>,
525
526 pub v_measure: Option<f32>,
528
529 pub prototypes: Vec<super::entities::Prototype>,
531
532 pub parameters: ClusteringParameters,
534
535 pub method: ClusteringMethod,
537}
538
539impl ClusteringResult {
540 #[must_use]
542 pub fn cluster_count(&self) -> usize {
543 self.clusters.len()
544 }
545
546 #[must_use]
548 pub fn noise_rate(&self) -> f32 {
549 let total = self
550 .clusters
551 .iter()
552 .map(|c| c.member_count())
553 .sum::<usize>()
554 + self.noise.len();
555 if total == 0 {
556 0.0
557 } else {
558 self.noise.len() as f32 / total as f32
559 }
560 }
561}
562
563#[cfg(test)]
564mod tests {
565 use super::*;
566
567 #[test]
568 fn test_clustering_config_creation() {
569 let config = ClusteringConfig::hdbscan(10, 5);
570 assert!(matches!(config.method, ClusteringMethod::HDBSCAN));
571 assert_eq!(config.parameters.min_cluster_size, 10);
572 assert_eq!(config.parameters.min_samples, 5);
573 }
574
575 #[test]
576 fn test_transition_matrix() {
577 let c1 = ClusterId::new();
578 let c2 = ClusterId::new();
579 let c3 = ClusterId::new();
580
581 let mut matrix = TransitionMatrix::new(vec![c1, c2, c3]);
582
583 matrix.record_transition(&c1, &c2);
585 matrix.record_transition(&c1, &c2);
586 matrix.record_transition(&c1, &c3);
587 matrix.record_transition(&c2, &c1);
588
589 matrix.compute_probabilities();
590
591 assert!((matrix.probability(&c1, &c2).unwrap() - 2.0 / 3.0).abs() < 0.001);
593 assert!((matrix.probability(&c1, &c3).unwrap() - 1.0 / 3.0).abs() < 0.001);
595 assert!((matrix.probability(&c2, &c1).unwrap() - 1.0).abs() < 0.001);
597 }
598
599 #[test]
600 fn test_motif_config() {
601 let config = MotifConfig::strict();
602 assert_eq!(config.min_length, 3);
603 assert_eq!(config.min_occurrences, 5);
604 assert!(!config.allow_overlap);
605
606 let relaxed = MotifConfig::relaxed();
607 assert!(relaxed.allow_overlap);
608 assert_eq!(relaxed.max_gap, 2);
609 }
610
611 #[test]
612 fn test_distance_metric_display() {
613 assert_eq!(format!("{}", DistanceMetric::Cosine), "Cosine");
614 assert_eq!(format!("{}", DistanceMetric::Euclidean), "Euclidean");
615 }
616}