Skip to main content

scry_learn/
matrix.rs

1// SPDX-License-Identifier: MIT OR Apache-2.0
2//! Contiguous column-major dense matrix for ML workloads.
3//!
4//! [`DenseMatrix`] stores all feature data in a single `Vec<f64>` with
5//! column-major layout (`data[col * n_rows + row]`), giving cache-friendly
6//! column access and eliminating per-column heap allocations.
7
8use crate::error::{Result, ScryLearnError};
9
10/// A contiguous, column-major dense matrix.
11///
12/// Layout: `data[col * n_rows + row]`.
13///
14/// This replaces `Vec<Vec<f64>>` for feature storage, providing:
15/// - Zero-cost column slicing via [`col`](Self::col)
16/// - Single contiguous allocation instead of N+1 heap blocks
17/// - Cache-friendly access patterns for column-oriented ML algorithms
18#[derive(Clone, Debug)]
19#[non_exhaustive]
20pub struct DenseMatrix {
21    /// Flat storage in column-major order.
22    data: Vec<f64>,
23    /// Number of rows (samples).
24    n_rows: usize,
25    /// Number of columns (features).
26    n_cols: usize,
27}
28
29impl DenseMatrix {
30    /// Create a matrix from a flat column-major buffer.
31    ///
32    /// Returns an error if `data.len() != n_rows * n_cols`.
33    pub fn new(data: Vec<f64>, n_rows: usize, n_cols: usize) -> Result<Self> {
34        if data.len() != n_rows * n_cols {
35            return Err(ScryLearnError::InvalidParameter(format!(
36                "DenseMatrix::new: data.len()={} but n_rows*n_cols={}",
37                data.len(),
38                n_rows * n_cols,
39            )));
40        }
41        Ok(Self {
42            data,
43            n_rows,
44            n_cols,
45        })
46    }
47
48    /// Create a matrix of all zeros.
49    pub fn zeros(n_rows: usize, n_cols: usize) -> Self {
50        Self {
51            data: vec![0.0; n_rows * n_cols],
52            n_rows,
53            n_cols,
54        }
55    }
56
57    /// Build from column-major `Vec<Vec<f64>>` (each inner vec is one column).
58    ///
59    /// Returns an error if columns have different lengths.
60    pub fn from_col_major(cols: Vec<Vec<f64>>) -> Result<Self> {
61        if cols.is_empty() {
62            return Ok(Self {
63                data: Vec::new(),
64                n_rows: 0,
65                n_cols: 0,
66            });
67        }
68        let n_rows = cols[0].len();
69        let n_cols = cols.len();
70        for (i, col) in cols.iter().enumerate() {
71            if col.len() != n_rows {
72                return Err(ScryLearnError::InvalidParameter(format!(
73                    "DenseMatrix::from_col_major: column {i} has {} rows, expected {n_rows}",
74                    col.len(),
75                )));
76            }
77        }
78        let mut data = Vec::with_capacity(n_rows * n_cols);
79        for col in &cols {
80            data.extend_from_slice(col);
81        }
82        Ok(Self {
83            data,
84            n_rows,
85            n_cols,
86        })
87    }
88
89    /// Build from row-major data, transposing into column-major storage.
90    pub fn from_row_major(rows: &[&[f64]], n_rows: usize, n_cols: usize) -> Self {
91        let mut data = vec![0.0; n_rows * n_cols];
92        for (i, row) in rows.iter().enumerate() {
93            for (j, &val) in row.iter().enumerate() {
94                data[j * n_rows + i] = val;
95            }
96        }
97        Self {
98            data,
99            n_rows,
100            n_cols,
101        }
102    }
103
104    /// Zero-cost slice of column `j`.
105    #[inline]
106    pub fn col(&self, j: usize) -> &[f64] {
107        let start = j * self.n_rows;
108        &self.data[start..start + self.n_rows]
109    }
110
111    /// Mutable slice of column `j`.
112    #[inline]
113    pub fn col_mut(&mut self, j: usize) -> &mut [f64] {
114        let start = j * self.n_rows;
115        &mut self.data[start..start + self.n_rows]
116    }
117
118    /// Get a single element.
119    #[inline]
120    pub fn get(&self, row: usize, col: usize) -> f64 {
121        self.data[col * self.n_rows + row]
122    }
123
124    /// Set a single element.
125    #[inline]
126    pub fn set(&mut self, row: usize, col: usize, val: f64) {
127        self.data[col * self.n_rows + row] = val;
128    }
129
130    /// Number of rows.
131    #[inline]
132    pub fn n_rows(&self) -> usize {
133        self.n_rows
134    }
135
136    /// Number of columns.
137    #[inline]
138    pub fn n_cols(&self) -> usize {
139        self.n_cols
140    }
141
142    /// The raw flat buffer in column-major order.
143    #[inline]
144    pub fn as_slice(&self) -> &[f64] {
145        &self.data
146    }
147
148    /// Iterate over values in row `i` (strided access across columns).
149    pub fn row_iter(&self, i: usize) -> impl Iterator<Item = f64> + '_ {
150        (0..self.n_cols).map(move |j| self.data[j * self.n_rows + i])
151    }
152
153    /// Collect row `i` into a `Vec<f64>`.
154    pub fn row_to_vec(&self, i: usize) -> Vec<f64> {
155        self.row_iter(i).collect()
156    }
157
158    /// Build from a reference to column-major `&[Vec<f64>]` (no ownership transfer).
159    ///
160    /// Same as [`from_col_major`](Self::from_col_major) but borrows the columns
161    /// instead of consuming them, avoiding a clone of the outer `Vec`.
162    pub fn from_col_major_ref(cols: &[Vec<f64>]) -> Result<Self> {
163        if cols.is_empty() {
164            return Ok(Self {
165                data: Vec::new(),
166                n_rows: 0,
167                n_cols: 0,
168            });
169        }
170        let n_rows = cols[0].len();
171        let n_cols = cols.len();
172        for (i, col) in cols.iter().enumerate() {
173            if col.len() != n_rows {
174                return Err(ScryLearnError::InvalidParameter(format!(
175                    "DenseMatrix::from_col_major_ref: column {i} has {} rows, expected {n_rows}",
176                    col.len(),
177                )));
178            }
179        }
180        let mut data = Vec::with_capacity(n_rows * n_cols);
181        for col in cols {
182            data.extend_from_slice(col);
183        }
184        Ok(Self {
185            data,
186            n_rows,
187            n_cols,
188        })
189    }
190
191    /// Convert back to `Vec<Vec<f64>>` column-major (backward compat).
192    pub fn to_col_vecs(&self) -> Vec<Vec<f64>> {
193        (0..self.n_cols).map(|j| self.col(j).to_vec()).collect()
194    }
195}
196
197// ---------------------------------------------------------------------------
198// Conversions
199// ---------------------------------------------------------------------------
200
201impl From<Vec<Vec<f64>>> for DenseMatrix {
202    /// Convert from column-major `Vec<Vec<f64>>`. Panics on ragged input.
203    fn from(cols: Vec<Vec<f64>>) -> Self {
204        Self::from_col_major(cols).expect("ragged column vectors in DenseMatrix::from")
205    }
206}
207
208impl From<&[Vec<f64>]> for DenseMatrix {
209    fn from(cols: &[Vec<f64>]) -> Self {
210        let owned: Vec<Vec<f64>> = cols.to_vec();
211        Self::from(owned)
212    }
213}
214
215// ---------------------------------------------------------------------------
216// Serde support
217// ---------------------------------------------------------------------------
218
219#[cfg(feature = "serde")]
220impl serde::Serialize for DenseMatrix {
221    fn serialize<S: serde::Serializer>(
222        &self,
223        serializer: S,
224    ) -> std::result::Result<S::Ok, S::Error> {
225        use serde::ser::SerializeStruct;
226        let mut state = serializer.serialize_struct("DenseMatrix", 3)?;
227        state.serialize_field("data", &self.data)?;
228        state.serialize_field("n_rows", &self.n_rows)?;
229        state.serialize_field("n_cols", &self.n_cols)?;
230        state.end()
231    }
232}
233
234#[cfg(feature = "serde")]
235impl<'de> serde::Deserialize<'de> for DenseMatrix {
236    fn deserialize<D: serde::Deserializer<'de>>(
237        deserializer: D,
238    ) -> std::result::Result<Self, D::Error> {
239        #[derive(serde::Deserialize)]
240        struct Raw {
241            data: Vec<f64>,
242            n_rows: usize,
243            n_cols: usize,
244        }
245        let raw = Raw::deserialize(deserializer)?;
246        Self::new(raw.data, raw.n_rows, raw.n_cols).map_err(serde::de::Error::custom)
247    }
248}
249
250// ---------------------------------------------------------------------------
251// Tests
252// ---------------------------------------------------------------------------
253
254#[cfg(test)]
255#[allow(clippy::float_cmp)]
256mod tests {
257    use super::*;
258
259    #[test]
260    fn from_col_major_roundtrip() {
261        let cols = vec![vec![1.0, 2.0, 3.0], vec![4.0, 5.0, 6.0]];
262        let m = DenseMatrix::from_col_major(cols.clone()).unwrap();
263        assert_eq!(m.n_rows(), 3);
264        assert_eq!(m.n_cols(), 2);
265        assert_eq!(m.to_col_vecs(), cols);
266    }
267
268    #[test]
269    fn col_correctness() {
270        let m = DenseMatrix::from_col_major(vec![vec![1.0, 2.0], vec![3.0, 4.0], vec![5.0, 6.0]])
271            .unwrap();
272        assert_eq!(m.col(0), &[1.0, 2.0]);
273        assert_eq!(m.col(1), &[3.0, 4.0]);
274        assert_eq!(m.col(2), &[5.0, 6.0]);
275    }
276
277    #[test]
278    fn row_iter_correctness() {
279        let m =
280            DenseMatrix::from_col_major(vec![vec![1.0, 2.0, 3.0], vec![4.0, 5.0, 6.0]]).unwrap();
281        let row0: Vec<f64> = m.row_iter(0).collect();
282        assert_eq!(row0, vec![1.0, 4.0]);
283        let row2: Vec<f64> = m.row_iter(2).collect();
284        assert_eq!(row2, vec![3.0, 6.0]);
285    }
286
287    #[test]
288    fn get_set_indexing() {
289        let mut m = DenseMatrix::zeros(3, 2);
290        m.set(1, 0, 42.0);
291        m.set(2, 1, 99.0);
292        assert_eq!(m.get(1, 0), 42.0);
293        assert_eq!(m.get(2, 1), 99.0);
294        assert_eq!(m.get(0, 0), 0.0);
295    }
296
297    #[test]
298    fn from_vec_vec_conversion() {
299        let cols = vec![vec![10.0, 20.0], vec![30.0, 40.0]];
300        let m: DenseMatrix = cols.into();
301        assert_eq!(m.n_rows(), 2);
302        assert_eq!(m.n_cols(), 2);
303        assert_eq!(m.get(0, 0), 10.0);
304        assert_eq!(m.get(1, 1), 40.0);
305    }
306
307    #[test]
308    fn from_slice_conversion() {
309        let cols = vec![vec![1.0, 2.0], vec![3.0, 4.0]];
310        let m: DenseMatrix = cols.as_slice().into();
311        assert_eq!(m.col(0), &[1.0, 2.0]);
312    }
313
314    #[test]
315    fn empty_matrix() {
316        let m = DenseMatrix::from_col_major(vec![]).unwrap();
317        assert_eq!(m.n_rows(), 0);
318        assert_eq!(m.n_cols(), 0);
319        assert_eq!(m.as_slice(), &[] as &[f64]);
320    }
321
322    #[test]
323    fn zero_row_matrix() {
324        let m = DenseMatrix::from_col_major(vec![vec![], vec![]]).unwrap();
325        assert_eq!(m.n_rows(), 0);
326        assert_eq!(m.n_cols(), 2);
327    }
328
329    #[test]
330    fn single_column() {
331        let m = DenseMatrix::from_col_major(vec![vec![1.0, 2.0, 3.0]]).unwrap();
332        assert_eq!(m.n_cols(), 1);
333        assert_eq!(m.col(0), &[1.0, 2.0, 3.0]);
334        assert_eq!(m.row_to_vec(1), vec![2.0]);
335    }
336
337    #[test]
338    fn ragged_error() {
339        let result = DenseMatrix::from_col_major(vec![vec![1.0, 2.0], vec![3.0]]);
340        assert!(result.is_err());
341    }
342
343    #[test]
344    fn new_validates_length() {
345        assert!(DenseMatrix::new(vec![1.0, 2.0, 3.0], 2, 2).is_err());
346        assert!(DenseMatrix::new(vec![1.0, 2.0, 3.0, 4.0], 2, 2).is_ok());
347    }
348
349    #[test]
350    fn from_row_major_transposes() {
351        let rows: Vec<&[f64]> = vec![&[1.0, 2.0], &[3.0, 4.0], &[5.0, 6.0]];
352        let m = DenseMatrix::from_row_major(&rows, 3, 2);
353        // Column 0 should be [1, 3, 5], column 1 should be [2, 4, 6]
354        assert_eq!(m.col(0), &[1.0, 3.0, 5.0]);
355        assert_eq!(m.col(1), &[2.0, 4.0, 6.0]);
356    }
357
358    #[test]
359    fn col_mut_works() {
360        let mut m = DenseMatrix::zeros(3, 2);
361        let col = m.col_mut(1);
362        col[0] = 10.0;
363        col[1] = 20.0;
364        col[2] = 30.0;
365        assert_eq!(m.col(1), &[10.0, 20.0, 30.0]);
366    }
367
368    #[cfg(feature = "serde")]
369    #[test]
370    fn serde_roundtrip() {
371        let m = DenseMatrix::from_col_major(vec![vec![1.0, 2.0], vec![3.0, 4.0]]).unwrap();
372        let json = serde_json::to_string(&m).unwrap();
373        let m2: DenseMatrix = serde_json::from_str(&json).unwrap();
374        assert_eq!(m.as_slice(), m2.as_slice());
375        assert_eq!(m.n_rows(), m2.n_rows());
376        assert_eq!(m.n_cols(), m2.n_cols());
377    }
378}