sklears_mixture/spatial/
spatial_constraints.rs

1//! Spatial Constraint Types and Configuration
2//!
3//! This module defines the various types of spatial constraints that can be applied
4//! to mixture models, along with configuration structures for spatial modeling parameters.
5
6use crate::common::CovarianceType;
7
8/// Spatial constraint types for spatial mixture models
9///
10/// Different types of spatial constraints can be applied to mixture models
11/// to incorporate spatial relationships and ensure spatial coherence in clustering.
12#[derive(Debug, Clone, PartialEq)]
13pub enum SpatialConstraint {
14    /// Adjacency-based spatial constraint using neighbor relationships
15    ///
16    /// Uses topological adjacency to define spatial relationships.
17    /// Neighboring spatial units influence each other's cluster assignments.
18    Adjacency,
19
20    /// Distance-based constraint with specified radius
21    ///
22    /// Spatial units within a given distance radius influence each other.
23    /// The radius parameter controls the spatial influence range.
24    Distance { radius: f64 },
25
26    /// Grid-based constraint for regular spatial patterns
27    ///
28    /// Assumes data is arranged in a regular grid pattern.
29    /// Neighboring grid cells influence each other based on grid topology.
30    Grid { rows: usize, cols: usize },
31
32    /// Custom spatial constraint with user-defined relationships
33    ///
34    /// Allows for arbitrary spatial relationship definitions.
35    /// Users can provide custom spatial weights matrices.
36    Custom,
37}
38
39impl Default for SpatialConstraint {
40    fn default() -> Self {
41        Self::Distance { radius: 1.0 }
42    }
43}
44
45impl SpatialConstraint {
46    /// Check if the constraint requires coordinate information
47    pub fn requires_coordinates(&self) -> bool {
48        matches!(self, Self::Distance { .. })
49    }
50
51    /// Check if the constraint requires grid dimensions
52    pub fn requires_grid_dimensions(&self) -> bool {
53        matches!(self, Self::Grid { .. })
54    }
55
56    /// Get a descriptive name for the constraint type
57    pub fn constraint_name(&self) -> &'static str {
58        match self {
59            Self::Adjacency => "adjacency",
60            Self::Distance { .. } => "distance",
61            Self::Grid { .. } => "grid",
62            Self::Custom => "custom",
63        }
64    }
65}
66
67/// Configuration for spatially constrained mixture models
68///
69/// This structure contains all parameters needed to configure spatial mixture models,
70/// including the number of components, covariance structure, spatial constraints,
71/// and optimization parameters.
72#[derive(Debug, Clone)]
73pub struct SpatialMixtureConfig {
74    /// Number of mixture components
75    pub n_components: usize,
76
77    /// Type of covariance matrix structure
78    pub covariance_type: CovarianceType,
79
80    /// Spatial constraint configuration
81    pub spatial_constraint: SpatialConstraint,
82
83    /// Weight for spatial constraint term (0.0 to 1.0)
84    ///
85    /// Controls the balance between data likelihood and spatial constraint.
86    /// Higher values enforce stronger spatial coherence.
87    pub spatial_weight: f64,
88
89    /// Maximum number of EM iterations
90    pub max_iter: usize,
91
92    /// Convergence tolerance for EM algorithm
93    pub tol: f64,
94
95    /// Random seed for reproducibility
96    pub random_state: Option<u64>,
97}
98
99impl Default for SpatialMixtureConfig {
100    fn default() -> Self {
101        Self {
102            n_components: 2,
103            covariance_type: CovarianceType::Full,
104            spatial_constraint: SpatialConstraint::default(),
105            spatial_weight: 0.1,
106            max_iter: 100,
107            tol: 1e-4,
108            random_state: None,
109        }
110    }
111}
112
113impl SpatialMixtureConfig {
114    /// Create a new configuration with specified number of components
115    pub fn new(n_components: usize) -> Self {
116        Self {
117            n_components,
118            ..Default::default()
119        }
120    }
121
122    /// Validate the configuration parameters
123    pub fn validate(&self) -> Result<(), String> {
124        if self.n_components == 0 {
125            return Err("Number of components must be greater than 0".to_string());
126        }
127
128        if !(0.0..=1.0).contains(&self.spatial_weight) {
129            return Err("Spatial weight must be between 0.0 and 1.0".to_string());
130        }
131
132        if self.max_iter == 0 {
133            return Err("Maximum iterations must be greater than 0".to_string());
134        }
135
136        if self.tol <= 0.0 {
137            return Err("Tolerance must be positive".to_string());
138        }
139
140        // Validate spatial constraint specific parameters
141        match &self.spatial_constraint {
142            SpatialConstraint::Distance { radius } => {
143                if *radius <= 0.0 {
144                    return Err("Distance radius must be positive".to_string());
145                }
146            }
147            SpatialConstraint::Grid { rows, cols } => {
148                if *rows == 0 || *cols == 0 {
149                    return Err("Grid dimensions must be greater than 0".to_string());
150                }
151            }
152            _ => {}
153        }
154
155        Ok(())
156    }
157
158    /// Get the effective number of parameters for this configuration
159    pub fn parameter_count(&self, n_features: usize) -> usize {
160        let weight_params = self.n_components - 1; // n-1 independent weights
161        let mean_params = self.n_components * n_features;
162
163        let covariance_params = match self.covariance_type {
164            CovarianceType::Full => self.n_components * n_features * (n_features + 1) / 2,
165            CovarianceType::Diagonal => self.n_components * n_features,
166            CovarianceType::Spherical => self.n_components,
167            CovarianceType::Tied => n_features * (n_features + 1) / 2,
168        };
169
170        weight_params + mean_params + covariance_params
171    }
172
173    /// Check if the configuration is suitable for the given data size
174    pub fn check_data_requirements(
175        &self,
176        n_samples: usize,
177        n_features: usize,
178    ) -> Result<(), String> {
179        if n_samples < self.n_components {
180            return Err(format!(
181                "Number of samples ({}) must be at least the number of components ({})",
182                n_samples, self.n_components
183            ));
184        }
185
186        let min_samples = self.parameter_count(n_features) * 2;
187        if n_samples < min_samples {
188            return Err(format!(
189                "Insufficient data: need at least {} samples for {} parameters",
190                min_samples,
191                self.parameter_count(n_features)
192            ));
193        }
194
195        Ok(())
196    }
197}
198
199/// Spatial regularization types for mixture models
200#[derive(Debug, Clone, PartialEq)]
201pub enum SpatialRegularization {
202    /// No spatial regularization
203    None,
204
205    /// L1 spatial penalty (encourages sparse spatial effects)
206    L1 { lambda: f64 },
207
208    /// L2 spatial penalty (smooth spatial effects)
209    L2 { lambda: f64 },
210
211    /// Total variation penalty (piecewise constant spatial effects)
212    TotalVariation { lambda: f64 },
213
214    /// Elastic net combination of L1 and L2
215    ElasticNet { l1_ratio: f64, lambda: f64 },
216}
217
218impl Default for SpatialRegularization {
219    fn default() -> Self {
220        Self::None
221    }
222}
223
224impl SpatialRegularization {
225    /// Get the regularization strength parameter
226    pub fn lambda(&self) -> f64 {
227        match self {
228            Self::None => 0.0,
229            Self::L1 { lambda } => *lambda,
230            Self::L2 { lambda } => *lambda,
231            Self::TotalVariation { lambda } => *lambda,
232            Self::ElasticNet { lambda, .. } => *lambda,
233        }
234    }
235
236    /// Check if regularization is active
237    pub fn is_active(&self) -> bool {
238        !matches!(self, Self::None) && self.lambda() > 0.0
239    }
240}
241
242/// Spatial smoothing parameters for mixture models
243#[derive(Debug, Clone)]
244pub struct SpatialSmoothingConfig {
245    /// Type of spatial regularization
246    pub regularization: SpatialRegularization,
247
248    /// Kernel bandwidth for spatial smoothing
249    pub bandwidth: f64,
250
251    /// Whether to use adaptive bandwidth based on local density
252    pub adaptive_bandwidth: bool,
253
254    /// Minimum bandwidth value (for adaptive bandwidth)
255    pub min_bandwidth: f64,
256
257    /// Maximum bandwidth value (for adaptive bandwidth)
258    pub max_bandwidth: f64,
259}
260
261impl Default for SpatialSmoothingConfig {
262    fn default() -> Self {
263        Self {
264            regularization: SpatialRegularization::default(),
265            bandwidth: 1.0,
266            adaptive_bandwidth: false,
267            min_bandwidth: 0.1,
268            max_bandwidth: 10.0,
269        }
270    }
271}
272
273#[allow(non_snake_case)]
274#[cfg(test)]
275mod tests {
276    use super::*;
277
278    #[test]
279    fn test_spatial_constraint_default() {
280        let constraint = SpatialConstraint::default();
281        assert!(matches!(constraint, SpatialConstraint::Distance { radius } if radius == 1.0));
282    }
283
284    #[test]
285    fn test_spatial_constraint_properties() {
286        let distance_constraint = SpatialConstraint::Distance { radius: 2.0 };
287        assert!(distance_constraint.requires_coordinates());
288        assert!(!distance_constraint.requires_grid_dimensions());
289        assert_eq!(distance_constraint.constraint_name(), "distance");
290
291        let grid_constraint = SpatialConstraint::Grid { rows: 10, cols: 10 };
292        assert!(!grid_constraint.requires_coordinates());
293        assert!(grid_constraint.requires_grid_dimensions());
294        assert_eq!(grid_constraint.constraint_name(), "grid");
295    }
296
297    #[test]
298    fn test_spatial_mixture_config_validation() {
299        let mut config = SpatialMixtureConfig::new(3);
300        assert!(config.validate().is_ok());
301
302        config.n_components = 0;
303        assert!(config.validate().is_err());
304
305        config.n_components = 3;
306        config.spatial_weight = 1.5;
307        assert!(config.validate().is_err());
308
309        config.spatial_weight = 0.5;
310        config.max_iter = 0;
311        assert!(config.validate().is_err());
312    }
313
314    #[test]
315    fn test_spatial_mixture_config_parameter_count() {
316        let config = SpatialMixtureConfig {
317            n_components: 3,
318            covariance_type: CovarianceType::Full,
319            ..Default::default()
320        };
321
322        let n_features = 2;
323        let param_count = config.parameter_count(n_features);
324
325        // 3 components: 2 weights + 6 means + 9 covariances = 17 parameters
326        assert_eq!(param_count, 17);
327    }
328
329    #[test]
330    fn test_spatial_regularization() {
331        let l1_reg = SpatialRegularization::L1 { lambda: 0.1 };
332        assert!(l1_reg.is_active());
333        assert_eq!(l1_reg.lambda(), 0.1);
334
335        let no_reg = SpatialRegularization::None;
336        assert!(!no_reg.is_active());
337        assert_eq!(no_reg.lambda(), 0.0);
338    }
339}