torsh_distributed/communication/
serialization.rs1#![allow(dead_code)]
8use crate::{TorshDistributedError, TorshResult};
9use serde::{Deserialize, Serialize};
10use torsh_core::device::DeviceType;
11use torsh_core::dtype::TensorElement;
12use torsh_tensor::Tensor;
13
14pub trait CommunicationMessage: Serialize + for<'de> Deserialize<'de> + Send + Sync {}
16
17impl<T> CommunicationMessage for T where T: Serialize + for<'de> Deserialize<'de> + Send + Sync {}
19
20pub fn serialize_message<T: CommunicationMessage>(msg: &T) -> TorshResult<Vec<u8>> {
22 oxicode::serde::encode_to_vec(msg, oxicode::config::standard()).map_err(|e| {
23 TorshDistributedError::SerializationError(format!("Message serialization failed: {}", e))
24 })
25}
26
27pub fn deserialize_message<T: CommunicationMessage>(data: &[u8]) -> TorshResult<T> {
29 let (value, _): (T, usize) =
30 oxicode::serde::decode_from_slice(data, oxicode::config::standard()).map_err(|e| {
31 TorshDistributedError::SerializationError(format!(
32 "Message deserialization failed: {}",
33 e
34 ))
35 })?;
36 Ok(value)
37}
38
39#[derive(Serialize, Deserialize, Debug, Clone)]
41pub struct SerializableTensor {
42 pub data: Vec<u8>,
44 pub shape: Vec<usize>,
46 pub dtype: String,
48 pub device: DeviceType,
50 pub element_size: usize,
52}
53
54pub fn serialize_tensor<T>(tensor: &Tensor<T>) -> TorshResult<Vec<u8>>
56where
57 T: Clone + Send + Sync + 'static + TensorElement + Copy,
58{
59 let data = tensor.to_vec().map_err(|e| {
61 TorshDistributedError::SerializationError(format!("Failed to extract tensor data: {}", e))
62 })?;
63
64 let element_size = std::mem::size_of::<T>();
66 let byte_data = unsafe {
67 std::slice::from_raw_parts(data.as_ptr() as *const u8, data.len() * element_size).to_vec()
68 };
69
70 let serializable = SerializableTensor {
71 data: byte_data,
72 shape: tensor.shape().dims().to_vec(),
73 dtype: std::any::type_name::<T>().to_string(),
74 device: tensor.device(),
75 element_size,
76 };
77
78 serialize_message(&serializable)
79}
80
81pub fn deserialize_tensor<T>(data: &[u8], expected_shape: &[usize]) -> TorshResult<Tensor<T>>
83where
84 T: Clone + Send + Sync + 'static + torsh_core::dtype::TensorElement,
85{
86 let serializable: SerializableTensor = deserialize_message(data)?;
87
88 if serializable.shape != expected_shape {
90 return Err(TorshDistributedError::TensorShapeMismatch {
91 expected: expected_shape.to_vec(),
92 actual: serializable.shape,
93 });
94 }
95
96 let expected_element_size = std::mem::size_of::<T>();
98 if serializable.element_size != expected_element_size {
99 return Err(TorshDistributedError::SerializationError(format!(
100 "Element size mismatch: expected {}, got {}",
101 expected_element_size, serializable.element_size
102 )));
103 }
104
105 let element_count = serializable.data.len() / serializable.element_size;
107 let typed_data = unsafe {
108 std::slice::from_raw_parts(serializable.data.as_ptr() as *const T, element_count).to_vec()
109 };
110
111 Tensor::from_data(typed_data, serializable.shape, serializable.device).map_err(|e| {
113 TorshDistributedError::SerializationError(format!(
114 "Failed to create tensor from data: {}",
115 e
116 ))
117 })
118}
119
120pub fn estimate_tensor_serialized_size<T>(tensor: &Tensor<T>) -> usize
122where
123 T: 'static + TensorElement + Copy,
124{
125 let element_size = std::mem::size_of::<T>();
126 let data_size = tensor.numel() * element_size;
127 let metadata_overhead = 256; data_size + metadata_overhead
129}
130
131#[cfg(feature = "compression")]
133pub fn compress_tensor_data(data: Vec<u8>) -> TorshResult<Vec<u8>> {
134 use oxiarc_deflate::gzip::gzip_compress;
135
136 gzip_compress(&data, 1).map_err(|e| {
138 TorshDistributedError::SerializationError(format!("Compression failed: {}", e))
139 })
140}
141
142#[cfg(feature = "compression")]
144pub fn decompress_tensor_data(compressed_data: &[u8]) -> TorshResult<Vec<u8>> {
145 use oxiarc_deflate::gzip::gzip_decompress;
146
147 gzip_decompress(compressed_data).map_err(|e| {
148 TorshDistributedError::SerializationError(format!("Decompression failed: {}", e))
149 })
150}
151
152#[cfg(test)]
153mod tests {
154 use super::*;
155 use serde::{Deserialize, Serialize};
156 use torsh_tensor::creation::zeros;
157
158 #[derive(Serialize, Deserialize, Debug, PartialEq)]
159 struct TestMessage {
160 id: u32,
161 content: String,
162 }
163
164 #[test]
165 fn test_message_serialization() {
166 let msg = TestMessage {
167 id: 42,
168 content: "test message".to_string(),
169 };
170
171 let serialized = serialize_message(&msg).expect("serialize message should succeed");
172 let deserialized: TestMessage =
173 deserialize_message(&serialized).expect("deserialize message should succeed");
174
175 assert_eq!(msg, deserialized);
176 }
177
178 #[test]
179 fn test_tensor_serialization() {
180 let tensor = zeros::<f32>(&[2, 3]).expect("operation should succeed");
181 let binding = tensor.shape();
182 let shape = binding.dims();
183
184 let serialized = serialize_tensor(&tensor).expect("serialize tensor should succeed");
185 let deserialized: Tensor<f32> =
186 deserialize_tensor(&serialized, shape).expect("deserialize tensor should succeed");
187
188 assert_eq!(tensor.shape().dims(), deserialized.shape().dims());
189 assert_eq!(tensor.device(), deserialized.device());
190 }
191
192 #[test]
193 fn test_tensor_shape_mismatch() {
194 let tensor = zeros::<f32>(&[2, 3]).expect("operation should succeed");
195 let serialized = serialize_tensor(&tensor).expect("serialize tensor should succeed");
196
197 let result: Result<Tensor<f32>, _> = deserialize_tensor(&serialized, &[3, 2]);
199 assert!(result.is_err());
200 }
201
202 #[test]
203 fn test_estimate_tensor_size() {
204 let tensor = zeros::<f32>(&[10, 10]).expect("operation should succeed");
205 let estimated_size = estimate_tensor_serialized_size(&tensor);
206
207 let expected_min_size = 100 * std::mem::size_of::<f32>();
209 assert!(estimated_size >= expected_min_size);
210 }
211}