scirs2_neural/serialization/traits.rs
1//! Generic model serialization traits
2//!
3//! This module provides the `ModelSerialize` and `ModelDeserialize` traits
4//! that allow any neural network architecture to be saved to and loaded from disk.
5//! These traits work with multiple formats (JSON, SafeTensors, etc.) and handle
6//! nested layers, attention heads, and normalization parameters.
7
8use crate::error::Result;
9use std::path::Path;
10
11/// Supported serialization formats for model persistence
12#[derive(Debug, Clone, Copy, PartialEq, Eq)]
13pub enum ModelFormat {
14 /// JSON format - human-readable, larger files
15 Json,
16 /// SafeTensors format - binary, HuggingFace-compatible
17 SafeTensors,
18 /// CBOR format - binary, compact
19 Cbor,
20 /// MessagePack format - binary, compact
21 MessagePack,
22}
23
24/// Trait for serializing a model to disk
25///
26/// Any neural network architecture that implements this trait can be saved
27/// to a file in one of the supported formats. The serialization captures
28/// both the model configuration (architecture) and the learned parameters (weights).
29///
30/// # Example
31///
32/// ```rust
33/// use scirs2_neural::serialization::traits::{ModelSerialize, ModelFormat};
34///
35/// // ModelSerialize is a trait implemented by model architectures.
36/// // Example usage (with a model that implements ModelSerialize):
37/// let format = ModelFormat::SafeTensors;
38/// assert_eq!(format, ModelFormat::SafeTensors);
39/// ```
40pub trait ModelSerialize {
41 /// Save the model to the specified path in the given format
42 ///
43 /// This method serializes both the model architecture (configuration)
44 /// and all learned parameters (weights, biases, normalization stats, etc.)
45 fn save(&self, path: &Path, format: ModelFormat) -> Result<()>;
46
47 /// Serialize the model to bytes in the given format
48 ///
49 /// This is useful when you want to store the serialized model in memory
50 /// or send it over a network rather than writing to disk.
51 fn to_bytes(&self, format: ModelFormat) -> Result<Vec<u8>>;
52
53 /// Get the architecture name for this model (e.g., "ResNet", "BERT", "GPT")
54 fn architecture_name(&self) -> &str;
55
56 /// Get the model version string
57 fn model_version(&self) -> String {
58 "0.1.0".to_string()
59 }
60}
61
62/// Trait for deserializing a model from disk
63///
64/// Any neural network architecture that implements this trait can be loaded
65/// from a file that was previously saved with `ModelSerialize`.
66///
67/// # Example
68///
69/// ```rust
70/// use scirs2_neural::serialization::traits::{ModelDeserialize, ModelFormat};
71///
72/// // ModelDeserialize is a trait implemented by model architectures.
73/// // Example usage (with a model that implements ModelDeserialize):
74/// let format = ModelFormat::Json;
75/// assert_eq!(format, ModelFormat::Json);
76/// ```
77pub trait ModelDeserialize: Sized {
78 /// Load the model from the specified path in the given format
79 ///
80 /// This method deserializes both the model architecture and all
81 /// learned parameters, reconstructing a fully functional model.
82 fn load(path: &Path, format: ModelFormat) -> Result<Self>;
83
84 /// Deserialize the model from bytes in the given format
85 ///
86 /// This is useful when loading from a network stream or in-memory buffer.
87 fn from_bytes(bytes: &[u8], format: ModelFormat) -> Result<Self>;
88}
89
90/// Metadata about a serialized model, stored alongside the weights
91#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
92pub struct ModelMetadata {
93 /// Architecture name (e.g., "ResNet", "BERT", "GPT")
94 pub architecture: String,
95 /// Model version
96 pub version: String,
97 /// Framework version that produced this file
98 pub framework_version: String,
99 /// Number of parameters in the model
100 pub num_parameters: usize,
101 /// Data type used for parameters (e.g., "f32", "f64")
102 pub dtype: String,
103 /// Additional key-value metadata
104 pub extra: std::collections::HashMap<String, String>,
105}
106
107impl ModelMetadata {
108 /// Create new metadata for a model
109 pub fn new(architecture: &str, dtype: &str, num_parameters: usize) -> Self {
110 Self {
111 architecture: architecture.to_string(),
112 version: "0.1.0".to_string(),
113 framework_version: env!("CARGO_PKG_VERSION").to_string(),
114 num_parameters,
115 dtype: dtype.to_string(),
116 extra: std::collections::HashMap::new(),
117 }
118 }
119
120 /// Add an extra metadata key-value pair
121 pub fn with_extra(mut self, key: &str, value: &str) -> Self {
122 self.extra.insert(key.to_string(), value.to_string());
123 self
124 }
125}
126
127/// Information about a single tensor in a serialized model
128#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
129pub struct TensorInfo {
130 /// Name of the tensor (e.g., "layer1.weight", "encoder.attention.query")
131 pub name: String,
132 /// Data type (e.g., "F32", "F64")
133 pub dtype: String,
134 /// Shape of the tensor
135 pub shape: Vec<usize>,
136 /// Byte offset in the data section
137 pub data_offset: usize,
138 /// Number of bytes for this tensor
139 pub byte_length: usize,
140}
141
142impl TensorInfo {
143 /// Create a new TensorInfo
144 pub fn new(
145 name: &str,
146 dtype: &str,
147 shape: Vec<usize>,
148 data_offset: usize,
149 byte_length: usize,
150 ) -> Self {
151 Self {
152 name: name.to_string(),
153 dtype: dtype.to_string(),
154 shape,
155 data_offset,
156 byte_length,
157 }
158 }
159
160 /// Get the total number of elements in this tensor
161 pub fn num_elements(&self) -> usize {
162 if self.shape.is_empty() {
163 0
164 } else {
165 self.shape.iter().product()
166 }
167 }
168}
169
170/// A named parameter collection that can be extracted from any model
171///
172/// This provides a uniform interface for accessing model parameters
173/// regardless of the underlying architecture.
174#[derive(Debug, Clone)]
175pub struct NamedParameters {
176 /// Ordered list of (name, flattened_f64_values, shape) tuples
177 pub parameters: Vec<(String, Vec<f64>, Vec<usize>)>,
178}
179
180impl NamedParameters {
181 /// Create a new empty NamedParameters collection
182 pub fn new() -> Self {
183 Self {
184 parameters: Vec::new(),
185 }
186 }
187
188 /// Add a parameter tensor
189 pub fn add(&mut self, name: &str, values: Vec<f64>, shape: Vec<usize>) {
190 self.parameters.push((name.to_string(), values, shape));
191 }
192
193 /// Get the total number of scalar parameters
194 pub fn total_parameters(&self) -> usize {
195 self.parameters.iter().map(|(_, v, _)| v.len()).sum()
196 }
197
198 /// Find a parameter by name
199 pub fn get(&self, name: &str) -> Option<&(String, Vec<f64>, Vec<usize>)> {
200 self.parameters.iter().find(|(n, _, _)| n == name)
201 }
202
203 /// Get the number of named parameter groups
204 pub fn len(&self) -> usize {
205 self.parameters.len()
206 }
207
208 /// Check if empty
209 pub fn is_empty(&self) -> bool {
210 self.parameters.is_empty()
211 }
212}
213
214impl Default for NamedParameters {
215 fn default() -> Self {
216 Self::new()
217 }
218}
219
220/// Trait for extracting named parameters from a model
221///
222/// This trait provides a standardized way to extract all named parameters
223/// from any model architecture, enabling format-agnostic serialization.
224pub trait ExtractParameters {
225 /// Extract all named parameters from the model
226 ///
227 /// Parameters are returned as named `(String, Vec<f64>, Vec<usize>)` tuples
228 /// where the first element is the name (e.g., "encoder.layer.0.attention.query.weight"),
229 /// the second is the flattened parameter values, and the third is the shape.
230 fn extract_named_parameters(&self) -> Result<NamedParameters>;
231
232 /// Load named parameters into the model
233 ///
234 /// This method takes a NamedParameters collection and sets the model's
235 /// parameters accordingly. Parameter names must match those returned
236 /// by `extract_named_parameters`.
237 fn load_named_parameters(&mut self, params: &NamedParameters) -> Result<()>;
238}
239
240#[cfg(test)]
241mod tests {
242 use super::*;
243
244 #[test]
245 fn test_model_metadata_creation() {
246 let metadata = ModelMetadata::new("ResNet", "f32", 11_000_000);
247 assert_eq!(metadata.architecture, "ResNet");
248 assert_eq!(metadata.dtype, "f32");
249 assert_eq!(metadata.num_parameters, 11_000_000);
250 }
251
252 #[test]
253 fn test_model_metadata_with_extra() {
254 let metadata = ModelMetadata::new("BERT", "f32", 110_000_000)
255 .with_extra("variant", "base-uncased")
256 .with_extra("vocab_size", "30522");
257 assert_eq!(
258 metadata.extra.get("variant"),
259 Some(&"base-uncased".to_string())
260 );
261 assert_eq!(metadata.extra.get("vocab_size"), Some(&"30522".to_string()));
262 }
263
264 #[test]
265 fn test_tensor_info() {
266 let info = TensorInfo::new("layer1.weight", "F32", vec![768, 3072], 0, 768 * 3072 * 4);
267 assert_eq!(info.num_elements(), 768 * 3072);
268 assert_eq!(info.byte_length, 768 * 3072 * 4);
269 }
270
271 #[test]
272 fn test_tensor_info_empty_shape() {
273 let info = TensorInfo::new("empty", "F32", vec![], 0, 0);
274 assert_eq!(info.num_elements(), 0);
275 }
276
277 #[test]
278 fn test_named_parameters() {
279 let mut params = NamedParameters::new();
280 assert!(params.is_empty());
281 assert_eq!(params.len(), 0);
282
283 params.add("layer1.weight", vec![1.0, 2.0, 3.0, 4.0], vec![2, 2]);
284 params.add("layer1.bias", vec![0.1, 0.2], vec![2]);
285
286 assert_eq!(params.len(), 2);
287 assert!(!params.is_empty());
288 assert_eq!(params.total_parameters(), 6);
289
290 let found = params.get("layer1.weight");
291 assert!(found.is_some());
292 let (name, values, shape) = found.expect("parameter should exist");
293 assert_eq!(name, "layer1.weight");
294 assert_eq!(values, &[1.0, 2.0, 3.0, 4.0]);
295 assert_eq!(shape, &[2, 2]);
296
297 assert!(params.get("nonexistent").is_none());
298 }
299
300 #[test]
301 fn test_model_format_enum() {
302 let fmt = ModelFormat::SafeTensors;
303 assert_eq!(fmt, ModelFormat::SafeTensors);
304 assert_ne!(fmt, ModelFormat::Json);
305
306 // Test all variants exist
307 let _json = ModelFormat::Json;
308 let _st = ModelFormat::SafeTensors;
309 let _cbor = ModelFormat::Cbor;
310 let _mp = ModelFormat::MessagePack;
311 }
312}