Skip to main content

scirs2_integrate/pinn/
types.rs

1//! Types and configuration for Physics-Informed Neural Networks.
2
3/// Boundary condition types for PDE problems.
4#[derive(Debug, Clone)]
5#[non_exhaustive]
6pub enum BoundaryCondition {
7    /// Dirichlet boundary condition: u = value
8    Dirichlet {
9        /// The prescribed value at the boundary
10        value: f64,
11    },
12    /// Neumann boundary condition: du/dn = flux
13    Neumann {
14        /// The prescribed normal flux at the boundary
15        flux: f64,
16    },
17    /// Robin boundary condition: alpha*u + beta*du/dn = value
18    Robin {
19        /// Coefficient of u
20        alpha: f64,
21        /// Coefficient of du/dn
22        beta: f64,
23        /// Right-hand side value
24        value: f64,
25    },
26    /// Periodic boundary condition
27    Periodic,
28}
29
30/// Strategy for placing collocation points in the domain interior.
31#[derive(Debug, Clone, Default)]
32#[non_exhaustive]
33pub enum CollocationStrategy {
34    /// Uniformly random placement
35    #[default]
36    Random,
37    /// Latin Hypercube Sampling for better space-filling
38    LatinHypercube,
39    /// Uniform grid placement
40    UniformGrid,
41    /// Adaptive placement with more points where the PDE residual is high
42    AdaptiveResidual,
43}
44
45/// Configuration for PINN training.
46#[derive(Debug, Clone)]
47pub struct PINNConfig {
48    /// Number of neurons per hidden layer (default: \[64, 64, 64\])
49    pub hidden_layers: Vec<usize>,
50    /// Adam optimizer learning rate (default: 1e-3)
51    pub learning_rate: f64,
52    /// Maximum training epochs (default: 10000)
53    pub max_epochs: usize,
54    /// Number of interior collocation points (default: 1000)
55    pub n_collocation: usize,
56    /// Number of boundary points per boundary segment (default: 100)
57    pub n_boundary: usize,
58    /// Weight for PDE residual loss (default: 1.0)
59    pub physics_weight: f64,
60    /// Weight for boundary condition loss (default: 10.0)
61    pub boundary_weight: f64,
62    /// Weight for observational data loss (default: 1.0)
63    pub data_weight: f64,
64    /// Collocation point placement strategy
65    pub collocation: CollocationStrategy,
66    /// Stop training when total loss falls below this threshold (default: 1e-6)
67    pub convergence_tol: f64,
68}
69
70impl Default for PINNConfig {
71    fn default() -> Self {
72        Self {
73            hidden_layers: vec![64, 64, 64],
74            learning_rate: 1e-3,
75            max_epochs: 10000,
76            n_collocation: 1000,
77            n_boundary: 100,
78            physics_weight: 1.0,
79            boundary_weight: 10.0,
80            data_weight: 1.0,
81            collocation: CollocationStrategy::default(),
82            convergence_tol: 1e-6,
83        }
84    }
85}
86
87/// Result of PINN training.
88#[derive(Debug, Clone)]
89pub struct PINNResult {
90    /// Total loss at the end of training
91    pub final_loss: f64,
92    /// Physics (PDE residual) loss component
93    pub physics_loss: f64,
94    /// Boundary condition loss component
95    pub boundary_loss: f64,
96    /// Observational data loss component
97    pub data_loss: f64,
98    /// Number of epochs actually trained
99    pub epochs_trained: usize,
100    /// Whether training converged (loss < convergence_tol)
101    pub converged: bool,
102    /// Loss value at each epoch
103    pub loss_history: Vec<f64>,
104}
105
106/// Defines a PDE problem for PINN solving.
107#[derive(Debug, Clone)]
108pub struct PDEProblem {
109    /// Number of spatial dimensions (1, 2, or 3)
110    pub spatial_dim: usize,
111    /// Domain bounds per spatial dimension: \[(x_min, x_max), (y_min, y_max), ...\]
112    pub domain: Vec<(f64, f64)>,
113    /// Boundary specifications
114    pub boundaries: Vec<Boundary>,
115    /// Whether the problem includes a time variable
116    pub has_time: bool,
117    /// Time domain bounds (t_min, t_max), required when `has_time` is true
118    pub time_domain: Option<(f64, f64)>,
119}
120
121/// A single boundary specification.
122#[derive(Debug, Clone)]
123pub struct Boundary {
124    /// Which spatial dimension this boundary constrains
125    pub dim: usize,
126    /// Whether this is the low or high end of the dimension
127    pub side: BoundarySide,
128    /// The boundary condition to apply
129    pub condition: BoundaryCondition,
130}
131
132/// Which side of a dimension a boundary lies on.
133#[derive(Debug, Clone, Copy, PartialEq, Eq)]
134#[non_exhaustive]
135pub enum BoundarySide {
136    /// Lower bound of the dimension
137    Low,
138    /// Upper bound of the dimension
139    High,
140}
141
142#[cfg(test)]
143mod tests {
144    use super::*;
145
146    #[test]
147    fn test_default_config() {
148        let config = PINNConfig::default();
149        assert_eq!(config.hidden_layers, vec![64, 64, 64]);
150        assert!((config.learning_rate - 1e-3).abs() < 1e-15);
151        assert_eq!(config.max_epochs, 10000);
152        assert_eq!(config.n_collocation, 1000);
153        assert_eq!(config.n_boundary, 100);
154        assert!((config.physics_weight - 1.0).abs() < 1e-15);
155        assert!((config.boundary_weight - 10.0).abs() < 1e-15);
156        assert!((config.data_weight - 1.0).abs() < 1e-15);
157        assert!((config.convergence_tol - 1e-6).abs() < 1e-15);
158    }
159
160    #[test]
161    fn test_boundary_condition_variants() {
162        let dirichlet = BoundaryCondition::Dirichlet { value: 1.0 };
163        let neumann = BoundaryCondition::Neumann { flux: 0.5 };
164        let robin = BoundaryCondition::Robin {
165            alpha: 1.0,
166            beta: 2.0,
167            value: 3.0,
168        };
169        let periodic = BoundaryCondition::Periodic;
170
171        // Verify Debug trait works
172        let _ = format!("{:?}", dirichlet);
173        let _ = format!("{:?}", neumann);
174        let _ = format!("{:?}", robin);
175        let _ = format!("{:?}", periodic);
176    }
177
178    #[test]
179    fn test_pde_problem_construction() {
180        let problem = PDEProblem {
181            spatial_dim: 2,
182            domain: vec![(0.0, 1.0), (0.0, 1.0)],
183            boundaries: vec![
184                Boundary {
185                    dim: 0,
186                    side: BoundarySide::Low,
187                    condition: BoundaryCondition::Dirichlet { value: 0.0 },
188                },
189                Boundary {
190                    dim: 0,
191                    side: BoundarySide::High,
192                    condition: BoundaryCondition::Dirichlet { value: 1.0 },
193                },
194            ],
195            has_time: false,
196            time_domain: None,
197        };
198
199        assert_eq!(problem.spatial_dim, 2);
200        assert_eq!(problem.domain.len(), 2);
201        assert_eq!(problem.boundaries.len(), 2);
202        assert!(!problem.has_time);
203    }
204
205    #[test]
206    fn test_boundary_side_equality() {
207        assert_eq!(BoundarySide::Low, BoundarySide::Low);
208        assert_eq!(BoundarySide::High, BoundarySide::High);
209        assert_ne!(BoundarySide::Low, BoundarySide::High);
210    }
211
212    #[test]
213    fn test_pinn_result_fields() {
214        let result = PINNResult {
215            final_loss: 0.001,
216            physics_loss: 0.0005,
217            boundary_loss: 0.0003,
218            data_loss: 0.0002,
219            epochs_trained: 5000,
220            converged: true,
221            loss_history: vec![1.0, 0.5, 0.1, 0.01, 0.001],
222        };
223
224        assert!(result.converged);
225        assert_eq!(result.epochs_trained, 5000);
226        assert_eq!(result.loss_history.len(), 5);
227        assert!(result.final_loss < 0.01);
228    }
229
230    #[test]
231    fn test_collocation_strategy_default() {
232        let strategy = CollocationStrategy::default();
233        assert!(matches!(strategy, CollocationStrategy::Random));
234    }
235
236    #[test]
237    fn test_time_dependent_problem() {
238        let problem = PDEProblem {
239            spatial_dim: 1,
240            domain: vec![(0.0, 1.0)],
241            boundaries: vec![],
242            has_time: true,
243            time_domain: Some((0.0, 1.0)),
244        };
245
246        assert!(problem.has_time);
247        assert!(problem.time_domain.is_some());
248        let (t_min, t_max) = problem.time_domain.unwrap_or((0.0, 0.0));
249        assert!((t_min - 0.0).abs() < 1e-15);
250        assert!((t_max - 1.0).abs() < 1e-15);
251    }
252}