Skip to main content

scirs2_integrate/gpu_fem/
dispatch.rs

1//! Auto-dispatch for GPU/CPU FEM stiffness assembly.
2//!
3//! Selects the GPU path when the `gpu_fem` feature is enabled and the mesh is
4//! large enough; falls back to the CPU Rayon-parallel path otherwise.
5
6use crate::error::IntegrateError;
7use crate::gpu_fem::stiffness::{
8    assemble_stiffness_cpu, Element2D, MeshElement2D, StiffnessMatrix,
9};
10use scirs2_core::ndarray::Array2;
11
12// ─────────────────────────────────────────────────────────────────────────────
13// Configuration
14// ─────────────────────────────────────────────────────────────────────────────
15
16/// Configuration for the FEM assembly dispatcher.
17#[derive(Debug, Clone)]
18pub struct FemAssemblyConfig {
19    /// Number of elements above which the GPU path is attempted (default 10 000).
20    pub gpu_threshold: usize,
21    /// Whether to attempt GPU acceleration at all.
22    pub use_gpu: bool,
23}
24
25impl Default for FemAssemblyConfig {
26    fn default() -> Self {
27        FemAssemblyConfig {
28            gpu_threshold: 10_000,
29            use_gpu: true,
30        }
31    }
32}
33
34// ─────────────────────────────────────────────────────────────────────────────
35// Error type
36// ─────────────────────────────────────────────────────────────────────────────
37
38/// Errors specific to the GPU FEM dispatch layer.
39#[derive(Debug, Clone)]
40pub enum GpuFemError {
41    /// GPU device not available on this machine.
42    GpuNotAvailable,
43    /// GPU assembly failed with a diagnostic message.
44    AssemblyFailed(String),
45}
46
47impl std::fmt::Display for GpuFemError {
48    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
49        match self {
50            GpuFemError::GpuNotAvailable => write!(f, "GPU not available"),
51            GpuFemError::AssemblyFailed(msg) => write!(f, "GPU assembly failed: {msg}"),
52        }
53    }
54}
55
56impl std::error::Error for GpuFemError {}
57
58// ─────────────────────────────────────────────────────────────────────────────
59// Auto-dispatch (Element2D path)
60// ─────────────────────────────────────────────────────────────────────────────
61
62/// Assemble the stiffness matrix using the best available backend.
63///
64/// When `config.use_gpu` is `true`, the mesh is larger than
65/// `config.gpu_threshold`, and the `gpu_fem` feature is compiled in, the GPU
66/// path is tried first.  On failure (or if conditions are not met) the
67/// function falls back to the CPU Rayon-parallel path, which always succeeds.
68pub fn assemble_stiffness_auto(
69    elements: &[Element2D],
70    d_matrix: &Array2<f64>,
71    n_nodes: usize,
72    config: &FemAssemblyConfig,
73) -> Result<StiffnessMatrix, IntegrateError> {
74    #[cfg(feature = "gpu_fem")]
75    if config.use_gpu && elements.len() >= config.gpu_threshold {
76        match crate::gpu_fem::wgpu_backend::assemble_stiffness_gpu(elements, d_matrix, n_nodes) {
77            Ok(result) => return Ok(result),
78            Err(_) => {
79                // Fall through to CPU path
80            }
81        }
82    }
83
84    // CPU path is always available
85    assemble_stiffness_cpu(elements, d_matrix, n_nodes)
86}
87
88/// Assemble from a mesh with explicit node connectivity.
89pub fn assemble_stiffness_mesh_auto(
90    mesh_elements: &[MeshElement2D],
91    d_matrix: &Array2<f64>,
92    n_nodes: usize,
93    config: &FemAssemblyConfig,
94) -> Result<StiffnessMatrix, IntegrateError> {
95    #[cfg(feature = "gpu_fem")]
96    if config.use_gpu && mesh_elements.len() >= config.gpu_threshold {
97        // GPU path for mesh-structured assembly (stub — falls through)
98        let _ = (&mesh_elements, n_nodes); // suppress unused warning in stub
99    }
100
101    crate::gpu_fem::stiffness::assemble_stiffness_mesh(mesh_elements, d_matrix, n_nodes)
102}
103
104// ─────────────────────────────────────────────────────────────────────────────
105// Tests
106// ─────────────────────────────────────────────────────────────────────────────
107
108#[cfg(test)]
109mod tests {
110    use super::*;
111    use scirs2_core::ndarray::array;
112
113    fn isotropic_d() -> Array2<f64> {
114        let e = 1.0_f64;
115        let nu = 0.3_f64;
116        let c = e / (1.0 - nu * nu);
117        array![
118            [c, c * nu, 0.0],
119            [c * nu, c, 0.0],
120            [0.0, 0.0, c * (1.0 - nu) / 2.0],
121        ]
122    }
123
124    fn single_triangle() -> Vec<Element2D> {
125        vec![Element2D {
126            nodes: [[0.0, 0.0], [1.0, 0.0], [0.0, 1.0]],
127            material_id: 0,
128        }]
129    }
130
131    #[test]
132    fn test_auto_dispatch_routes_to_cpu_below_threshold() {
133        let d = isotropic_d();
134        let elements = single_triangle();
135        let config = FemAssemblyConfig {
136            gpu_threshold: 10_000,
137            use_gpu: true,
138        };
139        // With only 1 element, must use CPU path regardless
140        let km = assemble_stiffness_auto(&elements, &d, 3, &config).expect("auto dispatch failed");
141        assert!(!km.vals.is_empty());
142    }
143
144    #[test]
145    fn test_auto_dispatch_use_gpu_false() {
146        let d = isotropic_d();
147        let elements = single_triangle();
148        let config = FemAssemblyConfig {
149            gpu_threshold: 0,
150            use_gpu: false,
151        };
152        let km = assemble_stiffness_auto(&elements, &d, 3, &config).expect("auto dispatch failed");
153        assert!(!km.vals.is_empty());
154    }
155
156    #[test]
157    fn test_auto_dispatch_large_mesh_cpu_fallback() {
158        // 10 elements, each treated as 3 independent nodes → n_nodes = 30
159        let d = isotropic_d();
160        let config = FemAssemblyConfig {
161            gpu_threshold: 5,
162            use_gpu: false, // force CPU even if above threshold
163        };
164        let n_elems = 10_usize;
165        let elements: Vec<Element2D> = (0..n_elems)
166            .map(|k| Element2D {
167                nodes: [
168                    [k as f64, 0.0],
169                    [k as f64 + 1.0, 0.0],
170                    [k as f64 + 0.5, 1.0],
171                ],
172                material_id: 0,
173            })
174            .collect();
175        // n_nodes must be 3 * n_elems so every element's DOFs fit within [0, 2*n_nodes)
176        let km = assemble_stiffness_auto(&elements, &d, 3 * n_elems, &config)
177            .expect("auto dispatch failed");
178        assert!(!km.vals.is_empty());
179    }
180
181    #[test]
182    fn test_cpu_two_element_assembly_distinct_dofs() {
183        // Confirm that two Element2D assemblies write to *different* DOF ranges.
184        // Element 0 → DOFs [0..5], element 1 → DOFs [6..11].
185        let d = isotropic_d();
186        let config = FemAssemblyConfig {
187            gpu_threshold: 10_000,
188            use_gpu: false,
189        };
190        let elements = vec![
191            Element2D {
192                nodes: [[0.0, 0.0], [1.0, 0.0], [0.0, 1.0]],
193                material_id: 0,
194            },
195            Element2D {
196                nodes: [[1.0, 0.0], [1.0, 1.0], [0.0, 1.0]],
197                material_id: 0,
198            },
199        ];
200        // 2 elements × 3 nodes = 6 independent nodes → n_nodes = 6
201        let km = assemble_stiffness_auto(&elements, &d, 6, &config)
202            .expect("two-element assembly failed");
203        let dense = km.to_dense();
204        // Use view to avoid move into closures
205        let view = dense.view();
206        // Top-left 6×6 block (element 0 DOFs) must be non-zero
207        let block0_norm: f64 = (0..6)
208            .flat_map(|i| (0..6).map(move |j| (i, j)))
209            .map(|(i, j)| view[[i, j]].abs())
210            .sum();
211        // Bottom-right 6×6 block (element 1 DOFs) must also be non-zero
212        let block1_norm: f64 = (6..12)
213            .flat_map(|i| (6..12).map(move |j| (i, j)))
214            .map(|(i, j)| view[[i, j]].abs())
215            .sum();
216        assert!(block0_norm > 0.0, "Element 0 DOF block is zero");
217        assert!(
218            block1_norm > 0.0,
219            "Element 1 DOF block is zero — DOFs were not distinct"
220        );
221        // Cross-block must be zero (no shared nodes between independent elements)
222        let cross_norm: f64 = (0..6)
223            .flat_map(|i| (6..12).map(move |j| (i, j)))
224            .map(|(i, j)| view[[i, j]].abs())
225            .sum();
226        assert!(
227            cross_norm < 1e-12,
228            "Cross-block is non-zero — shared DOFs detected unexpectedly"
229        );
230    }
231}