1use super::core::{
5 compute_checksum, ComputationGraph, FileHeader, Loadable, ModelMetadata, Saveable,
6 SerializationError, SerializationResult, TensorMetadata,
7};
8use crate::tensor::Tensor;
9use num_traits::Float;
10use serde::{Deserialize, Serialize};
11use std::collections::HashMap;
12use std::fs::{File, OpenOptions};
13use std::io::{BufReader, BufWriter, Read, Write};
14use std::path::{Path, PathBuf};
15
16pub fn save<P: AsRef<Path>>(obj: &dyn Saveable, path: P) -> SerializationResult<()> {
19 let path = path.as_ref();
20
21 if let Some(parent) = path.parent() {
23 std::fs::create_dir_all(parent)?;
24 }
25
26 let file = OpenOptions::new()
27 .create(true)
28 .write(true)
29 .truncate(true)
30 .open(path)?;
31 let mut writer = BufWriter::new(file);
32
33 writer.write_all(b"RUSTORCH")?;
35
36 let metadata = obj.metadata();
38 let mut header = FileHeader::new(obj.type_id().to_string(), metadata);
39
40 let object_data = obj.save_binary()?;
42 header.checksum = compute_checksum(&object_data);
43
44 let header_data =
46 bincode::serialize(&header).map_err(|e| SerializationError::FormatError(e.to_string()))?;
47 let header_size = header_data.len() as u64;
48
49 writer.write_all(&header_size.to_le_bytes())?;
50 writer.write_all(&header_data)?;
51
52 writer.write_all(&object_data)?;
54 writer.flush()?;
55
56 Ok(())
57}
58
59pub fn load<P: AsRef<Path>, T: Loadable>(path: P) -> SerializationResult<T> {
62 let file = File::open(path.as_ref())?;
63 let mut reader = BufReader::new(file);
64
65 let mut magic = [0u8; 8];
67 reader.read_exact(&mut magic)?;
68 if &magic != b"RUSTORCH" {
69 return Err(SerializationError::FormatError(
70 "Invalid RusTorch file format".to_string(),
71 ));
72 }
73
74 let mut header_size_bytes = [0u8; 8];
76 reader.read_exact(&mut header_size_bytes)?;
77 let header_size = u64::from_le_bytes(header_size_bytes);
78
79 let mut header_data = vec![0u8; header_size as usize];
81 reader.read_exact(&mut header_data)?;
82 let header: FileHeader = bincode::deserialize(&header_data)
83 .map_err(|e| SerializationError::FormatError(e.to_string()))?;
84
85 header.validate()?;
87
88 if header.object_type != T::expected_type_id() {
90 return Err(SerializationError::TypeMismatch {
91 expected: T::expected_type_id().to_string(),
92 found: header.object_type,
93 });
94 }
95
96 T::validate_version(&header.version)?;
98
99 let mut object_data = Vec::new();
101 reader.read_to_end(&mut object_data)?;
102
103 let computed_checksum = compute_checksum(&object_data);
105 if computed_checksum != header.checksum {
106 return Err(SerializationError::CorruptionError(
107 "Checksum mismatch".to_string(),
108 ));
109 }
110
111 T::load_binary(&object_data)
113}
114
115#[derive(Debug, Clone)]
118pub struct StateDict<T: Float> {
119 pub parameters: HashMap<String, Tensor<T>>,
120 pub buffers: HashMap<String, Tensor<T>>,
121 pub metadata: ModelMetadata,
122}
123
124impl<T: Float + 'static> StateDict<T> {
125 pub fn new() -> Self {
128 Self {
129 parameters: HashMap::new(),
130 buffers: HashMap::new(),
131 metadata: ModelMetadata {
132 model_type: "unknown".to_string(),
133 parameters: HashMap::new(),
134 buffers: HashMap::new(),
135 config: HashMap::new(),
136 training_state: false,
137 },
138 }
139 }
140
141 pub fn add_parameter(&mut self, name: String, tensor: Tensor<T>) {
144 let metadata = TensorMetadata {
145 shape: tensor.shape().to_vec(),
146 dtype: std::any::type_name::<T>().to_string(),
147 device: "cpu".to_string(), requires_grad: true,
149 data_offset: 0, data_size: tensor.numel() as u64 * std::mem::size_of::<T>() as u64,
151 };
152
153 self.metadata.parameters.insert(name.clone(), metadata);
154 self.parameters.insert(name, tensor);
155 }
156
157 pub fn add_buffer(&mut self, name: String, tensor: Tensor<T>) {
160 let metadata = TensorMetadata {
161 shape: tensor.shape().to_vec(),
162 dtype: std::any::type_name::<T>().to_string(),
163 device: "cpu".to_string(),
164 requires_grad: false,
165 data_offset: 0,
166 data_size: tensor.numel() as u64 * std::mem::size_of::<T>() as u64,
167 };
168
169 self.metadata.buffers.insert(name.clone(), metadata);
170 self.buffers.insert(name, tensor);
171 }
172
173 pub fn get_parameter(&self, name: &str) -> Option<&Tensor<T>> {
176 self.parameters.get(name)
177 }
178
179 pub fn get_buffer(&self, name: &str) -> Option<&Tensor<T>> {
182 self.buffers.get(name)
183 }
184
185 pub fn is_training(&self) -> bool {
188 self.metadata.training_state
189 }
190
191 pub fn set_training(&mut self, training: bool) {
194 self.metadata.training_state = training;
195 }
196}
197
198impl<T: Float + 'static> Saveable for StateDict<T> {
199 fn save_binary(&self) -> SerializationResult<Vec<u8>> {
200 let mut buffer = Vec::new();
201
202 let metadata_json = serde_json::to_string(&self.metadata)
204 .map_err(|e| SerializationError::FormatError(e.to_string()))?;
205 let metadata_bytes = metadata_json.as_bytes();
206 buffer.extend_from_slice(&(metadata_bytes.len() as u64).to_le_bytes());
207 buffer.extend_from_slice(metadata_bytes);
208
209 buffer.extend_from_slice(&(self.parameters.len() as u32).to_le_bytes());
211 for (name, tensor) in &self.parameters {
212 let name_bytes = name.as_bytes();
213 buffer.extend_from_slice(&(name_bytes.len() as u32).to_le_bytes());
214 buffer.extend_from_slice(name_bytes);
215
216 let tensor_data = tensor.save_binary()?;
217 buffer.extend_from_slice(&(tensor_data.len() as u64).to_le_bytes());
218 buffer.extend_from_slice(&tensor_data);
219 }
220
221 buffer.extend_from_slice(&(self.buffers.len() as u32).to_le_bytes());
223 for (name, tensor) in &self.buffers {
224 let name_bytes = name.as_bytes();
225 buffer.extend_from_slice(&(name_bytes.len() as u32).to_le_bytes());
226 buffer.extend_from_slice(name_bytes);
227
228 let tensor_data = tensor.save_binary()?;
229 buffer.extend_from_slice(&(tensor_data.len() as u64).to_le_bytes());
230 buffer.extend_from_slice(&tensor_data);
231 }
232
233 Ok(buffer)
234 }
235
236 fn type_id(&self) -> &'static str {
237 "state_dict"
238 }
239
240 fn metadata(&self) -> HashMap<String, String> {
241 let mut meta = HashMap::new();
242 meta.insert("model_type".to_string(), self.metadata.model_type.clone());
243 meta.insert(
244 "num_parameters".to_string(),
245 self.parameters.len().to_string(),
246 );
247 meta.insert("num_buffers".to_string(), self.buffers.len().to_string());
248 meta.insert(
249 "training_state".to_string(),
250 self.metadata.training_state.to_string(),
251 );
252 meta
253 }
254}
255
256impl<T: Float + 'static> Loadable for StateDict<T> {
257 fn load_binary(data: &[u8]) -> SerializationResult<Self> {
258 if data.is_empty() {
259 return Ok(Self::new());
260 }
261
262 let mut offset = 0;
263 let mut state_dict = Self::new();
264
265 if data.len() < offset + 8 {
267 return Ok(state_dict);
268 }
269 let metadata_len =
270 u64::from_le_bytes(data[offset..offset + 8].try_into().map_err(|_| {
271 SerializationError::FormatError("Invalid metadata length".to_string())
272 })?) as usize;
273 offset += 8;
274
275 if data.len() < offset + metadata_len {
276 return Ok(state_dict);
277 }
278 let metadata_str =
279 std::str::from_utf8(&data[offset..offset + metadata_len]).map_err(|_| {
280 SerializationError::FormatError("Invalid metadata encoding".to_string())
281 })?;
282 state_dict.metadata = serde_json::from_str(metadata_str)
283 .map_err(|e| SerializationError::FormatError(e.to_string()))?;
284 offset += metadata_len;
285
286 if data.len() < offset + 4 {
288 return Ok(state_dict);
289 }
290 let params_count =
291 u32::from_le_bytes(data[offset..offset + 4].try_into().map_err(|_| {
292 SerializationError::FormatError("Invalid parameters count".to_string())
293 })?);
294 offset += 4;
295
296 for _ in 0..params_count {
297 if data.len() < offset + 4 {
299 break;
300 }
301 let name_len =
302 u32::from_le_bytes(data[offset..offset + 4].try_into().map_err(|_| {
303 SerializationError::FormatError("Invalid parameter name length".to_string())
304 })?) as usize;
305 offset += 4;
306
307 if data.len() < offset + name_len {
308 break;
309 }
310 let name =
311 String::from_utf8(data[offset..offset + name_len].to_vec()).map_err(|_| {
312 SerializationError::FormatError("Invalid parameter name encoding".to_string())
313 })?;
314 offset += name_len;
315
316 if data.len() < offset + 8 {
318 break;
319 }
320 let tensor_data_len =
321 u64::from_le_bytes(data[offset..offset + 8].try_into().map_err(|_| {
322 SerializationError::FormatError("Invalid tensor data length".to_string())
323 })?) as usize;
324 offset += 8;
325
326 if data.len() < offset + tensor_data_len {
327 break;
328 }
329 let tensor_data = &data[offset..offset + tensor_data_len];
330 if let Ok(tensor) = Tensor::<T>::load_binary(tensor_data) {
331 state_dict.parameters.insert(name, tensor);
332 }
333 offset += tensor_data_len;
334 }
335
336 if data.len() < offset + 4 {
338 return Ok(state_dict);
339 }
340 let buffers_count =
341 u32::from_le_bytes(data[offset..offset + 4].try_into().map_err(|_| {
342 SerializationError::FormatError("Invalid buffers count".to_string())
343 })?);
344 offset += 4;
345
346 for _ in 0..buffers_count {
347 if data.len() < offset + 4 {
349 break;
350 }
351 let name_len =
352 u32::from_le_bytes(data[offset..offset + 4].try_into().map_err(|_| {
353 SerializationError::FormatError("Invalid buffer name length".to_string())
354 })?) as usize;
355 offset += 4;
356
357 if data.len() < offset + name_len {
358 break;
359 }
360 let name =
361 String::from_utf8(data[offset..offset + name_len].to_vec()).map_err(|_| {
362 SerializationError::FormatError("Invalid buffer name encoding".to_string())
363 })?;
364 offset += name_len;
365
366 if data.len() < offset + 8 {
368 break;
369 }
370 let tensor_data_len =
371 u64::from_le_bytes(data[offset..offset + 8].try_into().map_err(|_| {
372 SerializationError::FormatError("Invalid tensor data length".to_string())
373 })?) as usize;
374 offset += 8;
375
376 if data.len() < offset + tensor_data_len {
377 break;
378 }
379 let tensor_data = &data[offset..offset + tensor_data_len];
380 if let Ok(tensor) = Tensor::<T>::load_binary(tensor_data) {
381 state_dict.buffers.insert(name, tensor);
382 }
383 offset += tensor_data_len;
384 }
385
386 Ok(state_dict)
387 }
388
389 fn expected_type_id() -> &'static str {
390 "state_dict"
391 }
392}
393
394#[derive(Debug, Clone)]
397pub struct SafeTensorFormat<T: Float> {
398 pub tensors: HashMap<String, Tensor<T>>,
399 pub metadata: HashMap<String, String>,
400}
401
402impl<T: Float + 'static> SafeTensorFormat<T> {
403 pub fn new() -> Self {
406 Self {
407 tensors: HashMap::new(),
408 metadata: HashMap::new(),
409 }
410 }
411
412 pub fn add_tensor(&mut self, name: String, tensor: Tensor<T>) {
415 self.tensors.insert(name, tensor);
416 }
417
418 pub fn save_safetensors<P: AsRef<Path>>(&self, path: P) -> SerializationResult<()> {
421 let mut header_data = HashMap::new();
423
424 for (name, tensor) in &self.tensors {
425 let shape: Vec<usize> = tensor.shape().to_vec();
426 header_data.insert(
427 name.clone(),
428 serde_json::json!({
429 "dtype": self.get_dtype_string(),
430 "shape": shape,
431 "data_offsets": [0, tensor.numel() * std::mem::size_of::<T>()]
432 }),
433 );
434 }
435
436 header_data.insert("__metadata__".to_string(), serde_json::json!(self.metadata));
438
439 let header_json = serde_json::to_string(&header_data)
440 .map_err(|e| SerializationError::FormatError(e.to_string()))?;
441
442 let file = OpenOptions::new()
443 .create(true)
444 .write(true)
445 .truncate(true)
446 .open(path)?;
447 let mut writer = BufWriter::new(file);
448
449 let header_size = header_json.len() as u64;
451 writer.write_all(&header_size.to_le_bytes())?;
452 writer.write_all(header_json.as_bytes())?;
453
454 for (_, tensor) in &self.tensors {
456 if let Some(data_slice) = tensor.data.as_slice() {
457 let bytes = unsafe {
458 std::slice::from_raw_parts(
459 data_slice.as_ptr() as *const u8,
460 data_slice.len() * std::mem::size_of::<T>(),
461 )
462 };
463 writer.write_all(bytes)?;
464 }
465 }
466
467 writer.flush()?;
468 Ok(())
469 }
470
471 fn get_dtype_string(&self) -> String {
472 match std::mem::size_of::<T>() {
473 4 => "F32".to_string(),
474 8 => "F64".to_string(),
475 _ => "UNKNOWN".to_string(),
476 }
477 }
478}
479
480#[derive(Debug, Clone)]
483pub struct ModelCheckpoint<T: Float> {
484 pub epoch: usize,
485 pub step: usize,
486 pub model_state: StateDict<T>,
487 pub optimizer_state: HashMap<String, Vec<u8>>,
488 pub scheduler_state: HashMap<String, Vec<u8>>,
489 pub metrics: HashMap<String, f64>,
490 pub timestamp: u64,
491}
492
493impl<T: Float + 'static> ModelCheckpoint<T> {
494 pub fn new(epoch: usize, step: usize, model_state: StateDict<T>) -> Self {
497 Self {
498 epoch,
499 step,
500 model_state,
501 optimizer_state: HashMap::new(),
502 scheduler_state: HashMap::new(),
503 metrics: HashMap::new(),
504 timestamp: std::time::SystemTime::now()
505 .duration_since(std::time::UNIX_EPOCH)
506 .unwrap_or_default()
507 .as_secs(),
508 }
509 }
510
511 pub fn add_optimizer_state(&mut self, name: String, state: Vec<u8>) {
514 self.optimizer_state.insert(name, state);
515 }
516
517 pub fn add_metric(&mut self, name: String, value: f64) {
520 self.metrics.insert(name, value);
521 }
522
523 pub fn save_checkpoint<P: AsRef<Path>>(&self, path: P) -> SerializationResult<()> {
526 save(self, path)
527 }
528
529 pub fn load_checkpoint<P: AsRef<Path>>(path: P) -> SerializationResult<Self> {
532 load(path)
533 }
534}
535
536impl<T: Float + 'static> Saveable for ModelCheckpoint<T> {
537 fn save_binary(&self) -> SerializationResult<Vec<u8>> {
538 let mut buffer = Vec::new();
539
540 buffer.extend_from_slice(&(self.epoch as u64).to_le_bytes());
542 buffer.extend_from_slice(&(self.step as u64).to_le_bytes());
543
544 let model_state_data = self.model_state.save_binary()?;
546 buffer.extend_from_slice(&(model_state_data.len() as u64).to_le_bytes());
547 buffer.extend_from_slice(&model_state_data);
548
549 buffer.extend_from_slice(&(self.optimizer_state.len() as u32).to_le_bytes());
551 for (key, value) in &self.optimizer_state {
552 let key_bytes = key.as_bytes();
553 buffer.extend_from_slice(&(key_bytes.len() as u32).to_le_bytes());
554 buffer.extend_from_slice(key_bytes);
555 buffer.extend_from_slice(&(value.len() as u64).to_le_bytes());
556 buffer.extend_from_slice(value);
557 }
558
559 buffer.extend_from_slice(&(self.scheduler_state.len() as u32).to_le_bytes());
561 for (key, value) in &self.scheduler_state {
562 let key_bytes = key.as_bytes();
563 buffer.extend_from_slice(&(key_bytes.len() as u32).to_le_bytes());
564 buffer.extend_from_slice(key_bytes);
565 buffer.extend_from_slice(&(value.len() as u64).to_le_bytes());
566 buffer.extend_from_slice(value);
567 }
568
569 buffer.extend_from_slice(&(self.metrics.len() as u32).to_le_bytes());
571 for (key, value) in &self.metrics {
572 let key_bytes = key.as_bytes();
573 buffer.extend_from_slice(&(key_bytes.len() as u32).to_le_bytes());
574 buffer.extend_from_slice(key_bytes);
575 buffer.extend_from_slice(&value.to_le_bytes());
576 }
577
578 buffer.extend_from_slice(&self.timestamp.to_le_bytes());
580
581 Ok(buffer)
582 }
583
584 fn type_id(&self) -> &'static str {
585 "model_checkpoint"
586 }
587
588 fn metadata(&self) -> HashMap<String, String> {
589 let mut meta = HashMap::new();
590 meta.insert("epoch".to_string(), self.epoch.to_string());
591 meta.insert("step".to_string(), self.step.to_string());
592 meta.insert("timestamp".to_string(), self.timestamp.to_string());
593 meta.insert(
594 "model_type".to_string(),
595 self.model_state.metadata.model_type.clone(),
596 );
597 meta
598 }
599}
600
601impl<T: Float + 'static> Loadable for ModelCheckpoint<T> {
602 fn load_binary(data: &[u8]) -> SerializationResult<Self> {
603 if data.is_empty() {
604 return Ok(Self::new(0, 0, StateDict::new()));
605 }
606
607 let mut offset = 0;
608 let mut checkpoint = Self::new(0, 0, StateDict::new());
609
610 if data.len() < offset + 16 {
612 return Ok(checkpoint);
613 }
614 checkpoint.epoch = u64::from_le_bytes(
615 data[offset..offset + 8]
616 .try_into()
617 .map_err(|_| SerializationError::FormatError("Invalid epoch".to_string()))?,
618 ) as usize;
619 offset += 8;
620
621 checkpoint.step = u64::from_le_bytes(
622 data[offset..offset + 8]
623 .try_into()
624 .map_err(|_| SerializationError::FormatError("Invalid step".to_string()))?,
625 ) as usize;
626 offset += 8;
627
628 if data.len() < offset + 8 {
630 return Ok(checkpoint);
631 }
632 let model_state_len =
633 u64::from_le_bytes(data[offset..offset + 8].try_into().map_err(|_| {
634 SerializationError::FormatError("Invalid model state length".to_string())
635 })?) as usize;
636 offset += 8;
637
638 if data.len() < offset + model_state_len {
639 return Ok(checkpoint);
640 }
641 let model_state_data = &data[offset..offset + model_state_len];
642 if let Ok(model_state) = StateDict::<T>::load_binary(model_state_data) {
643 checkpoint.model_state = model_state;
644 }
645 offset += model_state_len;
646
647 if data.len() < offset + 4 {
649 return Ok(checkpoint);
650 }
651 let optimizer_count =
652 u32::from_le_bytes(data[offset..offset + 4].try_into().map_err(|_| {
653 SerializationError::FormatError("Invalid optimizer count".to_string())
654 })?);
655 offset += 4;
656
657 for _ in 0..optimizer_count {
658 if data.len() < offset + 4 {
660 break;
661 }
662 let key_len =
663 u32::from_le_bytes(data[offset..offset + 4].try_into().map_err(|_| {
664 SerializationError::FormatError("Invalid key length".to_string())
665 })?) as usize;
666 offset += 4;
667
668 if data.len() < offset + key_len {
669 break;
670 }
671 let key = String::from_utf8(data[offset..offset + key_len].to_vec())
672 .map_err(|_| SerializationError::FormatError("Invalid key encoding".to_string()))?;
673 offset += key_len;
674
675 if data.len() < offset + 8 {
677 break;
678 }
679 let value_len =
680 u64::from_le_bytes(data[offset..offset + 8].try_into().map_err(|_| {
681 SerializationError::FormatError("Invalid value length".to_string())
682 })?) as usize;
683 offset += 8;
684
685 if data.len() < offset + value_len {
686 break;
687 }
688 let value = data[offset..offset + value_len].to_vec();
689 checkpoint.optimizer_state.insert(key, value);
690 offset += value_len;
691 }
692
693 if data.len() < offset + 4 {
695 return Ok(checkpoint);
696 }
697 let scheduler_count =
698 u32::from_le_bytes(data[offset..offset + 4].try_into().map_err(|_| {
699 SerializationError::FormatError("Invalid scheduler count".to_string())
700 })?);
701 offset += 4;
702
703 for _ in 0..scheduler_count {
704 if data.len() < offset + 4 {
706 break;
707 }
708 let key_len =
709 u32::from_le_bytes(data[offset..offset + 4].try_into().map_err(|_| {
710 SerializationError::FormatError("Invalid key length".to_string())
711 })?) as usize;
712 offset += 4;
713
714 if data.len() < offset + key_len {
715 break;
716 }
717 let key = String::from_utf8(data[offset..offset + key_len].to_vec())
718 .map_err(|_| SerializationError::FormatError("Invalid key encoding".to_string()))?;
719 offset += key_len;
720
721 if data.len() < offset + 8 {
723 break;
724 }
725 let value_len =
726 u64::from_le_bytes(data[offset..offset + 8].try_into().map_err(|_| {
727 SerializationError::FormatError("Invalid value length".to_string())
728 })?) as usize;
729 offset += 8;
730
731 if data.len() < offset + value_len {
732 break;
733 }
734 let value = data[offset..offset + value_len].to_vec();
735 checkpoint.scheduler_state.insert(key, value);
736 offset += value_len;
737 }
738
739 if data.len() < offset + 4 {
741 return Ok(checkpoint);
742 }
743 let metrics_count =
744 u32::from_le_bytes(data[offset..offset + 4].try_into().map_err(|_| {
745 SerializationError::FormatError("Invalid metrics count".to_string())
746 })?);
747 offset += 4;
748
749 for _ in 0..metrics_count {
750 if data.len() < offset + 4 {
752 break;
753 }
754 let key_len =
755 u32::from_le_bytes(data[offset..offset + 4].try_into().map_err(|_| {
756 SerializationError::FormatError("Invalid key length".to_string())
757 })?) as usize;
758 offset += 4;
759
760 if data.len() < offset + key_len {
761 break;
762 }
763 let key = String::from_utf8(data[offset..offset + key_len].to_vec())
764 .map_err(|_| SerializationError::FormatError("Invalid key encoding".to_string()))?;
765 offset += key_len;
766
767 if data.len() < offset + 8 {
769 break;
770 }
771 let value = f64::from_le_bytes(data[offset..offset + 8].try_into().map_err(|_| {
772 SerializationError::FormatError("Invalid metric value".to_string())
773 })?);
774 checkpoint.metrics.insert(key, value);
775 offset += 8;
776 }
777
778 if data.len() >= offset + 8 {
780 checkpoint.timestamp =
781 u64::from_le_bytes(data[offset..offset + 8].try_into().map_err(|_| {
782 SerializationError::FormatError("Invalid timestamp".to_string())
783 })?);
784 }
785
786 Ok(checkpoint)
787 }
788
789 fn expected_type_id() -> &'static str {
790 "model_checkpoint"
791 }
792}
793
794impl<T: Float + 'static> Saveable for Tensor<T> {
797 fn save_binary(&self) -> SerializationResult<Vec<u8>> {
798 let mut buffer = Vec::new();
799
800 let shape = self.shape();
802 buffer.extend_from_slice(&(shape.len() as u32).to_le_bytes());
803 for &dim in shape {
804 buffer.extend_from_slice(&(dim as u64).to_le_bytes());
805 }
806
807 if let Some(data_slice) = self.data.as_slice() {
809 let byte_len = data_slice.len() * std::mem::size_of::<T>();
810 buffer.extend_from_slice(&(byte_len as u64).to_le_bytes());
811 let bytes =
812 unsafe { std::slice::from_raw_parts(data_slice.as_ptr() as *const u8, byte_len) };
813 buffer.extend_from_slice(bytes);
814 } else {
815 buffer.extend_from_slice(&(0u64).to_le_bytes());
816 }
817
818 Ok(buffer)
819 }
820
821 fn type_id(&self) -> &'static str {
822 "tensor"
823 }
824
825 fn metadata(&self) -> HashMap<String, String> {
826 self.get_metadata()
827 }
828}
829
830impl<T: Float + 'static> Loadable for Tensor<T> {
831 fn load_binary(data: &[u8]) -> SerializationResult<Self> {
832 let mut cursor = 0;
833
834 if data.len() < 4 {
835 return Err(SerializationError::FormatError(
836 "Insufficient data for tensor shape".to_string(),
837 ));
838 }
839
840 let shape_len = u32::from_le_bytes([
842 data[cursor],
843 data[cursor + 1],
844 data[cursor + 2],
845 data[cursor + 3],
846 ]) as usize;
847 cursor += 4;
848
849 let mut shape = Vec::new();
851 for _ in 0..shape_len {
852 if cursor + 8 > data.len() {
853 return Err(SerializationError::FormatError(
854 "Insufficient data for tensor shape".to_string(),
855 ));
856 }
857 let dim = u64::from_le_bytes([
858 data[cursor],
859 data[cursor + 1],
860 data[cursor + 2],
861 data[cursor + 3],
862 data[cursor + 4],
863 data[cursor + 5],
864 data[cursor + 6],
865 data[cursor + 7],
866 ]) as usize;
867 shape.push(dim);
868 cursor += 8;
869 }
870
871 if cursor + 8 > data.len() {
873 return Err(SerializationError::FormatError(
874 "Insufficient data for tensor data length".to_string(),
875 ));
876 }
877 let data_len = u64::from_le_bytes([
878 data[cursor],
879 data[cursor + 1],
880 data[cursor + 2],
881 data[cursor + 3],
882 data[cursor + 4],
883 data[cursor + 5],
884 data[cursor + 6],
885 data[cursor + 7],
886 ]) as usize;
887 cursor += 8;
888
889 if cursor + data_len > data.len() {
891 return Err(SerializationError::FormatError(
892 "Insufficient data for tensor data".to_string(),
893 ));
894 }
895
896 let expected_elements = shape.iter().product::<usize>();
897 let actual_elements = data_len / std::mem::size_of::<T>();
898
899 if actual_elements != expected_elements {
900 return Err(SerializationError::FormatError(format!(
901 "Shape/data mismatch: shape requires {} elements, data has {}",
902 expected_elements, actual_elements
903 )));
904 }
905
906 let element_size = std::mem::size_of::<T>();
908 let ptr = data[cursor..cursor + data_len].as_ptr();
909
910 if (ptr as usize) % std::mem::align_of::<T>() != 0 {
912 let mut aligned_data = vec![0u8; data_len];
914 aligned_data.copy_from_slice(&data[cursor..cursor + data_len]);
915 let float_data = unsafe {
916 std::slice::from_raw_parts(aligned_data.as_ptr() as *const T, actual_elements)
917 };
918 return Ok(Tensor::from_vec(float_data.to_vec(), shape));
919 }
920
921 let float_data = unsafe { std::slice::from_raw_parts(ptr as *const T, actual_elements) };
922
923 Ok(Tensor::from_vec(float_data.to_vec(), shape))
924 }
925
926 fn expected_type_id() -> &'static str {
927 "tensor"
928 }
929}
930
931pub fn detect_format<P: AsRef<Path>>(path: P) -> SerializationResult<String> {
934 let file = File::open(path.as_ref())?;
935 let mut reader = BufReader::new(file);
936
937 let mut magic = [0u8; 16];
939 reader.read_exact(&mut magic)?;
940
941 if &magic[0..8] == b"RUSTORCH" {
942 Ok("rustorch".to_string())
943 } else if &magic[0..4] == b"PKG\x00" {
944 Ok("pickle".to_string())
945 } else if &magic[0..8] == b"safetens" {
946 Ok("safetensors".to_string())
947 } else {
948 Err(SerializationError::FormatError(
949 "Unknown file format".to_string(),
950 ))
951 }
952}
953
954#[cfg(test)]
955mod tests {
956 use super::*;
957 use crate::tensor::Tensor;
958
959 #[test]
960 fn test_state_dict_creation() {
961 let mut state_dict = StateDict::<f32>::new();
962
963 let tensor = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0], vec![2, 2]);
964 state_dict.add_parameter("weight".to_string(), tensor);
965
966 assert!(state_dict.get_parameter("weight").is_some());
967 assert_eq!(state_dict.parameters.len(), 1);
968 }
969
970 #[test]
971 fn test_tensor_save_load() {
972 let tensor = Tensor::from_vec(vec![1.0f32, 2.0, 3.0, 4.0], vec![2, 2]);
973
974 let binary_data = tensor.save_binary().unwrap();
975 let loaded_tensor = Tensor::<f32>::load_binary(&binary_data).unwrap();
976
977 assert_eq!(tensor.shape(), loaded_tensor.shape());
978 assert_eq!(tensor.data.as_slice(), loaded_tensor.data.as_slice());
979 }
980
981 #[test]
982 fn test_format_detection() {
983 }
986
987 #[test]
988 fn test_model_checkpoint() {
989 let mut state_dict = StateDict::<f32>::new();
990 let tensor = Tensor::from_vec(vec![1.0, 2.0], vec![2]);
991 state_dict.add_parameter("test_param".to_string(), tensor);
992
993 let checkpoint = ModelCheckpoint::new(5, 100, state_dict);
994
995 assert_eq!(checkpoint.epoch, 5);
996 assert_eq!(checkpoint.step, 100);
997 assert!(checkpoint.model_state.get_parameter("test_param").is_some());
998 }
999
1000 #[test]
1001 fn test_safe_tensor_format() {
1002 let mut safe_format = SafeTensorFormat::<f32>::new();
1003 let tensor = Tensor::from_vec(vec![1.0, 2.0, 3.0], vec![3]);
1004 safe_format.add_tensor("test_tensor".to_string(), tensor);
1005
1006 assert_eq!(safe_format.tensors.len(), 1);
1007 assert!(safe_format.tensors.contains_key("test_tensor"));
1008 }
1009}