scirs2_integrate/gpu_fem/
dispatch.rs1use crate::error::IntegrateError;
7use crate::gpu_fem::stiffness::{
8 assemble_stiffness_cpu, Element2D, MeshElement2D, StiffnessMatrix,
9};
10use scirs2_core::ndarray::Array2;
11
12#[derive(Debug, Clone)]
18pub struct FemAssemblyConfig {
19 pub gpu_threshold: usize,
21 pub use_gpu: bool,
23}
24
25impl Default for FemAssemblyConfig {
26 fn default() -> Self {
27 FemAssemblyConfig {
28 gpu_threshold: 10_000,
29 use_gpu: true,
30 }
31 }
32}
33
34#[derive(Debug, Clone)]
40pub enum GpuFemError {
41 GpuNotAvailable,
43 AssemblyFailed(String),
45}
46
47impl std::fmt::Display for GpuFemError {
48 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
49 match self {
50 GpuFemError::GpuNotAvailable => write!(f, "GPU not available"),
51 GpuFemError::AssemblyFailed(msg) => write!(f, "GPU assembly failed: {msg}"),
52 }
53 }
54}
55
56impl std::error::Error for GpuFemError {}
57
58pub fn assemble_stiffness_auto(
69 elements: &[Element2D],
70 d_matrix: &Array2<f64>,
71 n_nodes: usize,
72 config: &FemAssemblyConfig,
73) -> Result<StiffnessMatrix, IntegrateError> {
74 #[cfg(feature = "gpu_fem")]
75 if config.use_gpu && elements.len() >= config.gpu_threshold {
76 match crate::gpu_fem::wgpu_backend::assemble_stiffness_gpu(elements, d_matrix, n_nodes) {
77 Ok(result) => return Ok(result),
78 Err(_) => {
79 }
81 }
82 }
83
84 assemble_stiffness_cpu(elements, d_matrix, n_nodes)
86}
87
88pub fn assemble_stiffness_mesh_auto(
90 mesh_elements: &[MeshElement2D],
91 d_matrix: &Array2<f64>,
92 n_nodes: usize,
93 config: &FemAssemblyConfig,
94) -> Result<StiffnessMatrix, IntegrateError> {
95 #[cfg(feature = "gpu_fem")]
96 if config.use_gpu && mesh_elements.len() >= config.gpu_threshold {
97 let _ = (&mesh_elements, n_nodes); }
100
101 crate::gpu_fem::stiffness::assemble_stiffness_mesh(mesh_elements, d_matrix, n_nodes)
102}
103
104#[cfg(test)]
109mod tests {
110 use super::*;
111 use scirs2_core::ndarray::array;
112
113 fn isotropic_d() -> Array2<f64> {
114 let e = 1.0_f64;
115 let nu = 0.3_f64;
116 let c = e / (1.0 - nu * nu);
117 array![
118 [c, c * nu, 0.0],
119 [c * nu, c, 0.0],
120 [0.0, 0.0, c * (1.0 - nu) / 2.0],
121 ]
122 }
123
124 fn single_triangle() -> Vec<Element2D> {
125 vec![Element2D {
126 nodes: [[0.0, 0.0], [1.0, 0.0], [0.0, 1.0]],
127 material_id: 0,
128 }]
129 }
130
131 #[test]
132 fn test_auto_dispatch_routes_to_cpu_below_threshold() {
133 let d = isotropic_d();
134 let elements = single_triangle();
135 let config = FemAssemblyConfig {
136 gpu_threshold: 10_000,
137 use_gpu: true,
138 };
139 let km = assemble_stiffness_auto(&elements, &d, 3, &config).expect("auto dispatch failed");
141 assert!(!km.vals.is_empty());
142 }
143
144 #[test]
145 fn test_auto_dispatch_use_gpu_false() {
146 let d = isotropic_d();
147 let elements = single_triangle();
148 let config = FemAssemblyConfig {
149 gpu_threshold: 0,
150 use_gpu: false,
151 };
152 let km = assemble_stiffness_auto(&elements, &d, 3, &config).expect("auto dispatch failed");
153 assert!(!km.vals.is_empty());
154 }
155
156 #[test]
157 fn test_auto_dispatch_large_mesh_cpu_fallback() {
158 let d = isotropic_d();
160 let config = FemAssemblyConfig {
161 gpu_threshold: 5,
162 use_gpu: false, };
164 let n_elems = 10_usize;
165 let elements: Vec<Element2D> = (0..n_elems)
166 .map(|k| Element2D {
167 nodes: [
168 [k as f64, 0.0],
169 [k as f64 + 1.0, 0.0],
170 [k as f64 + 0.5, 1.0],
171 ],
172 material_id: 0,
173 })
174 .collect();
175 let km = assemble_stiffness_auto(&elements, &d, 3 * n_elems, &config)
177 .expect("auto dispatch failed");
178 assert!(!km.vals.is_empty());
179 }
180
181 #[test]
182 fn test_cpu_two_element_assembly_distinct_dofs() {
183 let d = isotropic_d();
186 let config = FemAssemblyConfig {
187 gpu_threshold: 10_000,
188 use_gpu: false,
189 };
190 let elements = vec![
191 Element2D {
192 nodes: [[0.0, 0.0], [1.0, 0.0], [0.0, 1.0]],
193 material_id: 0,
194 },
195 Element2D {
196 nodes: [[1.0, 0.0], [1.0, 1.0], [0.0, 1.0]],
197 material_id: 0,
198 },
199 ];
200 let km = assemble_stiffness_auto(&elements, &d, 6, &config)
202 .expect("two-element assembly failed");
203 let dense = km.to_dense();
204 let view = dense.view();
206 let block0_norm: f64 = (0..6)
208 .flat_map(|i| (0..6).map(move |j| (i, j)))
209 .map(|(i, j)| view[[i, j]].abs())
210 .sum();
211 let block1_norm: f64 = (6..12)
213 .flat_map(|i| (6..12).map(move |j| (i, j)))
214 .map(|(i, j)| view[[i, j]].abs())
215 .sum();
216 assert!(block0_norm > 0.0, "Element 0 DOF block is zero");
217 assert!(
218 block1_norm > 0.0,
219 "Element 1 DOF block is zero — DOFs were not distinct"
220 );
221 let cross_norm: f64 = (0..6)
223 .flat_map(|i| (6..12).map(move |j| (i, j)))
224 .map(|(i, j)| view[[i, j]].abs())
225 .sum();
226 assert!(
227 cross_norm < 1e-12,
228 "Cross-block is non-zero — shared DOFs detected unexpectedly"
229 );
230 }
231}