Skip to main content

tensorlogic_scirs_backend/
tensor_io.rs

1//! Tensor binary serialization and deserialization.
2//!
3//! Saves and loads `ArrayD<f64>` tensors in a simple binary format:
4//! `[magic(4)] [version(1)] [ndim(4)] [shape(ndim*8)] [data(nelems*8)]`
5//!
6//! For multi-tensor files:
7//! `[count(4)] [name_len(4)][name(bytes)][tensor]...`
8
9use scirs2_core::ndarray::{ArrayD, IxDyn};
10use std::io::{BufReader, BufWriter, Read, Write};
11use std::path::Path;
12use thiserror::Error;
13
14/// Magic bytes identifying the TensorLogic Tensor Format.
15const MAGIC: &[u8; 4] = b"TLTF";
16
17/// Current format version.
18const VERSION: u8 = 1;
19
20/// Errors that can occur during tensor I/O operations.
21#[derive(Debug, Error)]
22pub enum TensorIoError {
23    /// An underlying I/O error occurred.
24    #[error("IO error: {0}")]
25    Io(#[from] std::io::Error),
26
27    /// The file does not start with the expected magic bytes.
28    #[error("Invalid magic bytes")]
29    InvalidMagic,
30
31    /// The file version is not supported by this implementation.
32    #[error("Unsupported version: {0}")]
33    UnsupportedVersion(u8),
34
35    /// The number of elements implied by the shape does not match the data.
36    #[error("Shape mismatch: expected {expected} elements, got {got}")]
37    ShapeMismatch { expected: usize, got: usize },
38}
39
40/// Header metadata for a serialized tensor.
41#[derive(Debug, Clone)]
42pub struct TensorHeader {
43    /// Number of dimensions.
44    pub ndim: usize,
45    /// Shape of each dimension.
46    pub shape: Vec<usize>,
47    /// Total number of elements (product of shape).
48    pub element_count: usize,
49    /// Size of the data section in bytes (`element_count * 8`).
50    pub size_bytes: usize,
51}
52
53impl TensorHeader {
54    /// Create a header from an existing tensor.
55    pub fn from_tensor(tensor: &ArrayD<f64>) -> Self {
56        let shape: Vec<usize> = tensor.shape().to_vec();
57        let element_count = tensor.len();
58        Self {
59            ndim: shape.len(),
60            shape,
61            element_count,
62            size_bytes: element_count * 8,
63        }
64    }
65}
66
67/// Save a tensor to a binary file at the given path.
68pub fn save_tensor(path: &Path, tensor: &ArrayD<f64>) -> Result<(), TensorIoError> {
69    let file = std::fs::File::create(path)?;
70    let mut writer = BufWriter::new(file);
71    write_tensor(&mut writer, tensor)?;
72    writer.flush()?;
73    Ok(())
74}
75
76/// Load a tensor from a binary file at the given path.
77pub fn load_tensor(path: &Path) -> Result<ArrayD<f64>, TensorIoError> {
78    let file = std::fs::File::open(path)?;
79    let mut reader = BufReader::new(file);
80    read_tensor(&mut reader)
81}
82
83/// Write a tensor to any [`Write`] implementation.
84pub fn write_tensor<W: Write>(writer: &mut W, tensor: &ArrayD<f64>) -> Result<(), TensorIoError> {
85    // Magic
86    writer.write_all(MAGIC)?;
87    // Version
88    writer.write_all(&[VERSION])?;
89
90    let shape = tensor.shape();
91    let ndim = shape.len() as u32;
92    // ndim as little-endian u32
93    writer.write_all(&ndim.to_le_bytes())?;
94
95    // shape: each dimension as little-endian u64
96    for &dim in shape {
97        writer.write_all(&(dim as u64).to_le_bytes())?;
98    }
99
100    // Data: iterate in standard (row-major) order, write each f64 as little-endian
101    for &value in tensor.iter() {
102        writer.write_all(&value.to_le_bytes())?;
103    }
104
105    Ok(())
106}
107
108/// Read a tensor from any [`Read`] implementation.
109pub fn read_tensor<R: Read>(reader: &mut R) -> Result<ArrayD<f64>, TensorIoError> {
110    let header = read_header(reader)?;
111
112    // Read data
113    let mut data = vec![0u8; header.element_count * 8];
114    reader.read_exact(&mut data)?;
115
116    let values: Vec<f64> = data
117        .chunks_exact(8)
118        .map(|chunk| {
119            let mut bytes = [0u8; 8];
120            bytes.copy_from_slice(chunk);
121            f64::from_le_bytes(bytes)
122        })
123        .collect();
124
125    if values.len() != header.element_count {
126        return Err(TensorIoError::ShapeMismatch {
127            expected: header.element_count,
128            got: values.len(),
129        });
130    }
131
132    let tensor = ArrayD::from_shape_vec(IxDyn(&header.shape), values).map_err(|_| {
133        TensorIoError::ShapeMismatch {
134            expected: header.element_count,
135            got: 0,
136        }
137    })?;
138
139    Ok(tensor)
140}
141
142/// Read just the header from a reader without consuming the data section.
143pub fn read_header<R: Read>(reader: &mut R) -> Result<TensorHeader, TensorIoError> {
144    // Magic
145    let mut magic = [0u8; 4];
146    reader.read_exact(&mut magic)?;
147    if &magic != MAGIC {
148        return Err(TensorIoError::InvalidMagic);
149    }
150
151    // Version
152    let mut ver = [0u8; 1];
153    reader.read_exact(&mut ver)?;
154    if ver[0] != VERSION {
155        return Err(TensorIoError::UnsupportedVersion(ver[0]));
156    }
157
158    // ndim
159    let mut ndim_bytes = [0u8; 4];
160    reader.read_exact(&mut ndim_bytes)?;
161    let ndim = u32::from_le_bytes(ndim_bytes) as usize;
162
163    // shape
164    let mut shape = Vec::with_capacity(ndim);
165    for _ in 0..ndim {
166        let mut dim_bytes = [0u8; 8];
167        reader.read_exact(&mut dim_bytes)?;
168        shape.push(u64::from_le_bytes(dim_bytes) as usize);
169    }
170
171    let element_count = shape.iter().copied().product::<usize>().max(1);
172    // For 0-d tensors (scalar), element_count is 1
173    let element_count = if ndim == 0 { 1 } else { element_count };
174
175    Ok(TensorHeader {
176        ndim,
177        shape,
178        element_count,
179        size_bytes: element_count * 8,
180    })
181}
182
183/// Save multiple named tensors to a single binary file.
184///
185/// Format: `[count(4)] [name_len(4)][name(bytes)][tensor]...`
186pub fn save_tensors(path: &Path, tensors: &[(&str, &ArrayD<f64>)]) -> Result<(), TensorIoError> {
187    let file = std::fs::File::create(path)?;
188    let mut writer = BufWriter::new(file);
189
190    let count = tensors.len() as u32;
191    writer.write_all(&count.to_le_bytes())?;
192
193    for &(name, tensor) in tensors {
194        let name_bytes = name.as_bytes();
195        let name_len = name_bytes.len() as u32;
196        writer.write_all(&name_len.to_le_bytes())?;
197        writer.write_all(name_bytes)?;
198        write_tensor(&mut writer, tensor)?;
199    }
200
201    writer.flush()?;
202    Ok(())
203}
204
205/// Load all named tensors from a multi-tensor binary file.
206pub fn load_tensors(path: &Path) -> Result<Vec<(String, ArrayD<f64>)>, TensorIoError> {
207    let file = std::fs::File::open(path)?;
208    let mut reader = BufReader::new(file);
209
210    let mut count_bytes = [0u8; 4];
211    reader.read_exact(&mut count_bytes)?;
212    let count = u32::from_le_bytes(count_bytes) as usize;
213
214    let mut result = Vec::with_capacity(count);
215    for _ in 0..count {
216        // Read name
217        let mut name_len_bytes = [0u8; 4];
218        reader.read_exact(&mut name_len_bytes)?;
219        let name_len = u32::from_le_bytes(name_len_bytes) as usize;
220
221        let mut name_bytes = vec![0u8; name_len];
222        reader.read_exact(&mut name_bytes)?;
223        let name = String::from_utf8(name_bytes)
224            .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?;
225
226        let tensor = read_tensor(&mut reader)?;
227        result.push((name, tensor));
228    }
229
230    Ok(result)
231}
232
233#[cfg(test)]
234mod tests {
235    use super::*;
236    use scirs2_core::ndarray::{arr0, Array, Array1, Array2};
237    use std::io::Cursor;
238
239    /// Helper to create a unique temp file path.
240    fn temp_path(name: &str) -> std::path::PathBuf {
241        std::env::temp_dir().join(format!("tensorlogic_test_{name}_{}", std::process::id()))
242    }
243
244    #[test]
245    fn test_header_from_tensor() {
246        let tensor = Array::from_shape_vec(IxDyn(&[2, 3, 4]), (0..24).map(|x| x as f64).collect())
247            .expect("failed to create tensor");
248        let header = TensorHeader::from_tensor(&tensor);
249        assert_eq!(header.ndim, 3);
250        assert_eq!(header.shape, vec![2, 3, 4]);
251        assert_eq!(header.element_count, 24);
252    }
253
254    #[test]
255    fn test_save_load_roundtrip() {
256        let tensor = Array::from_shape_vec(IxDyn(&[2, 3]), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
257            .expect("failed to create tensor");
258        let path = temp_path("roundtrip.tltf");
259        save_tensor(&path, &tensor).expect("save failed");
260        let loaded = load_tensor(&path).expect("load failed");
261        assert_eq!(tensor, loaded);
262        let _ = std::fs::remove_file(&path);
263    }
264
265    #[test]
266    fn test_save_load_scalar() {
267        let tensor = arr0(42.5).into_dyn();
268        let path = temp_path("scalar.tltf");
269        save_tensor(&path, &tensor).expect("save failed");
270        let loaded = load_tensor(&path).expect("load failed");
271        assert_eq!(tensor, loaded);
272        let _ = std::fs::remove_file(&path);
273    }
274
275    #[test]
276    fn test_save_load_1d() {
277        let tensor = Array1::from(vec![1.0, 2.0, 3.0, 4.0, 5.0]).into_dyn();
278        let path = temp_path("1d.tltf");
279        save_tensor(&path, &tensor).expect("save failed");
280        let loaded = load_tensor(&path).expect("load failed");
281        assert_eq!(tensor, loaded);
282        let _ = std::fs::remove_file(&path);
283    }
284
285    #[test]
286    fn test_save_load_2d() {
287        let tensor = Array2::from_shape_vec((3, 4), (0..12).map(|x| x as f64).collect())
288            .expect("failed to create tensor")
289            .into_dyn();
290        let path = temp_path("2d.tltf");
291        save_tensor(&path, &tensor).expect("save failed");
292        let loaded = load_tensor(&path).expect("load failed");
293        assert_eq!(tensor, loaded);
294        let _ = std::fs::remove_file(&path);
295    }
296
297    #[test]
298    fn test_save_load_3d() {
299        let tensor = Array::from_shape_vec(IxDyn(&[2, 3, 4]), (0..24).map(|x| x as f64).collect())
300            .expect("failed to create tensor");
301        let path = temp_path("3d.tltf");
302        save_tensor(&path, &tensor).expect("save failed");
303        let loaded = load_tensor(&path).expect("load failed");
304        assert_eq!(tensor, loaded);
305        let _ = std::fs::remove_file(&path);
306    }
307
308    #[test]
309    fn test_save_load_large() {
310        let data: Vec<f64> = (0..10_000).map(|x| x as f64 * 0.001).collect();
311        let tensor =
312            Array::from_shape_vec(IxDyn(&[100, 100]), data).expect("failed to create tensor");
313        let path = temp_path("large.tltf");
314        save_tensor(&path, &tensor).expect("save failed");
315        let loaded = load_tensor(&path).expect("load failed");
316        assert_eq!(tensor, loaded);
317        let _ = std::fs::remove_file(&path);
318    }
319
320    #[test]
321    fn test_write_read_in_memory() {
322        let tensor = Array::from_shape_vec(IxDyn(&[2, 2]), vec![1.0, 2.0, 3.0, 4.0])
323            .expect("failed to create tensor");
324        let mut buf = Vec::new();
325        write_tensor(&mut buf, &tensor).expect("write failed");
326        let mut cursor = Cursor::new(&buf);
327        let loaded = read_tensor(&mut cursor).expect("read failed");
328        assert_eq!(tensor, loaded);
329    }
330
331    #[test]
332    fn test_read_invalid_magic() {
333        let data = b"BADMxxxxxxxx";
334        let mut cursor = Cursor::new(data.as_slice());
335        let result = read_tensor(&mut cursor);
336        assert!(result.is_err());
337        match result {
338            Err(TensorIoError::InvalidMagic) => {}
339            other => panic!("Expected InvalidMagic, got {other:?}"),
340        }
341    }
342
343    #[test]
344    fn test_read_header_only() {
345        let tensor = Array::from_shape_vec(IxDyn(&[3, 5]), (0..15).map(|x| x as f64).collect())
346            .expect("failed to create tensor");
347        let mut buf = Vec::new();
348        write_tensor(&mut buf, &tensor).expect("write failed");
349        let mut cursor = Cursor::new(&buf);
350        let header = read_header(&mut cursor).expect("header read failed");
351        assert_eq!(header.ndim, 2);
352        assert_eq!(header.shape, vec![3, 5]);
353        assert_eq!(header.element_count, 15);
354    }
355
356    #[test]
357    fn test_save_load_tensors_multi() {
358        let t1 = Array1::from(vec![1.0, 2.0, 3.0]).into_dyn();
359        let t2 = Array2::from_shape_vec((2, 2), vec![4.0, 5.0, 6.0, 7.0])
360            .expect("failed to create tensor")
361            .into_dyn();
362        let t3 = arr0(99.0).into_dyn();
363
364        let path = temp_path("multi.tltf");
365        save_tensors(&path, &[("alpha", &t1), ("beta", &t2), ("gamma", &t3)]).expect("save failed");
366        let loaded = load_tensors(&path).expect("load failed");
367        assert_eq!(loaded.len(), 3);
368        assert_eq!(loaded[0].0, "alpha");
369        assert_eq!(loaded[0].1, t1);
370        assert_eq!(loaded[1].0, "beta");
371        assert_eq!(loaded[1].1, t2);
372        assert_eq!(loaded[2].0, "gamma");
373        assert_eq!(loaded[2].1, t3);
374        let _ = std::fs::remove_file(&path);
375    }
376
377    #[test]
378    fn test_save_load_tensors_empty_list() {
379        let path = temp_path("empty_multi.tltf");
380        save_tensors(&path, &[]).expect("save failed");
381        let loaded = load_tensors(&path).expect("load failed");
382        assert!(loaded.is_empty());
383        let _ = std::fs::remove_file(&path);
384    }
385
386    #[test]
387    fn test_save_load_tensors_names_preserved() {
388        let t = Array1::from(vec![1.0]).into_dyn();
389        let names = ["weights", "bias", "running_mean"];
390        let tensors: Vec<(&str, &ArrayD<f64>)> = names.iter().map(|n| (*n, &t)).collect();
391        let path = temp_path("names.tltf");
392        save_tensors(&path, &tensors).expect("save failed");
393        let loaded = load_tensors(&path).expect("load failed");
394        let loaded_names: Vec<&str> = loaded.iter().map(|(n, _)| n.as_str()).collect();
395        assert_eq!(loaded_names, names.to_vec());
396        let _ = std::fs::remove_file(&path);
397    }
398
399    #[test]
400    fn test_tensor_io_error_display() {
401        let e1 = TensorIoError::InvalidMagic;
402        assert!(!format!("{e1}").is_empty());
403
404        let e2 = TensorIoError::UnsupportedVersion(99);
405        assert!(format!("{e2}").contains("99"));
406
407        let e3 = TensorIoError::ShapeMismatch {
408            expected: 10,
409            got: 5,
410        };
411        let msg = format!("{e3}");
412        assert!(msg.contains("10"));
413        assert!(msg.contains("5"));
414    }
415
416    #[test]
417    fn test_header_size_bytes() {
418        let tensor = Array::from_shape_vec(IxDyn(&[4, 5]), (0..20).map(|x| x as f64).collect())
419            .expect("failed to create tensor");
420        let header = TensorHeader::from_tensor(&tensor);
421        assert_eq!(header.size_bytes, header.element_count * 8);
422        assert_eq!(header.size_bytes, 160);
423    }
424
425    #[test]
426    fn test_save_load_negative_values() {
427        let tensor = Array::from_shape_vec(IxDyn(&[4]), vec![-1.0, -100.5, -0.0, -f64::MAX])
428            .expect("failed to create tensor");
429        let path = temp_path("negative.tltf");
430        save_tensor(&path, &tensor).expect("save failed");
431        let loaded = load_tensor(&path).expect("load failed");
432        assert_eq!(tensor, loaded);
433        let _ = std::fs::remove_file(&path);
434    }
435
436    #[test]
437    fn test_save_load_special_values() {
438        let tensor = Array::from_shape_vec(
439            IxDyn(&[4]),
440            vec![f64::NAN, f64::INFINITY, f64::NEG_INFINITY, 0.0],
441        )
442        .expect("failed to create tensor");
443        let path = temp_path("special.tltf");
444        save_tensor(&path, &tensor).expect("save failed");
445        let loaded = load_tensor(&path).expect("load failed");
446        // NaN != NaN, so compare bitwise
447        for (orig, load) in tensor.iter().zip(loaded.iter()) {
448            assert_eq!(orig.to_bits(), load.to_bits());
449        }
450        let _ = std::fs::remove_file(&path);
451    }
452
453    #[test]
454    fn test_save_nonexistent_dir() {
455        let path = std::path::PathBuf::from("/nonexistent_dir_xyz/tensor.tltf");
456        let tensor = arr0(1.0).into_dyn();
457        let result = save_tensor(&path, &tensor);
458        assert!(result.is_err());
459    }
460}