Skip to main content

rustsim_core/
soa.rs

1//! SoA (Structure-of-Arrays) extraction for GPU-friendly data layout.
2//!
3//! ABM agents are stored as AoS (Array-of-Structures) in `AgentStore`.
4//! For GPU kernels, we need flat contiguous arrays of each field.
5//! `SoaExtractable` lets agent types define how to extract/write-back
6//! their data to/from flat `f32` buffers suitable for GPU upload.
7//!
8//! For workloads that require `f64` precision (scientific simulations,
9//! long time horizons, or ill-conditioned dynamics), implement
10//! [`SoaExtractableF64`] instead of or alongside [`SoaExtractable`].
11//! The two traits are independent: an agent type may implement either,
12//! both, or neither. CUDA kernels that target `f64` must declare `double`
13//! parameters.
14
15use crate::agent::Agent;
16use crate::store::AgentStore;
17use crate::types::AgentId;
18
19/// Trait for agents whose numeric fields can be extracted into SoA buffers
20/// and written back from SoA buffers after GPU computation.
21///
22/// Each "column" is a `Vec<f32>` representing one field across all agents.
23/// The order of agents in the buffers matches the order of IDs returned
24/// by `extract`.
25///
26/// # Example
27///
28/// ```ignore
29/// impl SoaExtractable for Particle {
30///     fn num_columns() -> usize { 2 } // x, vx
31///     fn column_names() -> Vec<&'static str> { vec!["x", "vx"] }
32///     fn extract_row(&self, columns: &mut [Vec<f32>]) {
33///         columns[0].push(self.x);
34///         columns[1].push(self.vx);
35///     }
36///     fn write_back_row(&mut self, columns: &[&[f32]], row: usize) {
37///         self.x = columns[0][row];
38///         self.vx = columns[1][row];
39///     }
40/// }
41/// ```
42pub trait SoaExtractable: Agent {
43    /// Number of f32 columns to extract.
44    fn num_columns() -> usize;
45
46    /// Human-readable names for each column (for debugging / PTX variable naming).
47    fn column_names() -> Vec<&'static str>;
48
49    /// Push this agent's values into the column vectors.
50    fn extract_row(&self, columns: &mut [Vec<f32>]);
51
52    /// Read this agent's values back from the column slices at `row`.
53    fn write_back_row(&mut self, columns: &[&[f32]], row: usize);
54}
55
56/// Extract SoA buffers from an `AgentStore`.
57///
58/// Returns `(ids, columns)` where:
59/// - `ids[i]` is the `AgentId` for row `i`
60/// - `columns[c][i]` is column `c` for agent `i`
61pub fn extract_soa<A, S>(store: &S) -> (Vec<AgentId>, Vec<Vec<f32>>)
62where
63    A: SoaExtractable,
64    S: AgentStore<A>,
65{
66    let ids = store.iter_ids();
67    let n = ids.len();
68    let nc = A::num_columns();
69    let mut columns: Vec<Vec<f32>> = (0..nc).map(|_| Vec::with_capacity(n)).collect();
70
71    for &id in &ids {
72        if let Some(agent) = store.get(id) {
73            agent.extract_row(&mut columns);
74        }
75    }
76
77    (ids, columns)
78}
79
80/// Write SoA buffers back into an `AgentStore`.
81///
82/// `ids` and `columns` must have the same row count and match the
83/// order returned by `extract_soa`.
84pub fn write_back_soa<A, S>(store: &S, ids: &[AgentId], columns: &[Vec<f32>])
85where
86    A: SoaExtractable,
87    S: AgentStore<A>,
88{
89    let col_refs: Vec<&[f32]> = columns.iter().map(|c| c.as_slice()).collect();
90    for (row, &id) in ids.iter().enumerate() {
91        if let Some(mut agent) = store.get_mut(id) {
92            agent.write_back_row(&col_refs, row);
93        }
94    }
95}
96
97// -----------------------------------------------------------------------
98// f64 path
99// -----------------------------------------------------------------------
100
101/// Like [`SoaExtractable`] but using `f64` columns.
102///
103/// Implement this trait when your kernel needs double precision — e.g.
104/// scientific workloads, long-horizon integrators, or ill-conditioned
105/// dynamics. `f32` remains the default for parity with the CUDA batch path
106/// and most ABM workloads.
107///
108/// An agent type may implement both [`SoaExtractable`] and
109/// [`SoaExtractableF64`] independently; the two extraction paths do not
110/// interact. Use [`cast_columns_f64_to_f32`] as a convenience when you want
111/// to downcast an `f64` extraction result to `f32` for an existing `f32`
112/// kernel.
113pub trait SoaExtractableF64: Agent {
114    /// Number of `f64` columns to extract.
115    fn num_columns() -> usize;
116
117    /// Human-readable names for each column.
118    fn column_names() -> Vec<&'static str>;
119
120    /// Push this agent's values into the column vectors.
121    fn extract_row(&self, columns: &mut [Vec<f64>]);
122
123    /// Read this agent's values back from the column slices at `row`.
124    fn write_back_row(&mut self, columns: &[&[f64]], row: usize);
125}
126
127/// Extract `f64` SoA buffers from an `AgentStore`.
128pub fn extract_soa_f64<A, S>(store: &S) -> (Vec<AgentId>, Vec<Vec<f64>>)
129where
130    A: SoaExtractableF64,
131    S: AgentStore<A>,
132{
133    let ids = store.iter_ids();
134    let n = ids.len();
135    let nc = <A as SoaExtractableF64>::num_columns();
136    let mut columns: Vec<Vec<f64>> = (0..nc).map(|_| Vec::with_capacity(n)).collect();
137
138    for &id in &ids {
139        if let Some(agent) = store.get(id) {
140            <A as SoaExtractableF64>::extract_row(&agent, &mut columns);
141        }
142    }
143
144    (ids, columns)
145}
146
147/// Write `f64` SoA buffers back into an `AgentStore`.
148pub fn write_back_soa_f64<A, S>(store: &S, ids: &[AgentId], columns: &[Vec<f64>])
149where
150    A: SoaExtractableF64,
151    S: AgentStore<A>,
152{
153    let col_refs: Vec<&[f64]> = columns.iter().map(|c| c.as_slice()).collect();
154    for (row, &id) in ids.iter().enumerate() {
155        if let Some(mut agent) = store.get_mut(id) {
156            <A as SoaExtractableF64>::write_back_row(&mut agent, &col_refs, row);
157        }
158    }
159}
160
161/// Convenience helper: downcast `f64` columns to `f32` columns.
162///
163/// Useful when you have an `f64` extraction path but want to feed an
164/// existing `f32` kernel. Precision loss is explicit and at the call site.
165pub fn cast_columns_f64_to_f32(columns: &[Vec<f64>]) -> Vec<Vec<f32>> {
166    columns
167        .iter()
168        .map(|c| c.iter().map(|v| *v as f32).collect())
169        .collect()
170}
171
172/// Convenience helper: upcast `f32` columns to `f64` columns.
173pub fn cast_columns_f32_to_f64(columns: &[Vec<f32>]) -> Vec<Vec<f64>> {
174    columns
175        .iter()
176        .map(|c| c.iter().map(|v| *v as f64).collect())
177        .collect()
178}
179
180#[cfg(test)]
181mod tests_f64 {
182    use super::*;
183    use crate::store::{AgentStore, HashMapStore};
184
185    #[derive(Clone, Debug)]
186    struct P {
187        id: AgentId,
188        x: f64,
189        vx: f64,
190    }
191
192    impl Agent for P {
193        fn id(&self) -> AgentId {
194            self.id
195        }
196    }
197
198    impl SoaExtractableF64 for P {
199        fn num_columns() -> usize {
200            2
201        }
202        fn column_names() -> Vec<&'static str> {
203            vec!["x", "vx"]
204        }
205        fn extract_row(&self, columns: &mut [Vec<f64>]) {
206            columns[0].push(self.x);
207            columns[1].push(self.vx);
208        }
209        fn write_back_row(&mut self, columns: &[&[f64]], row: usize) {
210            self.x = columns[0][row];
211            self.vx = columns[1][row];
212        }
213    }
214
215    #[test]
216    fn extract_and_write_back_preserve_f64_precision() {
217        let mut store: HashMapStore<P> = HashMapStore::new();
218        // A value that loses precision when round-tripped through f32.
219        let precise = 1.0_f64 + 1.0e-10_f64;
220        store.insert(P {
221            id: 1,
222            x: precise,
223            vx: 2.0,
224        });
225
226        let (ids, mut cols) = extract_soa_f64::<P, _>(&store);
227        assert_eq!(ids.len(), 1);
228        assert_eq!(cols[0][0], precise);
229
230        cols[0][0] = precise * 2.0;
231        write_back_soa_f64::<P, _>(&store, &ids, &cols);
232
233        let a = store.get(1).unwrap();
234        assert_eq!(a.x, precise * 2.0);
235    }
236
237    #[test]
238    fn f32_f64_cast_round_trip_loses_low_bits() {
239        let f64_cols = vec![vec![1.0_f64 + 1.0e-10_f64]];
240        let f32_cols = cast_columns_f64_to_f32(&f64_cols);
241        let back = cast_columns_f32_to_f64(&f32_cols);
242        // f32 cannot represent 1.0 + 1e-10, so the round-trip rounds to 1.0.
243        assert_eq!(back[0][0], 1.0);
244    }
245}