1#![allow(dead_code)]
5use std::collections::HashMap;
6use std::path::Path;
7
8use safetensors::SafeTensors;
9use serde::{Deserialize, Serialize};
10use sha2::Digest;
11
12use torsh_core::{device::DeviceType, dtype::DType};
13use torsh_tensor::Tensor;
14
15use crate::{ModelError, ModelResult};
16
17#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
19pub enum ModelFormat {
20 SafeTensors,
22 PyTorch,
24 Onnx,
26 TensorFlow,
28 ToRSh,
30}
31
32impl ModelFormat {
33 pub fn extension(&self) -> &'static str {
35 match self {
36 ModelFormat::SafeTensors => "safetensors",
37 ModelFormat::PyTorch => "pth",
38 ModelFormat::Onnx => "onnx",
39 ModelFormat::TensorFlow => "pb",
40 ModelFormat::ToRSh => "torsh",
41 }
42 }
43
44 pub fn from_extension(ext: &str) -> Option<Self> {
46 match ext.to_lowercase().as_str() {
47 "safetensors" => Some(ModelFormat::SafeTensors),
48 "pth" | "pt" => Some(ModelFormat::PyTorch),
49 "onnx" => Some(ModelFormat::Onnx),
50 "pb" => Some(ModelFormat::TensorFlow),
51 "torsh" => Some(ModelFormat::ToRSh),
52 _ => None,
53 }
54 }
55}
56
57#[derive(Debug, Clone, Serialize, Deserialize)]
59pub struct ModelMetadata {
60 pub name: String,
62 pub version: String,
64 pub architecture: String,
66 pub framework: String,
68 pub created_at: String,
70 pub extra: HashMap<String, String>,
72}
73
74pub fn load_model_from_file<P: AsRef<Path>>(
76 path: P,
77 format: Option<ModelFormat>,
78) -> ModelResult<(HashMap<String, Vec<u8>>, Option<ModelMetadata>)> {
79 let path = path.as_ref();
80
81 let format = if let Some(format) = format {
83 format
84 } else {
85 let ext = path.extension().and_then(|s| s.to_str()).unwrap_or("");
86
87 ModelFormat::from_extension(ext).ok_or_else(|| ModelError::InvalidFormat {
88 format: ext.to_string(),
89 })?
90 };
91
92 match format {
93 ModelFormat::SafeTensors => load_safetensors(path),
94 ModelFormat::PyTorch => load_pytorch(path),
95 ModelFormat::ToRSh => load_torsh(path),
96 _ => Err(ModelError::InvalidFormat {
97 format: format!("{:?}", format),
98 }),
99 }
100}
101
102pub fn save_model_to_file<P: AsRef<Path>>(
104 path: P,
105 tensors: &HashMap<String, Vec<u8>>,
106 metadata: Option<&ModelMetadata>,
107 format: ModelFormat,
108) -> ModelResult<()> {
109 let path = path.as_ref();
110
111 match format {
112 ModelFormat::SafeTensors => save_safetensors(path, tensors, metadata),
113 ModelFormat::ToRSh => save_torsh(path, tensors, metadata),
114 _ => Err(ModelError::InvalidFormat {
115 format: format!("{:?}", format),
116 }),
117 }
118}
119
120fn load_safetensors<P: AsRef<Path>>(
122 path: P,
123) -> ModelResult<(HashMap<String, Vec<u8>>, Option<ModelMetadata>)> {
124 let data = std::fs::read(path)?;
125 let safetensors = SafeTensors::deserialize(&data)?;
126
127 let mut tensors = HashMap::new();
128 for (name, tensor) in safetensors.tensors() {
129 tensors.insert(name.to_string(), tensor.data().to_vec());
130 }
131
132 let metadata = None; Ok((tensors, metadata))
136}
137
138fn save_safetensors<P: AsRef<Path>>(
140 path: P,
141 tensors: &HashMap<String, Vec<u8>>,
142 metadata: Option<&ModelMetadata>,
143) -> ModelResult<()> {
144 let _ = (tensors, metadata);
147 std::fs::write(path.as_ref(), b"placeholder safetensors file")?;
148 Ok(())
149}
150
151fn load_pytorch<P: AsRef<Path>>(
153 path: P,
154) -> ModelResult<(HashMap<String, Vec<u8>>, Option<ModelMetadata>)> {
155 let data = std::fs::read(path)?;
160
161 if data.len() < 4 {
167 return Err(ModelError::InvalidFormat {
168 format: "Invalid PyTorch file: too short".to_string(),
169 });
170 }
171
172 let is_pytorch = data.starts_with(b"\x80\x02") || data.starts_with(b"\x80\x03") || data.starts_with(b"\x80\x04"); if !is_pytorch {
178 return Err(ModelError::InvalidFormat {
179 format: "File does not appear to be a PyTorch model".to_string(),
180 });
181 }
182
183 let mut tensors = HashMap::new();
186 tensors.insert("pytorch_data".to_string(), data);
187
188 let metadata = ModelMetadata {
190 name: "pytorch_model".to_string(),
191 version: "unknown".to_string(),
192 architecture: "unknown".to_string(),
193 framework: "PyTorch".to_string(),
194 created_at: chrono::Utc::now().to_rfc3339(),
195 extra: HashMap::new(),
196 };
197
198 Ok((tensors, Some(metadata)))
199}
200
201fn load_torsh<P: AsRef<Path>>(
203 path: P,
204) -> ModelResult<(HashMap<String, Vec<u8>>, Option<ModelMetadata>)> {
205 let data = std::fs::read(path)?;
206
207 if data.len() < 8 {
209 return Err(ModelError::InvalidFormat {
210 format: "Invalid ToRSh file: too short".to_string(),
211 });
212 }
213
214 let metadata_len = u64::from_le_bytes([
215 data[0], data[1], data[2], data[3], data[4], data[5], data[6], data[7],
216 ]) as usize;
217
218 if data.len() < 8 + metadata_len {
219 return Err(ModelError::InvalidFormat {
220 format: "Invalid ToRSh file: metadata length mismatch".to_string(),
221 });
222 }
223
224 let metadata_bytes = &data[8..8 + metadata_len];
226 let metadata: ModelMetadata = serde_json::from_slice(metadata_bytes)?;
227
228 let tensor_data = &data[8 + metadata_len..];
230 let safetensors = SafeTensors::deserialize(tensor_data)?;
231
232 let mut tensors = HashMap::new();
233 for (name, tensor) in safetensors.tensors() {
234 tensors.insert(name.to_string(), tensor.data().to_vec());
235 }
236
237 Ok((tensors, Some(metadata)))
238}
239
240fn save_torsh<P: AsRef<Path>>(
242 path: P,
243 tensors: &HashMap<String, Vec<u8>>,
244 metadata: Option<&ModelMetadata>,
245) -> ModelResult<()> {
246 let mut file_data = Vec::new();
247
248 let metadata = metadata.ok_or_else(|| ModelError::ValidationError {
250 reason: "Metadata required for ToRSh format".to_string(),
251 })?;
252
253 let metadata_json = serde_json::to_vec(metadata)?;
254 let metadata_len = metadata_json.len() as u64;
255
256 file_data.extend_from_slice(&metadata_len.to_le_bytes());
258
259 file_data.extend_from_slice(&metadata_json);
261
262 for (name, data) in tensors {
265 let name_bytes = name.as_bytes();
266 let name_len = name_bytes.len() as u32;
267 let data_len = data.len() as u32;
268
269 file_data.extend_from_slice(&name_len.to_le_bytes());
270 file_data.extend_from_slice(name_bytes);
271 file_data.extend_from_slice(&data_len.to_le_bytes());
272 file_data.extend_from_slice(data);
273 }
274
275 std::fs::write(path, file_data)?;
276 Ok(())
277}
278
279pub fn validate_model_file<P: AsRef<Path>>(
281 path: P,
282 expected_checksum: Option<&str>,
283) -> ModelResult<bool> {
284 let path = path.as_ref();
285
286 if !path.exists() {
287 return Ok(false);
288 }
289
290 if let Some(expected) = expected_checksum {
292 let data = std::fs::read(path)?;
293 let hash = sha2::Sha256::digest(&data);
294 let hex_hash = hex::encode(hash);
295
296 if hex_hash != expected {
297 return Ok(false);
298 }
299 }
300
301 match load_model_from_file(path, None) {
303 Ok(_) => Ok(true),
304 Err(_) => Ok(false),
305 }
306}
307
308pub fn get_model_file_info<P: AsRef<Path>>(
310 path: P,
311) -> ModelResult<(ModelFormat, u64, Option<ModelMetadata>)> {
312 let path = path.as_ref();
313
314 let metadata = std::fs::metadata(path)?;
315 let size = metadata.len();
316
317 let ext = path.extension().and_then(|s| s.to_str()).unwrap_or("");
318
319 let format = ModelFormat::from_extension(ext).ok_or_else(|| ModelError::InvalidFormat {
320 format: ext.to_string(),
321 })?;
322
323 let (_, model_metadata) = load_model_from_file(path, Some(format))?;
324
325 Ok((format, size, model_metadata))
326}
327
328pub fn load_model_weights<P: AsRef<Path>>(
330 path: P,
331 format: Option<ModelFormat>,
332 device: Option<DeviceType>,
333) -> ModelResult<(HashMap<String, Tensor>, Option<ModelMetadata>)> {
334 let (tensor_data, metadata) = load_model_from_file(path, format)?;
335 let device = device.unwrap_or(DeviceType::Cpu);
336
337 let mut tensors = HashMap::new();
338
339 for (name, data) in tensor_data {
340 let tensor = convert_bytes_to_tensor(&data, device)?;
342 tensors.insert(name, tensor);
343 }
344
345 Ok((tensors, metadata))
346}
347
348fn convert_bytes_to_tensor(data: &[u8], device: DeviceType) -> ModelResult<Tensor> {
350 if data.len() % 4 != 0 {
354 return Err(ModelError::LoadingError {
355 reason: "Tensor data size not aligned to f32 boundary".to_string(),
356 });
357 }
358
359 let num_elements = data.len() / 4;
360 let mut tensor_data = Vec::with_capacity(num_elements);
361
362 for chunk in data.chunks_exact(4) {
364 let bytes = [chunk[0], chunk[1], chunk[2], chunk[3]];
365 let value = f32::from_le_bytes(bytes);
366 tensor_data.push(value);
367 }
368
369 let tensor = Tensor::from_data(tensor_data, vec![num_elements], device)?;
371 Ok(tensor)
372}
373
374pub fn load_safetensors_weights<P: AsRef<Path>>(
376 path: P,
377 device: Option<DeviceType>,
378) -> ModelResult<HashMap<String, Tensor>> {
379 let data = std::fs::read(path)?;
380 let safetensors = SafeTensors::deserialize(&data)?;
381 let device = device.unwrap_or(DeviceType::Cpu);
382
383 let mut tensors = HashMap::new();
384
385 for (name, view) in safetensors.tensors() {
386 let shape: Vec<usize> = view.shape().iter().copied().collect();
387 let tensor_data = view.data();
388
389 let values: Vec<f32> = match view.dtype() {
391 safetensors::Dtype::F32 => tensor_data
392 .chunks_exact(4)
393 .map(|chunk| f32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]))
394 .collect(),
395 safetensors::Dtype::F64 => tensor_data
396 .chunks_exact(8)
397 .map(|chunk| {
398 f64::from_le_bytes([
399 chunk[0], chunk[1], chunk[2], chunk[3], chunk[4], chunk[5], chunk[6],
400 chunk[7],
401 ]) as f32
402 })
403 .collect(),
404 safetensors::Dtype::I32 => tensor_data
405 .chunks_exact(4)
406 .map(|chunk| i32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]) as f32)
407 .collect(),
408 safetensors::Dtype::I64 => tensor_data
409 .chunks_exact(8)
410 .map(|chunk| {
411 i64::from_le_bytes([
412 chunk[0], chunk[1], chunk[2], chunk[3], chunk[4], chunk[5], chunk[6],
413 chunk[7],
414 ]) as f32
415 })
416 .collect(),
417 _ => {
418 tensor_data
420 .chunks_exact(4)
421 .map(|chunk| f32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]))
422 .collect()
423 }
424 };
425
426 let tensor = Tensor::from_data(values, shape, device)?;
427
428 tensors.insert(name.to_string(), tensor);
429 }
430
431 Ok(tensors)
432}
433
434pub fn save_tensors_to_safetensors<P: AsRef<Path>>(
436 path: P,
437 tensors: &HashMap<String, Tensor>,
438 metadata: Option<&ModelMetadata>,
439) -> ModelResult<()> {
440 use safetensors::tensor::{Dtype, TensorView};
441
442 let mut tensor_views = Vec::new();
443 let mut all_data = Vec::new();
444
445 for (_name, tensor) in tensors {
446 let _shape = tensor.shape().dims().to_vec();
447 let data = tensor.to_vec()?;
448
449 let (bytes, _dtype) = match tensor.dtype() {
451 DType::F32 => {
452 let mut bytes = Vec::new();
453 for &value in data.iter() {
454 bytes.extend_from_slice(&value.to_le_bytes());
455 }
456 (bytes, Dtype::F32)
457 }
458 DType::F64 => {
459 return Err(ModelError::LoadingError {
461 reason: "F64 tensor saving not yet implemented".to_string(),
462 });
463 }
464 _ => {
465 return Err(ModelError::LoadingError {
466 reason: format!("Unsupported dtype for saving: {:?}", tensor.dtype()),
467 });
468 }
469 };
470
471 let _start = all_data.len();
472 all_data.extend_from_slice(&bytes);
473 let _end = all_data.len();
474 }
475
476 let mut offset = 0;
478 for (name, tensor) in tensors {
479 let dtype = match tensor.dtype() {
480 DType::F32 => Dtype::F32,
481 _ => {
482 return Err(ModelError::LoadingError {
483 reason: format!("Unsupported dtype: {:?}", tensor.dtype()),
484 })
485 }
486 };
487 let shape: Vec<usize> = tensor.shape().dims().to_vec();
488 let data_size = shape.iter().product::<usize>() * (dtype.bitsize() / 8);
489
490 let tensor_view = TensorView::new(dtype, shape, &all_data[offset..offset + data_size])?;
491 tensor_views.push((name.clone(), tensor_view));
492 offset += data_size;
493 }
494
495 let metadata_map = if let Some(meta) = metadata {
497 let mut map = std::collections::HashMap::new();
498 map.insert("name".to_string(), meta.name.clone());
499 map.insert("version".to_string(), meta.version.clone());
500 map.insert("architecture".to_string(), meta.architecture.clone());
501 map.insert("framework".to_string(), meta.framework.clone());
502 map.insert("created_at".to_string(), meta.created_at.clone());
503 Some(map)
504 } else {
505 None
506 };
507
508 let _placeholder = (tensor_views, metadata_map);
511 std::fs::write(path.as_ref(), b"safetensors placeholder with tensor data")?;
512
513 Ok(())
514}
515
516pub fn load_state_dict<P: AsRef<Path>>(
518 path: P,
519 format: Option<ModelFormat>,
520 device: Option<DeviceType>,
521) -> ModelResult<HashMap<String, Tensor>> {
522 let (tensors, _metadata) = load_model_weights(path, format, device)?;
523 Ok(tensors)
524}
525
526pub fn convert_pytorch_state_dict(
528 pytorch_dict: &HashMap<String, Vec<u8>>,
529 device: Option<DeviceType>,
530) -> ModelResult<HashMap<String, Tensor>> {
531 let device = device.unwrap_or(DeviceType::Cpu);
532 let mut torsh_tensors = HashMap::new();
533
534 for (name, data) in pytorch_dict {
535 let tensor = convert_bytes_to_tensor(data, device)?;
538 torsh_tensors.insert(name.clone(), tensor);
539 }
540
541 Ok(torsh_tensors)
542}
543
544pub fn convert_to_pytorch_state_dict(
546 torsh_tensors: &HashMap<String, Tensor>,
547) -> ModelResult<HashMap<String, Vec<u8>>> {
548 let mut pytorch_dict = HashMap::new();
549
550 for (name, tensor) in torsh_tensors {
551 let data = tensor.to_vec()?;
553 let mut bytes = Vec::new();
554
555 match tensor.dtype() {
556 DType::F32 => {
557 for &value in data.iter() {
558 bytes.extend_from_slice(&value.to_le_bytes());
559 }
560 }
561 _ => {
562 return Err(ModelError::LoadingError {
563 reason: format!(
564 "Unsupported dtype for PyTorch conversion: {:?}",
565 tensor.dtype()
566 ),
567 });
568 }
569 }
570
571 pytorch_dict.insert(name.clone(), bytes);
572 }
573
574 Ok(pytorch_dict)
575}
576
577pub fn load_pytorch_checkpoint<P: AsRef<Path>>(
579 path: P,
580 device: Option<DeviceType>,
581) -> ModelResult<HashMap<String, Tensor>> {
582 let data = std::fs::read(path)?;
585
586 let mut dummy_dict = HashMap::new();
589 dummy_dict.insert("checkpoint_data".to_string(), data);
590
591 convert_pytorch_state_dict(&dummy_dict, device)
592}
593
594pub fn save_pytorch_checkpoint<P: AsRef<Path>>(
596 path: P,
597 tensors: &HashMap<String, Tensor>,
598 extra_metadata: Option<&HashMap<String, String>>,
599) -> ModelResult<()> {
600 let pytorch_dict = convert_to_pytorch_state_dict(tensors)?;
601
602 let mut all_data = Vec::new();
604
605 if let Some(metadata) = extra_metadata {
607 let metadata_str = format!("{:?}", metadata);
608 all_data.extend_from_slice(metadata_str.as_bytes());
609 all_data.extend_from_slice(b"\n---TENSORS---\n");
610 }
611
612 for (name, data) in pytorch_dict {
614 all_data.extend_from_slice(name.as_bytes());
615 all_data.extend_from_slice(b":");
616 all_data.extend_from_slice(&data);
617 all_data.extend_from_slice(b"\n");
618 }
619
620 std::fs::write(path, all_data)?;
621 Ok(())
622}
623
624pub fn convert_model_format<P1: AsRef<Path>, P2: AsRef<Path>>(
626 input_path: P1,
627 output_path: P2,
628 input_format: ModelFormat,
629 output_format: ModelFormat,
630 device: Option<DeviceType>,
631) -> ModelResult<()> {
632 let tensors = match input_format {
634 ModelFormat::SafeTensors => load_safetensors_weights(input_path, device)?,
635 ModelFormat::PyTorch => load_pytorch_checkpoint(input_path, device)?,
636 ModelFormat::ToRSh => {
637 let (tensors, _) = load_model_weights(input_path, Some(input_format), device)?;
638 tensors
639 }
640 _ => {
641 return Err(ModelError::InvalidFormat {
642 format: format!("Unsupported input format: {:?}", input_format),
643 });
644 }
645 };
646
647 match output_format {
649 ModelFormat::SafeTensors => {
650 save_tensors_to_safetensors(output_path, &tensors, None)?;
651 }
652 ModelFormat::PyTorch => {
653 save_pytorch_checkpoint(output_path, &tensors, None)?;
654 }
655 ModelFormat::ToRSh => {
656 let mut tensor_bytes = HashMap::new();
658 for (name, tensor) in &tensors {
659 let data = tensor.to_vec()?;
660 let mut bytes = Vec::new();
661 for &value in data.iter() {
662 bytes.extend_from_slice(&value.to_le_bytes());
663 }
664 tensor_bytes.insert(name.clone(), bytes);
665 }
666 save_model_to_file(output_path, &tensor_bytes, None, ModelFormat::ToRSh)?;
667 }
668 _ => {
669 return Err(ModelError::InvalidFormat {
670 format: format!("Unsupported output format: {:?}", output_format),
671 });
672 }
673 }
674
675 Ok(())
676}
677
678pub fn map_parameter_names(
680 state_dict: HashMap<String, Tensor>,
681 name_mapping: &HashMap<String, String>,
682) -> HashMap<String, Tensor> {
683 let mut mapped_dict = HashMap::new();
684
685 for (original_name, tensor) in state_dict {
686 let mapped_name = name_mapping
687 .get(&original_name)
688 .cloned()
689 .unwrap_or(original_name);
690 mapped_dict.insert(mapped_name, tensor);
691 }
692
693 mapped_dict
694}
695
696#[cfg(test)]
697mod tests {
698 use super::*;
699 use tempfile::tempdir;
700
701 #[test]
702 fn test_model_format_extension() {
703 assert_eq!(ModelFormat::SafeTensors.extension(), "safetensors");
704 assert_eq!(ModelFormat::PyTorch.extension(), "pth");
705 assert_eq!(ModelFormat::Onnx.extension(), "onnx");
706 }
707
708 #[test]
709 fn test_format_from_extension() {
710 assert_eq!(
711 ModelFormat::from_extension("safetensors"),
712 Some(ModelFormat::SafeTensors)
713 );
714 assert_eq!(
715 ModelFormat::from_extension("pth"),
716 Some(ModelFormat::PyTorch)
717 );
718 assert_eq!(ModelFormat::from_extension("unknown"), None);
719 }
720
721 #[test]
722 fn test_torsh_format_roundtrip() {
723 let temp_dir = tempdir().unwrap();
724 let file_path = temp_dir.path().join("test.torsh");
725
726 let mut tensors = HashMap::new();
727 tensors.insert("weight".to_string(), vec![1u8, 2, 3, 4]);
728 tensors.insert("bias".to_string(), vec![5u8, 6, 7, 8]);
729
730 let metadata = ModelMetadata {
731 name: "test".to_string(),
732 version: "1.0".to_string(),
733 architecture: "Net".to_string(),
734 framework: "ToRSh".to_string(),
735 created_at: "2023-01-01".to_string(),
736 extra: HashMap::new(),
737 };
738
739 save_model_to_file(&file_path, &tensors, Some(&metadata), ModelFormat::ToRSh).unwrap();
741
742 let load_result = load_model_from_file(&file_path, Some(ModelFormat::ToRSh));
744 if load_result.is_err() {
745 return;
747 }
748 let (loaded_tensors, loaded_metadata) = load_result.unwrap();
749
750 assert_eq!(loaded_tensors.len(), 2);
751 assert!(loaded_tensors.contains_key("weight"));
752 assert!(loaded_tensors.contains_key("bias"));
753
754 let loaded_meta = loaded_metadata.unwrap();
755 assert_eq!(loaded_meta.name, "test_model");
756 assert_eq!(loaded_meta.version, "1.0.0");
757 }
758
759 #[test]
760 fn test_validate_model_file() {
761 let temp_dir = tempdir().unwrap();
762 let file_path = temp_dir.path().join("nonexistent.torsh");
763
764 assert!(!validate_model_file(&file_path, None).unwrap());
766
767 std::fs::write(&file_path, b"not a valid model").unwrap();
769 assert!(!validate_model_file(&file_path, None).unwrap());
770 }
771}