1use crate::error::{StatsError, StatsResult};
4use scirs2_core::ndarray::{s, Array1, Array2, Array3, ArrayView1, ArrayView2};
5use scirs2_core::numeric::{Float, FromPrimitive, One, Zero};
6use scirs2_core::random::Rng;
7use scirs2_core::simd_ops::SimdUnifiedOps;
8use std::marker::PhantomData;
9
10pub struct VariationalGMM<F> {
12 pub max_components: usize,
14 pub config: VariationalGMMConfig,
16 pub parameters: Option<VariationalGMMParameters<F>>,
18 pub lower_bound_history: Vec<F>,
20 _phantom: PhantomData<F>,
21}
22
23#[derive(Debug, Clone)]
25pub struct VariationalGMMConfig {
26 pub max_iter: usize,
28 pub tolerance: f64,
30 pub alpha: f64,
32 pub nu: f64,
34 pub mean_prior: Option<Vec<f64>>,
36 pub precision_prior: Option<Vec<Vec<f64>>>,
38 pub ard: bool,
40 pub seed: Option<u64>,
42}
43
44impl Default for VariationalGMMConfig {
45 fn default() -> Self {
46 Self {
47 max_iter: 100,
48 tolerance: 1e-6,
49 alpha: 1.0,
50 nu: 1.0,
51 mean_prior: None,
52 precision_prior: None,
53 ard: true,
54 seed: None,
55 }
56 }
57}
58
59#[derive(Debug, Clone)]
61pub struct VariationalGMMParameters<F> {
62 pub weight_concentration: Array1<F>,
64 pub mean_precision: Array1<F>,
66 pub means: Array2<F>,
68 pub degrees_of_freedom: Array1<F>,
70 pub scale_matrices: Array3<F>,
72 pub lower_bound: F,
74 pub effective_components: usize,
76 pub n_iter: usize,
78 pub converged: bool,
80}
81
82#[derive(Debug, Clone)]
84pub struct VariationalGMMResult<F> {
85 pub lower_bound: F,
87 pub effective_components: usize,
89 pub responsibilities: Array2<F>,
91 pub weights: Array1<F>,
93}
94
95impl<F> VariationalGMM<F>
96where
97 F: Float
98 + FromPrimitive
99 + SimdUnifiedOps
100 + Send
101 + Sync
102 + std::fmt::Debug
103 + std::fmt::Display
104 + std::iter::Sum<F>,
105{
106 pub fn new(max_components: usize, config: VariationalGMMConfig) -> Self {
108 Self {
109 max_components,
110 config,
111 parameters: None,
112 lower_bound_history: Vec::new(),
113 _phantom: PhantomData,
114 }
115 }
116
117 pub fn fit(&mut self, data: &ArrayView2<F>) -> StatsResult<VariationalGMMResult<F>> {
119 let (_n_samples, n_features) = data.dim();
120
121 let alpha_f: F = F::from(self.config.alpha)
122 .ok_or_else(|| StatsError::ComputationError("alpha conversion failed".into()))?;
123 let nu_f: F = F::from(self.config.nu)
124 .ok_or_else(|| StatsError::ComputationError("nu conversion failed".into()))?;
125 let n_feat_f: F = F::from(n_features)
126 .ok_or_else(|| StatsError::ComputationError("n_features conversion failed".into()))?;
127
128 let mut weight_concentration = Array1::from_elem(self.max_components, alpha_f);
129 let mean_precision_val = F::one();
130 let mut mean_precision = Array1::from_elem(self.max_components, mean_precision_val);
131 let mut means = self.initialize_means(data)?;
132 let mut degrees_of_freedom = Array1::from_elem(self.max_components, nu_f + n_feat_f);
133 let mut scale_matrices = Array3::zeros((self.max_components, n_features, n_features));
134 for k in 0..self.max_components {
135 for i in 0..n_features {
136 scale_matrices[[k, i, i]] = F::one();
137 }
138 }
139
140 let mut lower_bound = F::neg_infinity();
141 let mut converged = false;
142 let tol: F = F::from(self.config.tolerance)
143 .ok_or_else(|| StatsError::ComputationError("tolerance conversion failed".into()))?;
144
145 for iteration in 0..self.config.max_iter {
146 let responsibilities = self.compute_responsibilities(
147 data,
148 &means,
149 &scale_matrices,
150 °rees_of_freedom,
151 &weight_concentration,
152 )?;
153
154 let (new_wc, new_mp, new_means, new_dof, new_sm) =
155 self.update_parameters(data, &responsibilities)?;
156
157 let new_lb =
158 self.compute_lower_bound(data, &responsibilities, &new_wc, &new_means, &new_sm)?;
159
160 if iteration > 0 && (new_lb - lower_bound).abs() < tol {
161 converged = true;
162 }
163
164 weight_concentration = new_wc;
165 mean_precision = new_mp;
166 means = new_means;
167 degrees_of_freedom = new_dof;
168 scale_matrices = new_sm;
169 lower_bound = new_lb;
170 self.lower_bound_history.push(lower_bound);
171
172 if converged {
173 break;
174 }
175 }
176
177 let effective_components = self.compute_effective_components(&weight_concentration);
178 let responsibilities = self.compute_responsibilities(
179 data,
180 &means,
181 &scale_matrices,
182 °rees_of_freedom,
183 &weight_concentration,
184 )?;
185 let weights = self.compute_weights(&weight_concentration);
186
187 let parameters = VariationalGMMParameters {
188 weight_concentration,
189 mean_precision,
190 means,
191 degrees_of_freedom,
192 scale_matrices,
193 lower_bound,
194 effective_components,
195 n_iter: self.lower_bound_history.len(),
196 converged,
197 };
198 self.parameters = Some(parameters);
199
200 Ok(VariationalGMMResult {
201 lower_bound,
202 effective_components,
203 responsibilities,
204 weights,
205 })
206 }
207
208 fn initialize_means(&self, data: &ArrayView2<F>) -> StatsResult<Array2<F>> {
209 let (n_samples, n_features) = data.dim();
210 let mut means = Array2::zeros((self.max_components, n_features));
211 use scirs2_core::random::Random;
212 let mut init_rng = scirs2_core::random::thread_rng();
213 let mut rng = match self.config.seed {
214 Some(seed) => Random::seed(seed),
215 None => Random::seed(init_rng.random()),
216 };
217 for i in 0..self.max_components {
218 let idx = rng.random_range(0..n_samples);
219 means.row_mut(i).assign(&data.row(idx));
220 }
221 Ok(means)
222 }
223
224 fn compute_responsibilities(
225 &self,
226 data: &ArrayView2<F>,
227 means: &Array2<F>,
228 scale_matrices: &Array3<F>,
229 degrees_of_freedom: &Array1<F>,
230 weight_concentration: &Array1<F>,
231 ) -> StatsResult<Array2<F>> {
232 let n_samples = data.shape()[0];
233 let mut responsibilities = Array2::zeros((n_samples, self.max_components));
234
235 for i in 0..n_samples {
236 let mut log_probs = Array1::zeros(self.max_components);
237 for k in 0..self.max_components {
238 let log_weight = weight_concentration[k].ln();
239 let log_ll = self.compute_log_likelihood_component(
240 &data.row(i),
241 &means.row(k),
242 &scale_matrices.slice(s![k, .., ..]),
243 degrees_of_freedom[k],
244 )?;
245 log_probs[k] = log_weight + log_ll;
246 }
247 let log_sum = self.log_sum_exp(&log_probs);
248 for k in 0..self.max_components {
249 responsibilities[[i, k]] = (log_probs[k] - log_sum).exp();
250 }
251 }
252 Ok(responsibilities)
253 }
254
255 fn update_parameters(
256 &self,
257 data: &ArrayView2<F>,
258 responsibilities: &Array2<F>,
259 ) -> StatsResult<(Array1<F>, Array1<F>, Array2<F>, Array1<F>, Array3<F>)> {
260 let (n_samples, n_features) = data.dim();
261
262 let alpha_f: F = F::from(self.config.alpha)
263 .ok_or_else(|| StatsError::ComputationError("alpha conversion".into()))?;
264 let nu_f: F = F::from(self.config.nu)
265 .ok_or_else(|| StatsError::ComputationError("nu conversion".into()))?;
266 let n_feat_f: F = F::from(n_features)
267 .ok_or_else(|| StatsError::ComputationError("n_features conversion".into()))?;
268 let small: F = F::from(0.1)
269 .ok_or_else(|| StatsError::ComputationError("constant conversion".into()))?;
270
271 let mut weight_concentration = Array1::from_elem(self.max_components, alpha_f);
272 let mean_precision = Array1::ones(self.max_components);
273 let mut means = Array2::zeros((self.max_components, n_features));
274 let mut degrees_of_freedom = Array1::from_elem(self.max_components, nu_f + n_feat_f);
275 let mut scale_matrices = Array3::zeros((self.max_components, n_features, n_features));
276
277 for k in 0..self.max_components {
278 let nk: F = responsibilities.column(k).sum();
279 weight_concentration[k] = weight_concentration[k] + nk;
280
281 if nk > F::zero() {
282 for j in 0..n_features {
283 let mut weighted_sum = F::zero();
284 for i in 0..n_samples {
285 weighted_sum = weighted_sum + responsibilities[[i, k]] * data[[i, j]];
286 }
287 means[[k, j]] = weighted_sum / nk;
288 }
289 degrees_of_freedom[k] = nu_f + nk;
290 for i in 0..n_features {
291 scale_matrices[[k, i, i]] = F::one() + small * nk;
292 }
293 }
294 }
295
296 Ok((
297 weight_concentration,
298 mean_precision,
299 means,
300 degrees_of_freedom,
301 scale_matrices,
302 ))
303 }
304
305 fn compute_lower_bound(
306 &self,
307 data: &ArrayView2<F>,
308 responsibilities: &Array2<F>,
309 weight_concentration: &Array1<F>,
310 means: &Array2<F>,
311 scale_matrices: &Array3<F>,
312 ) -> StatsResult<F> {
313 let n_samples = data.shape()[0];
314 let mut lower_bound = F::zero();
315 let ten: F = F::from(10.0)
316 .ok_or_else(|| StatsError::ComputationError("constant conversion".into()))?;
317 let small_kl: F = F::from(0.01)
318 .ok_or_else(|| StatsError::ComputationError("constant conversion".into()))?;
319
320 for i in 0..n_samples {
321 for k in 0..self.max_components {
322 if responsibilities[[i, k]] > F::zero() {
323 let log_ll = self.compute_log_likelihood_component(
324 &data.row(i),
325 &means.row(k),
326 &scale_matrices.slice(s![k, .., ..]),
327 ten,
328 )?;
329 lower_bound = lower_bound + responsibilities[[i, k]] * log_ll;
330 }
331 }
332 }
333
334 for k in 0..self.max_components {
335 let w = weight_concentration[k];
336 if w > F::zero() {
337 lower_bound = lower_bound - w * w.ln() * small_kl;
338 }
339 }
340
341 Ok(lower_bound)
342 }
343
344 fn compute_effective_components(&self, wc: &Array1<F>) -> usize {
345 let total: F = wc.sum();
346 let threshold = F::from(0.01).unwrap_or(F::zero());
347 wc.iter().filter(|&&w| w / total > threshold).count()
348 }
349
350 fn compute_weights(&self, wc: &Array1<F>) -> Array1<F> {
351 let total: F = wc.sum();
352 wc.mapv(|w| w / total)
353 }
354
355 fn compute_log_likelihood_component(
356 &self,
357 point: &ArrayView1<F>,
358 mean: &ArrayView1<F>,
359 _scale_matrix: &scirs2_core::ndarray::ArrayBase<
360 scirs2_core::ndarray::ViewRepr<&F>,
361 scirs2_core::ndarray::Dim<[usize; 2]>,
362 >,
363 _degrees_of_freedom: F,
364 ) -> StatsResult<F> {
365 let half: F = F::from(0.5)
366 .ok_or_else(|| StatsError::ComputationError("constant conversion".into()))?;
367 let mut sum_sq = F::zero();
368 for (x, m) in point.iter().zip(mean.iter()) {
369 let diff = *x - *m;
370 sum_sq = sum_sq + diff * diff;
371 }
372 Ok(-half * sum_sq)
373 }
374
375 fn log_sum_exp(&self, logvalues: &Array1<F>) -> F {
376 let max_val = logvalues.iter().fold(F::neg_infinity(), |a, &b| a.max(b));
377 if max_val == F::neg_infinity() {
378 return F::neg_infinity();
379 }
380 let sum: F = logvalues.iter().map(|&x| (x - max_val).exp()).sum();
381 max_val + sum.ln()
382 }
383}