torsh_cluster/initialization/
forgy.rs1use super::InitializationStrategy;
4use crate::error::{ClusterError, ClusterResult};
5use scirs2_core::random::Random;
6use torsh_tensor::Tensor;
7
8#[derive(Debug, Default)]
10pub struct Forgy;
11
12impl InitializationStrategy for Forgy {
13 fn initialize(
14 &self,
15 data: &Tensor,
16 n_clusters: usize,
17 seed: Option<u64>,
18 ) -> ClusterResult<Tensor> {
19 let n_samples = data.shape().dims()[0];
20 let n_features = data.shape().dims()[1];
21
22 if n_clusters > n_samples {
23 return Err(ClusterError::InvalidClusters(n_clusters));
24 }
25
26 let mut rng = Random::seed(seed.unwrap_or_else(|| {
27 use std::time::{SystemTime, UNIX_EPOCH};
28 SystemTime::now()
29 .duration_since(UNIX_EPOCH)
30 .expect("system time should be after UNIX_EPOCH")
31 .as_secs()
32 }));
33
34 let data_vec = data.to_vec().map_err(ClusterError::TensorError)?;
35 let mut selected = std::collections::HashSet::new();
36 let mut centroids_data = Vec::with_capacity(n_clusters * n_features);
37
38 for _ in 0..n_clusters {
39 let mut idx = rng.gen_range(0..n_samples);
40 while selected.contains(&idx) {
41 idx = rng.gen_range(0..n_samples);
42 }
43 selected.insert(idx);
44
45 for j in 0..n_features {
46 centroids_data.push(data_vec[idx * n_features + j]);
47 }
48 }
49
50 Tensor::from_vec(centroids_data, &[n_clusters, n_features])
51 .map_err(ClusterError::TensorError)
52 }
53
54 fn name(&self) -> &str {
55 "Forgy"
56 }
57}