1use std::sync::Arc;
17
18use shape_value::heap_value::MatrixData;
19
20pub const MATRIX_DATA_OFFSET: i32 = 0;
21pub const MATRIX_ROWS_OFFSET: i32 = 8;
22pub const MATRIX_COLS_OFFSET: i32 = 12;
23pub const MATRIX_TOTAL_LEN_OFFSET: i32 = 16;
24pub const MATRIX_OWNER_OFFSET: i32 = 24;
25
26#[repr(C)]
32pub struct JitMatrix {
33 pub data: *const f64,
36 pub rows: u32,
38 pub cols: u32,
40 pub total_len: u64,
42 owner: *const MatrixData,
45}
46
47impl JitMatrix {
48 pub fn from_arc(arc: &Arc<MatrixData>) -> Self {
53 let mat = arc.as_ref();
54 let data = mat.data.as_slice().as_ptr();
55 let rows = mat.rows;
56 let cols = mat.cols;
57 let total_len = mat.data.len() as u64;
58 let owner = Arc::into_raw(Arc::clone(arc));
60 Self {
61 data,
62 rows,
63 cols,
64 total_len,
65 owner,
66 }
67 }
68
69 pub fn to_arc(&self) -> Arc<MatrixData> {
74 assert!(!self.owner.is_null(), "JitMatrix has null owner");
75 let arc = unsafe { Arc::from_raw(self.owner) };
77 let cloned = Arc::clone(&arc);
78 std::mem::forget(arc);
80 cloned
81 }
82}
83
84impl Drop for JitMatrix {
85 fn drop(&mut self) {
86 if !self.owner.is_null() {
87 unsafe {
89 let _ = Arc::from_raw(self.owner);
90 }
91 }
92 }
93}
94
95#[cfg(test)]
96mod tests {
97 use super::*;
98 use shape_value::aligned_vec::AlignedVec;
99
100 fn make_test_matrix(rows: u32, cols: u32) -> Arc<MatrixData> {
101 let n = (rows as usize) * (cols as usize);
102 let mut data = AlignedVec::with_capacity(n);
103 for i in 0..n {
104 data.push(i as f64);
105 }
106 Arc::new(MatrixData::from_flat(data, rows, cols))
107 }
108
109 #[test]
110 fn test_layout() {
111 assert_eq!(std::mem::offset_of!(JitMatrix, data), MATRIX_DATA_OFFSET as usize);
112 assert_eq!(std::mem::offset_of!(JitMatrix, rows), MATRIX_ROWS_OFFSET as usize);
113 assert_eq!(std::mem::offset_of!(JitMatrix, cols), MATRIX_COLS_OFFSET as usize);
114 assert_eq!(std::mem::offset_of!(JitMatrix, total_len), MATRIX_TOTAL_LEN_OFFSET as usize);
115 assert_eq!(std::mem::offset_of!(JitMatrix, owner), MATRIX_OWNER_OFFSET as usize);
116 assert_eq!(std::mem::size_of::<JitMatrix>(), 32);
117 }
118
119 #[test]
120 fn test_round_trip() {
121 let arc = make_test_matrix(3, 4);
122 let jm = JitMatrix::from_arc(&arc);
123 assert_eq!(jm.rows, 3);
124 assert_eq!(jm.cols, 4);
125 assert_eq!(jm.total_len, 12);
126
127 let slice = unsafe { std::slice::from_raw_parts(jm.data, jm.total_len as usize) };
129 assert_eq!(slice[0], 0.0);
130 assert_eq!(slice[11], 11.0);
131
132 let recovered = jm.to_arc();
134 assert_eq!(recovered.rows, 3);
135 assert_eq!(recovered.cols, 4);
136 assert_eq!(recovered.data[0], 0.0);
137 assert_eq!(recovered.data[11], 11.0);
138
139 assert_eq!(arc.data[5], 5.0);
141 }
142
143 #[test]
144 fn test_arc_refcount() {
145 let arc = make_test_matrix(2, 2);
146 assert_eq!(Arc::strong_count(&arc), 1);
147
148 let jm = JitMatrix::from_arc(&arc);
149 assert_eq!(Arc::strong_count(&arc), 2); let recovered = jm.to_arc();
152 assert_eq!(Arc::strong_count(&arc), 3); drop(recovered);
155 assert_eq!(Arc::strong_count(&arc), 2);
156
157 drop(jm);
158 assert_eq!(Arc::strong_count(&arc), 1); }
160}