torsh_cluster/initialization/
random_partition.rs1use super::InitializationStrategy;
4use crate::error::{ClusterError, ClusterResult};
5use scirs2_core::random::Random;
6use torsh_tensor::Tensor;
7
8#[derive(Debug, Default)]
10pub struct RandomPartition;
11
12impl InitializationStrategy for RandomPartition {
13 fn initialize(
14 &self,
15 data: &Tensor,
16 n_clusters: usize,
17 seed: Option<u64>,
18 ) -> ClusterResult<Tensor> {
19 let n_samples = data.shape().dims()[0];
20 let n_features = data.shape().dims()[1];
21
22 if n_clusters > n_samples {
23 return Err(ClusterError::InvalidClusters(n_clusters));
24 }
25
26 let mut rng = Random::seed(seed.unwrap_or_else(|| {
27 use std::time::{SystemTime, UNIX_EPOCH};
28 SystemTime::now()
29 .duration_since(UNIX_EPOCH)
30 .expect("system time should be after UNIX_EPOCH")
31 .as_secs()
32 }));
33
34 let data_vec = data.to_vec().map_err(ClusterError::TensorError)?;
35
36 let mut cluster_assignments = Vec::new();
38 for _ in 0..n_samples {
39 cluster_assignments.push(rng.gen_range(0..n_clusters));
40 }
41
42 let mut centroids_data = vec![0.0; n_clusters * n_features];
44 let mut cluster_counts = vec![0; n_clusters];
45
46 for i in 0..n_samples {
47 let cluster = cluster_assignments[i];
48 cluster_counts[cluster] += 1;
49 for j in 0..n_features {
50 centroids_data[cluster * n_features + j] += data_vec[i * n_features + j];
51 }
52 }
53
54 for k in 0..n_clusters {
56 if cluster_counts[k] > 0 {
57 for j in 0..n_features {
58 centroids_data[k * n_features + j] /= cluster_counts[k] as f32;
59 }
60 }
61 }
62
63 Tensor::from_vec(centroids_data, &[n_clusters, n_features])
64 .map_err(ClusterError::TensorError)
65 }
66
67 fn name(&self) -> &str {
68 "Random Partition"
69 }
70}