tensorlogic_scirs_backend/
tensor_io.rs1use scirs2_core::ndarray::{ArrayD, IxDyn};
10use std::io::{BufReader, BufWriter, Read, Write};
11use std::path::Path;
12use thiserror::Error;
13
14const MAGIC: &[u8; 4] = b"TLTF";
16
17const VERSION: u8 = 1;
19
20#[derive(Debug, Error)]
22pub enum TensorIoError {
23 #[error("IO error: {0}")]
25 Io(#[from] std::io::Error),
26
27 #[error("Invalid magic bytes")]
29 InvalidMagic,
30
31 #[error("Unsupported version: {0}")]
33 UnsupportedVersion(u8),
34
35 #[error("Shape mismatch: expected {expected} elements, got {got}")]
37 ShapeMismatch { expected: usize, got: usize },
38}
39
40#[derive(Debug, Clone)]
42pub struct TensorHeader {
43 pub ndim: usize,
45 pub shape: Vec<usize>,
47 pub element_count: usize,
49 pub size_bytes: usize,
51}
52
53impl TensorHeader {
54 pub fn from_tensor(tensor: &ArrayD<f64>) -> Self {
56 let shape: Vec<usize> = tensor.shape().to_vec();
57 let element_count = tensor.len();
58 Self {
59 ndim: shape.len(),
60 shape,
61 element_count,
62 size_bytes: element_count * 8,
63 }
64 }
65}
66
67pub fn save_tensor(path: &Path, tensor: &ArrayD<f64>) -> Result<(), TensorIoError> {
69 let file = std::fs::File::create(path)?;
70 let mut writer = BufWriter::new(file);
71 write_tensor(&mut writer, tensor)?;
72 writer.flush()?;
73 Ok(())
74}
75
76pub fn load_tensor(path: &Path) -> Result<ArrayD<f64>, TensorIoError> {
78 let file = std::fs::File::open(path)?;
79 let mut reader = BufReader::new(file);
80 read_tensor(&mut reader)
81}
82
83pub fn write_tensor<W: Write>(writer: &mut W, tensor: &ArrayD<f64>) -> Result<(), TensorIoError> {
85 writer.write_all(MAGIC)?;
87 writer.write_all(&[VERSION])?;
89
90 let shape = tensor.shape();
91 let ndim = shape.len() as u32;
92 writer.write_all(&ndim.to_le_bytes())?;
94
95 for &dim in shape {
97 writer.write_all(&(dim as u64).to_le_bytes())?;
98 }
99
100 for &value in tensor.iter() {
102 writer.write_all(&value.to_le_bytes())?;
103 }
104
105 Ok(())
106}
107
108pub fn read_tensor<R: Read>(reader: &mut R) -> Result<ArrayD<f64>, TensorIoError> {
110 let header = read_header(reader)?;
111
112 let mut data = vec![0u8; header.element_count * 8];
114 reader.read_exact(&mut data)?;
115
116 let values: Vec<f64> = data
117 .chunks_exact(8)
118 .map(|chunk| {
119 let mut bytes = [0u8; 8];
120 bytes.copy_from_slice(chunk);
121 f64::from_le_bytes(bytes)
122 })
123 .collect();
124
125 if values.len() != header.element_count {
126 return Err(TensorIoError::ShapeMismatch {
127 expected: header.element_count,
128 got: values.len(),
129 });
130 }
131
132 let tensor = ArrayD::from_shape_vec(IxDyn(&header.shape), values).map_err(|_| {
133 TensorIoError::ShapeMismatch {
134 expected: header.element_count,
135 got: 0,
136 }
137 })?;
138
139 Ok(tensor)
140}
141
142pub fn read_header<R: Read>(reader: &mut R) -> Result<TensorHeader, TensorIoError> {
144 let mut magic = [0u8; 4];
146 reader.read_exact(&mut magic)?;
147 if &magic != MAGIC {
148 return Err(TensorIoError::InvalidMagic);
149 }
150
151 let mut ver = [0u8; 1];
153 reader.read_exact(&mut ver)?;
154 if ver[0] != VERSION {
155 return Err(TensorIoError::UnsupportedVersion(ver[0]));
156 }
157
158 let mut ndim_bytes = [0u8; 4];
160 reader.read_exact(&mut ndim_bytes)?;
161 let ndim = u32::from_le_bytes(ndim_bytes) as usize;
162
163 let mut shape = Vec::with_capacity(ndim);
165 for _ in 0..ndim {
166 let mut dim_bytes = [0u8; 8];
167 reader.read_exact(&mut dim_bytes)?;
168 shape.push(u64::from_le_bytes(dim_bytes) as usize);
169 }
170
171 let element_count = shape.iter().copied().product::<usize>().max(1);
172 let element_count = if ndim == 0 { 1 } else { element_count };
174
175 Ok(TensorHeader {
176 ndim,
177 shape,
178 element_count,
179 size_bytes: element_count * 8,
180 })
181}
182
183pub fn save_tensors(path: &Path, tensors: &[(&str, &ArrayD<f64>)]) -> Result<(), TensorIoError> {
187 let file = std::fs::File::create(path)?;
188 let mut writer = BufWriter::new(file);
189
190 let count = tensors.len() as u32;
191 writer.write_all(&count.to_le_bytes())?;
192
193 for &(name, tensor) in tensors {
194 let name_bytes = name.as_bytes();
195 let name_len = name_bytes.len() as u32;
196 writer.write_all(&name_len.to_le_bytes())?;
197 writer.write_all(name_bytes)?;
198 write_tensor(&mut writer, tensor)?;
199 }
200
201 writer.flush()?;
202 Ok(())
203}
204
205pub fn load_tensors(path: &Path) -> Result<Vec<(String, ArrayD<f64>)>, TensorIoError> {
207 let file = std::fs::File::open(path)?;
208 let mut reader = BufReader::new(file);
209
210 let mut count_bytes = [0u8; 4];
211 reader.read_exact(&mut count_bytes)?;
212 let count = u32::from_le_bytes(count_bytes) as usize;
213
214 let mut result = Vec::with_capacity(count);
215 for _ in 0..count {
216 let mut name_len_bytes = [0u8; 4];
218 reader.read_exact(&mut name_len_bytes)?;
219 let name_len = u32::from_le_bytes(name_len_bytes) as usize;
220
221 let mut name_bytes = vec![0u8; name_len];
222 reader.read_exact(&mut name_bytes)?;
223 let name = String::from_utf8(name_bytes)
224 .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?;
225
226 let tensor = read_tensor(&mut reader)?;
227 result.push((name, tensor));
228 }
229
230 Ok(result)
231}
232
233#[cfg(test)]
234mod tests {
235 use super::*;
236 use scirs2_core::ndarray::{arr0, Array, Array1, Array2};
237 use std::io::Cursor;
238
239 fn temp_path(name: &str) -> std::path::PathBuf {
241 std::env::temp_dir().join(format!("tensorlogic_test_{name}_{}", std::process::id()))
242 }
243
244 #[test]
245 fn test_header_from_tensor() {
246 let tensor = Array::from_shape_vec(IxDyn(&[2, 3, 4]), (0..24).map(|x| x as f64).collect())
247 .expect("failed to create tensor");
248 let header = TensorHeader::from_tensor(&tensor);
249 assert_eq!(header.ndim, 3);
250 assert_eq!(header.shape, vec![2, 3, 4]);
251 assert_eq!(header.element_count, 24);
252 }
253
254 #[test]
255 fn test_save_load_roundtrip() {
256 let tensor = Array::from_shape_vec(IxDyn(&[2, 3]), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
257 .expect("failed to create tensor");
258 let path = temp_path("roundtrip.tltf");
259 save_tensor(&path, &tensor).expect("save failed");
260 let loaded = load_tensor(&path).expect("load failed");
261 assert_eq!(tensor, loaded);
262 let _ = std::fs::remove_file(&path);
263 }
264
265 #[test]
266 fn test_save_load_scalar() {
267 let tensor = arr0(42.5).into_dyn();
268 let path = temp_path("scalar.tltf");
269 save_tensor(&path, &tensor).expect("save failed");
270 let loaded = load_tensor(&path).expect("load failed");
271 assert_eq!(tensor, loaded);
272 let _ = std::fs::remove_file(&path);
273 }
274
275 #[test]
276 fn test_save_load_1d() {
277 let tensor = Array1::from(vec![1.0, 2.0, 3.0, 4.0, 5.0]).into_dyn();
278 let path = temp_path("1d.tltf");
279 save_tensor(&path, &tensor).expect("save failed");
280 let loaded = load_tensor(&path).expect("load failed");
281 assert_eq!(tensor, loaded);
282 let _ = std::fs::remove_file(&path);
283 }
284
285 #[test]
286 fn test_save_load_2d() {
287 let tensor = Array2::from_shape_vec((3, 4), (0..12).map(|x| x as f64).collect())
288 .expect("failed to create tensor")
289 .into_dyn();
290 let path = temp_path("2d.tltf");
291 save_tensor(&path, &tensor).expect("save failed");
292 let loaded = load_tensor(&path).expect("load failed");
293 assert_eq!(tensor, loaded);
294 let _ = std::fs::remove_file(&path);
295 }
296
297 #[test]
298 fn test_save_load_3d() {
299 let tensor = Array::from_shape_vec(IxDyn(&[2, 3, 4]), (0..24).map(|x| x as f64).collect())
300 .expect("failed to create tensor");
301 let path = temp_path("3d.tltf");
302 save_tensor(&path, &tensor).expect("save failed");
303 let loaded = load_tensor(&path).expect("load failed");
304 assert_eq!(tensor, loaded);
305 let _ = std::fs::remove_file(&path);
306 }
307
308 #[test]
309 fn test_save_load_large() {
310 let data: Vec<f64> = (0..10_000).map(|x| x as f64 * 0.001).collect();
311 let tensor =
312 Array::from_shape_vec(IxDyn(&[100, 100]), data).expect("failed to create tensor");
313 let path = temp_path("large.tltf");
314 save_tensor(&path, &tensor).expect("save failed");
315 let loaded = load_tensor(&path).expect("load failed");
316 assert_eq!(tensor, loaded);
317 let _ = std::fs::remove_file(&path);
318 }
319
320 #[test]
321 fn test_write_read_in_memory() {
322 let tensor = Array::from_shape_vec(IxDyn(&[2, 2]), vec![1.0, 2.0, 3.0, 4.0])
323 .expect("failed to create tensor");
324 let mut buf = Vec::new();
325 write_tensor(&mut buf, &tensor).expect("write failed");
326 let mut cursor = Cursor::new(&buf);
327 let loaded = read_tensor(&mut cursor).expect("read failed");
328 assert_eq!(tensor, loaded);
329 }
330
331 #[test]
332 fn test_read_invalid_magic() {
333 let data = b"BADMxxxxxxxx";
334 let mut cursor = Cursor::new(data.as_slice());
335 let result = read_tensor(&mut cursor);
336 assert!(result.is_err());
337 match result {
338 Err(TensorIoError::InvalidMagic) => {}
339 other => panic!("Expected InvalidMagic, got {other:?}"),
340 }
341 }
342
343 #[test]
344 fn test_read_header_only() {
345 let tensor = Array::from_shape_vec(IxDyn(&[3, 5]), (0..15).map(|x| x as f64).collect())
346 .expect("failed to create tensor");
347 let mut buf = Vec::new();
348 write_tensor(&mut buf, &tensor).expect("write failed");
349 let mut cursor = Cursor::new(&buf);
350 let header = read_header(&mut cursor).expect("header read failed");
351 assert_eq!(header.ndim, 2);
352 assert_eq!(header.shape, vec![3, 5]);
353 assert_eq!(header.element_count, 15);
354 }
355
356 #[test]
357 fn test_save_load_tensors_multi() {
358 let t1 = Array1::from(vec![1.0, 2.0, 3.0]).into_dyn();
359 let t2 = Array2::from_shape_vec((2, 2), vec![4.0, 5.0, 6.0, 7.0])
360 .expect("failed to create tensor")
361 .into_dyn();
362 let t3 = arr0(99.0).into_dyn();
363
364 let path = temp_path("multi.tltf");
365 save_tensors(&path, &[("alpha", &t1), ("beta", &t2), ("gamma", &t3)]).expect("save failed");
366 let loaded = load_tensors(&path).expect("load failed");
367 assert_eq!(loaded.len(), 3);
368 assert_eq!(loaded[0].0, "alpha");
369 assert_eq!(loaded[0].1, t1);
370 assert_eq!(loaded[1].0, "beta");
371 assert_eq!(loaded[1].1, t2);
372 assert_eq!(loaded[2].0, "gamma");
373 assert_eq!(loaded[2].1, t3);
374 let _ = std::fs::remove_file(&path);
375 }
376
377 #[test]
378 fn test_save_load_tensors_empty_list() {
379 let path = temp_path("empty_multi.tltf");
380 save_tensors(&path, &[]).expect("save failed");
381 let loaded = load_tensors(&path).expect("load failed");
382 assert!(loaded.is_empty());
383 let _ = std::fs::remove_file(&path);
384 }
385
386 #[test]
387 fn test_save_load_tensors_names_preserved() {
388 let t = Array1::from(vec![1.0]).into_dyn();
389 let names = ["weights", "bias", "running_mean"];
390 let tensors: Vec<(&str, &ArrayD<f64>)> = names.iter().map(|n| (*n, &t)).collect();
391 let path = temp_path("names.tltf");
392 save_tensors(&path, &tensors).expect("save failed");
393 let loaded = load_tensors(&path).expect("load failed");
394 let loaded_names: Vec<&str> = loaded.iter().map(|(n, _)| n.as_str()).collect();
395 assert_eq!(loaded_names, names.to_vec());
396 let _ = std::fs::remove_file(&path);
397 }
398
399 #[test]
400 fn test_tensor_io_error_display() {
401 let e1 = TensorIoError::InvalidMagic;
402 assert!(!format!("{e1}").is_empty());
403
404 let e2 = TensorIoError::UnsupportedVersion(99);
405 assert!(format!("{e2}").contains("99"));
406
407 let e3 = TensorIoError::ShapeMismatch {
408 expected: 10,
409 got: 5,
410 };
411 let msg = format!("{e3}");
412 assert!(msg.contains("10"));
413 assert!(msg.contains("5"));
414 }
415
416 #[test]
417 fn test_header_size_bytes() {
418 let tensor = Array::from_shape_vec(IxDyn(&[4, 5]), (0..20).map(|x| x as f64).collect())
419 .expect("failed to create tensor");
420 let header = TensorHeader::from_tensor(&tensor);
421 assert_eq!(header.size_bytes, header.element_count * 8);
422 assert_eq!(header.size_bytes, 160);
423 }
424
425 #[test]
426 fn test_save_load_negative_values() {
427 let tensor = Array::from_shape_vec(IxDyn(&[4]), vec![-1.0, -100.5, -0.0, -f64::MAX])
428 .expect("failed to create tensor");
429 let path = temp_path("negative.tltf");
430 save_tensor(&path, &tensor).expect("save failed");
431 let loaded = load_tensor(&path).expect("load failed");
432 assert_eq!(tensor, loaded);
433 let _ = std::fs::remove_file(&path);
434 }
435
436 #[test]
437 fn test_save_load_special_values() {
438 let tensor = Array::from_shape_vec(
439 IxDyn(&[4]),
440 vec![f64::NAN, f64::INFINITY, f64::NEG_INFINITY, 0.0],
441 )
442 .expect("failed to create tensor");
443 let path = temp_path("special.tltf");
444 save_tensor(&path, &tensor).expect("save failed");
445 let loaded = load_tensor(&path).expect("load failed");
446 for (orig, load) in tensor.iter().zip(loaded.iter()) {
448 assert_eq!(orig.to_bits(), load.to_bits());
449 }
450 let _ = std::fs::remove_file(&path);
451 }
452
453 #[test]
454 fn test_save_nonexistent_dir() {
455 let path = std::path::PathBuf::from("/nonexistent_dir_xyz/tensor.tltf");
456 let tensor = arr0(1.0).into_dyn();
457 let result = save_tensor(&path, &tensor);
458 assert!(result.is_err());
459 }
460}