1#![allow(dead_code)]
7use anyhow::{Context, Result};
8use std::collections::HashMap;
9use std::path::Path;
10use tracing::{debug, info, warn};
11
12use super::types::{DType, Device, LayerInfo, ModelMetadata, TensorInfo, TorshModel};
15
16const TORSH_FORMAT_VERSION: &str = "0.1.0";
18
19const TORSH_MAGIC: &[u8; 8] = b"TORSH001";
21
22#[derive(Debug, serde::Serialize, serde::Deserialize)]
24struct ModelHeader {
25 magic: [u8; 8],
26 version: String,
27 metadata_offset: u64,
28 weights_offset: u64,
29 num_layers: usize,
30 num_tensors: usize,
31}
32
33#[derive(Debug, serde::Serialize, serde::Deserialize)]
35struct SerializedTensor {
36 name: String,
37 shape: Vec<usize>,
38 dtype: String,
39 requires_grad: bool,
40 device: String,
41 data_offset: u64,
42 data_size: u64,
43}
44
45pub async fn save_model(model: &TorshModel, path: &Path) -> Result<()> {
47 info!("Saving ToRSh model to {}", path.display());
48
49 verify_model(model)?;
51
52 let metadata_json =
54 serde_json::to_string(&model.metadata).context("Failed to serialize model metadata")?;
55
56 let layers_json =
57 serde_json::to_string(&model.layers).context("Failed to serialize model layers")?;
58
59 let mut serialized_tensors = Vec::new();
61 let mut tensor_data = Vec::new();
62 let mut current_offset = 0u64;
63
64 for (name, tensor_info) in &model.weights {
65 let elements: usize = tensor_info.shape.iter().product();
66 let data_size = (elements * tensor_info.dtype.size_bytes()) as u64;
67
68 debug!(
69 "Serializing tensor '{}' with shape {:?} ({} bytes)",
70 name, tensor_info.shape, data_size
71 );
72
73 serialized_tensors.push(SerializedTensor {
74 name: name.clone(),
75 shape: tensor_info.shape.clone(),
76 dtype: tensor_info.dtype.name().to_string(),
77 requires_grad: tensor_info.requires_grad,
78 device: tensor_info.device.name(),
79 data_offset: current_offset,
80 data_size,
81 });
82
83 let tensor_bytes = serialize_tensor_data(tensor_info)?;
87 tensor_data.extend_from_slice(&tensor_bytes);
88 current_offset += tensor_bytes.len() as u64;
89 }
90
91 let tensors_json = serde_json::to_string(&serialized_tensors)
92 .context("Failed to serialize tensor metadata")?;
93
94 let mut current_position = 0u64;
96
97 let header_json_estimate = serde_json::to_string(&ModelHeader {
99 magic: *TORSH_MAGIC,
100 version: TORSH_FORMAT_VERSION.to_string(),
101 metadata_offset: 0,
102 weights_offset: 0,
103 num_layers: model.layers.len(),
104 num_tensors: model.weights.len(),
105 })?;
106 current_position += header_json_estimate.len() as u64 + 1; let metadata_offset = current_position;
109 current_position += metadata_json.len() as u64 + 1;
110 current_position += layers_json.len() as u64 + 1;
111 current_position += tensors_json.len() as u64 + 1;
112 let weights_offset = current_position;
113
114 let header = ModelHeader {
115 magic: *TORSH_MAGIC,
116 version: TORSH_FORMAT_VERSION.to_string(),
117 metadata_offset,
118 weights_offset,
119 num_layers: model.layers.len(),
120 num_tensors: model.weights.len(),
121 };
122
123 let mut file_content = Vec::new();
125
126 file_content.extend_from_slice(TORSH_MAGIC);
128
129 let header_json = serde_json::to_string(&header)?;
131 file_content.extend_from_slice(header_json.as_bytes());
132 file_content.push(b'\n');
133
134 file_content.extend_from_slice(metadata_json.as_bytes());
136 file_content.push(b'\n');
137
138 file_content.extend_from_slice(layers_json.as_bytes());
140 file_content.push(b'\n');
141
142 file_content.extend_from_slice(tensors_json.as_bytes());
144 file_content.push(b'\n');
145
146 file_content.extend_from_slice(&tensor_data);
148
149 let temp_path = path.with_extension("torsh.tmp");
151 tokio::fs::write(&temp_path, &file_content)
152 .await
153 .with_context(|| {
154 format!(
155 "Failed to write temporary model file: {}",
156 temp_path.display()
157 )
158 })?;
159
160 tokio::fs::rename(&temp_path, path).await.with_context(|| {
161 format!(
162 "Failed to move model file to final location: {}",
163 path.display()
164 )
165 })?;
166
167 let file_size_mb = file_content.len() as f64 / (1024.0 * 1024.0);
169
170 info!(
171 "Successfully saved model with {} layers, {} tensors ({:.2} MB)",
172 model.layers.len(),
173 model.weights.len(),
174 file_size_mb
175 );
176
177 Ok(())
178}
179
180fn serialize_tensor_data(tensor_info: &TensorInfo) -> Result<Vec<u8>> {
182 let elements: usize = tensor_info.shape.iter().product();
183 let bytes_per_element = tensor_info.dtype.size_bytes();
184 let total_bytes = elements * bytes_per_element;
185
186 use scirs2_core::random::thread_rng;
189 let mut rng = thread_rng();
190
191 let mut data = Vec::with_capacity(total_bytes);
192
193 match tensor_info.dtype {
195 DType::F32 => {
196 for _ in 0..elements {
197 let value: f32 = rng.gen_range(-1.0..1.0);
198 data.extend_from_slice(&value.to_le_bytes());
199 }
200 }
201 DType::F64 => {
202 for _ in 0..elements {
203 let value: f64 = rng.gen_range(-1.0..1.0);
204 data.extend_from_slice(&value.to_le_bytes());
205 }
206 }
207 DType::F16 | DType::BF16 => {
208 for _ in 0..elements {
210 let value: f32 = rng.gen_range(-1.0..1.0);
211 let half_value = (value * 32768.0) as i16;
212 data.extend_from_slice(&half_value.to_le_bytes());
213 }
214 }
215 DType::I8 => {
216 for _ in 0..elements {
217 let value: i8 = rng.gen_range(-128..127);
218 data.push(value as u8);
219 }
220 }
221 DType::I32 => {
222 for _ in 0..elements {
223 let value: i32 = rng.gen_range(-1000..1000);
224 data.extend_from_slice(&value.to_le_bytes());
225 }
226 }
227 _ => {
228 data.resize(total_bytes, 0);
230 }
231 }
232
233 Ok(data)
234}
235
236pub async fn load_model(path: &Path) -> Result<TorshModel> {
238 info!("Loading ToRSh model from {}", path.display());
239
240 let file_content = tokio::fs::read(path)
241 .await
242 .with_context(|| format!("Failed to read model file: {}", path.display()))?;
243
244 if file_content.len() < 8 {
246 anyhow::bail!("Invalid model file: too small (< 8 bytes)");
247 }
248
249 let magic = &file_content[0..8];
250 if magic != TORSH_MAGIC {
251 anyhow::bail!(
252 "Invalid model file: incorrect magic bytes. Expected {:?}, got {:?}",
253 TORSH_MAGIC,
254 magic
255 );
256 }
257
258 debug!("Verified ToRSh model magic bytes");
259
260 let content_after_magic = &file_content[8..];
262 let content_str = String::from_utf8_lossy(content_after_magic);
263 let mut lines = content_str.lines();
264
265 let header_line = lines
267 .next()
268 .ok_or_else(|| anyhow::anyhow!("Missing model header"))?;
269 let header: ModelHeader =
270 serde_json::from_str(header_line).with_context(|| "Failed to parse model header")?;
271
272 debug!(
273 "Loaded model header: version {}, {} layers, {} tensors",
274 header.version, header.num_layers, header.num_tensors
275 );
276
277 if header.version != TORSH_FORMAT_VERSION {
279 warn!(
280 "Model format version mismatch: file is {}, current is {}",
281 header.version, TORSH_FORMAT_VERSION
282 );
283 }
284
285 let metadata_line = lines
287 .next()
288 .ok_or_else(|| anyhow::anyhow!("Missing model metadata"))?;
289 let metadata: ModelMetadata =
290 serde_json::from_str(metadata_line).with_context(|| "Failed to parse model metadata")?;
291
292 debug!("Loaded model metadata: {}", metadata.format);
293
294 let layers_line = lines
296 .next()
297 .ok_or_else(|| anyhow::anyhow!("Missing model layers"))?;
298 let layers: Vec<LayerInfo> =
299 serde_json::from_str(layers_line).with_context(|| "Failed to parse model layers")?;
300
301 debug!("Loaded {} layers", layers.len());
302
303 let tensors_line = lines
305 .next()
306 .ok_or_else(|| anyhow::anyhow!("Missing tensor metadata"))?;
307 let serialized_tensors: Vec<SerializedTensor> =
308 serde_json::from_str(tensors_line).with_context(|| "Failed to parse tensor metadata")?;
309
310 debug!("Loaded metadata for {} tensors", serialized_tensors.len());
311
312 let mut weights = HashMap::new();
314
315 for serialized_tensor in serialized_tensors {
316 let dtype = parse_dtype(&serialized_tensor.dtype)?;
317 let device = parse_device(&serialized_tensor.device)?;
318
319 let weight_info = TensorInfo {
320 name: serialized_tensor.name.clone(),
321 shape: serialized_tensor.shape.clone(),
322 dtype,
323 requires_grad: serialized_tensor.requires_grad,
324 device,
325 };
326
327 debug!(
328 "Loaded tensor: {} with shape {:?} and dtype {:?}",
329 weight_info.name, weight_info.shape, weight_info.dtype
330 );
331
332 weights.insert(serialized_tensor.name.clone(), weight_info);
333 }
334
335 let model = TorshModel {
336 layers,
337 weights,
338 metadata,
339 };
340
341 verify_model(&model)?;
343
344 let file_size_mb = file_content.len() as f64 / (1024.0 * 1024.0);
345 info!(
346 "Successfully loaded model with {} layers, {} tensors ({:.2} MB)",
347 model.layers.len(),
348 model.weights.len(),
349 file_size_mb
350 );
351
352 Ok(model)
353}
354
355fn parse_dtype(s: &str) -> Result<DType> {
357 match s {
358 "f32" => Ok(DType::F32),
359 "f64" => Ok(DType::F64),
360 "f16" => Ok(DType::F16),
361 "bf16" => Ok(DType::BF16),
362 "i8" => Ok(DType::I8),
363 "i16" => Ok(DType::I16),
364 "i32" => Ok(DType::I32),
365 "i64" => Ok(DType::I64),
366 "u8" => Ok(DType::U8),
367 "bool" => Ok(DType::Bool),
368 _ => anyhow::bail!("Unknown dtype: {}", s),
369 }
370}
371
372fn parse_device(s: &str) -> Result<Device> {
374 if s == "cpu" {
375 return Ok(Device::Cpu);
376 }
377 if s.starts_with("cuda:") {
378 let id: usize = s[5..]
379 .parse()
380 .with_context(|| format!("Invalid CUDA device ID in: {}", s))?;
381 return Ok(Device::Cuda(id));
382 }
383 if s.starts_with("metal:") {
384 let id: usize = s[6..]
385 .parse()
386 .with_context(|| format!("Invalid Metal device ID in: {}", s))?;
387 return Ok(Device::Metal(id));
388 }
389 if s == "vulkan" {
390 return Ok(Device::Vulkan);
391 }
392
393 anyhow::bail!("Unknown device: {}", s)
394}
395
396pub async fn export_safetensors(model: &TorshModel, path: &Path) -> Result<()> {
398 info!("Exporting model to SafeTensors format: {}", path.display());
399
400 let mut metadata = HashMap::new();
402 metadata.insert("format".to_string(), "torsh".to_string());
403 metadata.insert("version".to_string(), model.metadata.version.clone());
404
405 let mut tensor_data = Vec::new();
407 for (name, tensor_info) in &model.weights {
408 let elements: usize = tensor_info.shape.iter().product();
409 let data_size = elements * tensor_info.dtype.size_bytes();
410
411 tensor_data.extend_from_slice(name.as_bytes());
413 tensor_data.push(b'\n');
414
415 let shape_json = serde_json::to_string(&tensor_info.shape)?;
417 tensor_data.extend_from_slice(shape_json.as_bytes());
418 tensor_data.push(b'\n');
419
420 let dummy_data = vec![0u8; data_size];
422 tensor_data.extend_from_slice(&dummy_data);
423 }
424
425 tokio::fs::write(path, tensor_data)
426 .await
427 .with_context(|| format!("Failed to write SafeTensors file: {}", path.display()))?;
428
429 info!("Successfully exported to SafeTensors format");
430 Ok(())
431}
432
433pub fn create_sample_model(name: &str, num_layers: usize) -> TorshModel {
435 debug!("Creating sample model: {} with {} layers", name, num_layers);
436
437 let mut layers = Vec::new();
438 let mut weights = HashMap::new();
439
440 let mut input_dim = 784; let mut output_dim = 512;
442
443 for i in 0..num_layers {
444 let layer_name = format!("layer_{}", i);
445 let is_last = i == num_layers - 1;
446
447 if is_last {
448 output_dim = 10; }
450
451 let layer = LayerInfo {
453 name: layer_name.clone(),
454 layer_type: if is_last { "Linear" } else { "Linear" }.to_string(),
455 input_shape: vec![input_dim],
456 output_shape: vec![output_dim],
457 parameters: (input_dim * output_dim + output_dim) as u64,
458 trainable: true,
459 config: HashMap::new(),
460 };
461
462 let weight_name = format!("{}.weight", layer_name);
464 let weight_info = TensorInfo {
465 name: weight_name.clone(),
466 shape: vec![output_dim, input_dim],
467 dtype: DType::F32,
468 requires_grad: true,
469 device: Device::Cpu,
470 };
471
472 let bias_name = format!("{}.bias", layer_name);
474 let bias_info = TensorInfo {
475 name: bias_name.clone(),
476 shape: vec![output_dim],
477 dtype: DType::F32,
478 requires_grad: true,
479 device: Device::Cpu,
480 };
481
482 layers.push(layer);
483 weights.insert(weight_name, weight_info);
484 weights.insert(bias_name, bias_info);
485
486 input_dim = output_dim;
487 output_dim = if is_last { 10 } else { output_dim / 2 };
488 }
489
490 let mut metadata = ModelMetadata::default();
491 metadata.format = "torsh".to_string();
492 metadata.version = TORSH_FORMAT_VERSION.to_string();
493 metadata.description = Some(format!("Sample {} layer model", num_layers));
494 metadata.tags = vec!["sample".to_string(), "test".to_string()];
495
496 TorshModel {
497 layers,
498 weights,
499 metadata,
500 }
501}
502
503pub fn verify_model(model: &TorshModel) -> Result<()> {
505 debug!("Verifying model integrity");
506
507 for layer in &model.layers {
509 if layer.input_shape.is_empty() || layer.output_shape.is_empty() {
510 anyhow::bail!("Layer {} has invalid shape", layer.name);
511 }
512 }
513
514 for (name, tensor) in &model.weights {
516 if tensor.shape.is_empty() {
517 anyhow::bail!("Tensor {} has invalid shape", name);
518 }
519
520 let elements: usize = tensor.shape.iter().product();
521 if elements == 0 {
522 anyhow::bail!("Tensor {} has zero elements", name);
523 }
524 }
525
526 info!("Model verification passed");
527 Ok(())
528}
529
530pub fn get_model_stats(model: &TorshModel) -> HashMap<String, serde_json::Value> {
532 use serde_json::json;
533
534 let total_params: u64 = model.layers.iter().map(|l| l.parameters).sum();
535 let trainable_params: u64 = model
536 .layers
537 .iter()
538 .filter(|l| l.trainable)
539 .map(|l| l.parameters)
540 .sum();
541
542 let memory_footprint: u64 = model
543 .weights
544 .values()
545 .map(|t| {
546 let elements: usize = t.shape.iter().product();
547 (elements * t.dtype.size_bytes()) as u64
548 })
549 .sum();
550
551 let layer_types: HashMap<String, usize> =
552 model.layers.iter().fold(HashMap::new(), |mut acc, layer| {
553 *acc.entry(layer.layer_type.clone()).or_insert(0) += 1;
554 acc
555 });
556
557 let mut stats = HashMap::new();
558 stats.insert("total_parameters".to_string(), json!(total_params));
559 stats.insert("trainable_parameters".to_string(), json!(trainable_params));
560 stats.insert(
561 "non_trainable_parameters".to_string(),
562 json!(total_params - trainable_params),
563 );
564 stats.insert(
565 "memory_footprint_bytes".to_string(),
566 json!(memory_footprint),
567 );
568 stats.insert(
569 "memory_footprint_mb".to_string(),
570 json!(memory_footprint as f64 / (1024.0 * 1024.0)),
571 );
572 stats.insert("num_layers".to_string(), json!(model.layers.len()));
573 stats.insert("num_tensors".to_string(), json!(model.weights.len()));
574 stats.insert("layer_types".to_string(), json!(layer_types));
575
576 stats
577}
578
579#[cfg(test)]
580mod tests {
581 use super::*;
582
583 #[tokio::test]
584 async fn test_save_load_model() {
585 let model = create_sample_model("test_model", 3);
586 let temp_dir = std::env::temp_dir();
587 let model_path = temp_dir.join("test_model.torsh");
588
589 save_model(&model, &model_path)
591 .await
592 .expect("operation should succeed");
593
594 assert!(model_path.exists());
596
597 let loaded_model = load_model(&model_path)
599 .await
600 .expect("operation should succeed");
601
602 assert_eq!(loaded_model.metadata.format, "torsh");
605
606 let _ = tokio::fs::remove_file(model_path).await;
608 }
609
610 #[test]
611 fn test_model_verification() {
612 let model = create_sample_model("test", 2);
613 assert!(verify_model(&model).is_ok());
614 }
615
616 #[test]
617 fn test_model_stats() {
618 let model = create_sample_model("test", 3);
619 let stats = get_model_stats(&model);
620
621 assert!(stats.contains_key("total_parameters"));
622 assert!(stats.contains_key("memory_footprint_mb"));
623 assert!(stats.contains_key("num_layers"));
624 }
625}