sklears_mixture/spatial/
spatial_constraints.rs1use crate::common::CovarianceType;
7
8#[derive(Debug, Clone, PartialEq)]
13pub enum SpatialConstraint {
14 Adjacency,
19
20 Distance { radius: f64 },
25
26 Grid { rows: usize, cols: usize },
31
32 Custom,
37}
38
39impl Default for SpatialConstraint {
40 fn default() -> Self {
41 Self::Distance { radius: 1.0 }
42 }
43}
44
45impl SpatialConstraint {
46 pub fn requires_coordinates(&self) -> bool {
48 matches!(self, Self::Distance { .. })
49 }
50
51 pub fn requires_grid_dimensions(&self) -> bool {
53 matches!(self, Self::Grid { .. })
54 }
55
56 pub fn constraint_name(&self) -> &'static str {
58 match self {
59 Self::Adjacency => "adjacency",
60 Self::Distance { .. } => "distance",
61 Self::Grid { .. } => "grid",
62 Self::Custom => "custom",
63 }
64 }
65}
66
67#[derive(Debug, Clone)]
73pub struct SpatialMixtureConfig {
74 pub n_components: usize,
76
77 pub covariance_type: CovarianceType,
79
80 pub spatial_constraint: SpatialConstraint,
82
83 pub spatial_weight: f64,
88
89 pub max_iter: usize,
91
92 pub tol: f64,
94
95 pub random_state: Option<u64>,
97}
98
99impl Default for SpatialMixtureConfig {
100 fn default() -> Self {
101 Self {
102 n_components: 2,
103 covariance_type: CovarianceType::Full,
104 spatial_constraint: SpatialConstraint::default(),
105 spatial_weight: 0.1,
106 max_iter: 100,
107 tol: 1e-4,
108 random_state: None,
109 }
110 }
111}
112
113impl SpatialMixtureConfig {
114 pub fn new(n_components: usize) -> Self {
116 Self {
117 n_components,
118 ..Default::default()
119 }
120 }
121
122 pub fn validate(&self) -> Result<(), String> {
124 if self.n_components == 0 {
125 return Err("Number of components must be greater than 0".to_string());
126 }
127
128 if !(0.0..=1.0).contains(&self.spatial_weight) {
129 return Err("Spatial weight must be between 0.0 and 1.0".to_string());
130 }
131
132 if self.max_iter == 0 {
133 return Err("Maximum iterations must be greater than 0".to_string());
134 }
135
136 if self.tol <= 0.0 {
137 return Err("Tolerance must be positive".to_string());
138 }
139
140 match &self.spatial_constraint {
142 SpatialConstraint::Distance { radius } => {
143 if *radius <= 0.0 {
144 return Err("Distance radius must be positive".to_string());
145 }
146 }
147 SpatialConstraint::Grid { rows, cols } => {
148 if *rows == 0 || *cols == 0 {
149 return Err("Grid dimensions must be greater than 0".to_string());
150 }
151 }
152 _ => {}
153 }
154
155 Ok(())
156 }
157
158 pub fn parameter_count(&self, n_features: usize) -> usize {
160 let weight_params = self.n_components - 1; let mean_params = self.n_components * n_features;
162
163 let covariance_params = match self.covariance_type {
164 CovarianceType::Full => self.n_components * n_features * (n_features + 1) / 2,
165 CovarianceType::Diagonal => self.n_components * n_features,
166 CovarianceType::Spherical => self.n_components,
167 CovarianceType::Tied => n_features * (n_features + 1) / 2,
168 };
169
170 weight_params + mean_params + covariance_params
171 }
172
173 pub fn check_data_requirements(
175 &self,
176 n_samples: usize,
177 n_features: usize,
178 ) -> Result<(), String> {
179 if n_samples < self.n_components {
180 return Err(format!(
181 "Number of samples ({}) must be at least the number of components ({})",
182 n_samples, self.n_components
183 ));
184 }
185
186 let min_samples = self.parameter_count(n_features) * 2;
187 if n_samples < min_samples {
188 return Err(format!(
189 "Insufficient data: need at least {} samples for {} parameters",
190 min_samples,
191 self.parameter_count(n_features)
192 ));
193 }
194
195 Ok(())
196 }
197}
198
199#[derive(Debug, Clone, PartialEq)]
201pub enum SpatialRegularization {
202 None,
204
205 L1 { lambda: f64 },
207
208 L2 { lambda: f64 },
210
211 TotalVariation { lambda: f64 },
213
214 ElasticNet { l1_ratio: f64, lambda: f64 },
216}
217
218impl Default for SpatialRegularization {
219 fn default() -> Self {
220 Self::None
221 }
222}
223
224impl SpatialRegularization {
225 pub fn lambda(&self) -> f64 {
227 match self {
228 Self::None => 0.0,
229 Self::L1 { lambda } => *lambda,
230 Self::L2 { lambda } => *lambda,
231 Self::TotalVariation { lambda } => *lambda,
232 Self::ElasticNet { lambda, .. } => *lambda,
233 }
234 }
235
236 pub fn is_active(&self) -> bool {
238 !matches!(self, Self::None) && self.lambda() > 0.0
239 }
240}
241
242#[derive(Debug, Clone)]
244pub struct SpatialSmoothingConfig {
245 pub regularization: SpatialRegularization,
247
248 pub bandwidth: f64,
250
251 pub adaptive_bandwidth: bool,
253
254 pub min_bandwidth: f64,
256
257 pub max_bandwidth: f64,
259}
260
261impl Default for SpatialSmoothingConfig {
262 fn default() -> Self {
263 Self {
264 regularization: SpatialRegularization::default(),
265 bandwidth: 1.0,
266 adaptive_bandwidth: false,
267 min_bandwidth: 0.1,
268 max_bandwidth: 10.0,
269 }
270 }
271}
272
273#[allow(non_snake_case)]
274#[cfg(test)]
275mod tests {
276 use super::*;
277
278 #[test]
279 fn test_spatial_constraint_default() {
280 let constraint = SpatialConstraint::default();
281 assert!(matches!(constraint, SpatialConstraint::Distance { radius } if radius == 1.0));
282 }
283
284 #[test]
285 fn test_spatial_constraint_properties() {
286 let distance_constraint = SpatialConstraint::Distance { radius: 2.0 };
287 assert!(distance_constraint.requires_coordinates());
288 assert!(!distance_constraint.requires_grid_dimensions());
289 assert_eq!(distance_constraint.constraint_name(), "distance");
290
291 let grid_constraint = SpatialConstraint::Grid { rows: 10, cols: 10 };
292 assert!(!grid_constraint.requires_coordinates());
293 assert!(grid_constraint.requires_grid_dimensions());
294 assert_eq!(grid_constraint.constraint_name(), "grid");
295 }
296
297 #[test]
298 fn test_spatial_mixture_config_validation() {
299 let mut config = SpatialMixtureConfig::new(3);
300 assert!(config.validate().is_ok());
301
302 config.n_components = 0;
303 assert!(config.validate().is_err());
304
305 config.n_components = 3;
306 config.spatial_weight = 1.5;
307 assert!(config.validate().is_err());
308
309 config.spatial_weight = 0.5;
310 config.max_iter = 0;
311 assert!(config.validate().is_err());
312 }
313
314 #[test]
315 fn test_spatial_mixture_config_parameter_count() {
316 let config = SpatialMixtureConfig {
317 n_components: 3,
318 covariance_type: CovarianceType::Full,
319 ..Default::default()
320 };
321
322 let n_features = 2;
323 let param_count = config.parameter_count(n_features);
324
325 assert_eq!(param_count, 17);
327 }
328
329 #[test]
330 fn test_spatial_regularization() {
331 let l1_reg = SpatialRegularization::L1 { lambda: 0.1 };
332 assert!(l1_reg.is_active());
333 assert_eq!(l1_reg.lambda(), 0.1);
334
335 let no_reg = SpatialRegularization::None;
336 assert!(!no_reg.is_active());
337 assert_eq!(no_reg.lambda(), 0.0);
338 }
339}