Skip to main content

torsh_distributed/communication/
serialization.rs

1//! Unified serialization utilities for communication
2//!
3//! This module consolidates tensor and message serialization patterns
4//! used across RPC, parameter server, and collective operations.
5
6// Framework infrastructure - components designed for future use
7#![allow(dead_code)]
8use crate::{TorshDistributedError, TorshResult};
9use serde::{Deserialize, Serialize};
10use torsh_core::device::DeviceType;
11use torsh_core::dtype::TensorElement;
12use torsh_tensor::Tensor;
13
14/// Trait for messages that can be serialized for communication
15pub trait CommunicationMessage: Serialize + for<'de> Deserialize<'de> + Send + Sync {}
16
17// Implement for common message types
18impl<T> CommunicationMessage for T where T: Serialize + for<'de> Deserialize<'de> + Send + Sync {}
19
20/// Serialize a message for communication
21pub fn serialize_message<T: CommunicationMessage>(msg: &T) -> TorshResult<Vec<u8>> {
22    oxicode::serde::encode_to_vec(msg, oxicode::config::standard()).map_err(|e| {
23        TorshDistributedError::SerializationError(format!("Message serialization failed: {}", e))
24    })
25}
26
27/// Deserialize a message from communication
28pub fn deserialize_message<T: CommunicationMessage>(data: &[u8]) -> TorshResult<T> {
29    let (value, _): (T, usize) =
30        oxicode::serde::decode_from_slice(data, oxicode::config::standard()).map_err(|e| {
31            TorshDistributedError::SerializationError(format!(
32                "Message deserialization failed: {}",
33                e
34            ))
35        })?;
36    Ok(value)
37}
38
39/// Serializable tensor representation for communication
40#[derive(Serialize, Deserialize, Debug, Clone)]
41pub struct SerializableTensor {
42    /// Raw tensor data as bytes
43    pub data: Vec<u8>,
44    /// Tensor shape
45    pub shape: Vec<usize>,
46    /// Data type identifier
47    pub dtype: String,
48    /// Device type
49    pub device: DeviceType,
50    /// Element size in bytes
51    pub element_size: usize,
52}
53
54/// Serialize a tensor for communication
55pub fn serialize_tensor<T>(tensor: &Tensor<T>) -> TorshResult<Vec<u8>>
56where
57    T: Clone + Send + Sync + 'static + TensorElement + Copy,
58{
59    // Get tensor data
60    let data = tensor.to_vec().map_err(|e| {
61        TorshDistributedError::SerializationError(format!("Failed to extract tensor data: {}", e))
62    })?;
63
64    // Convert to bytes
65    let element_size = std::mem::size_of::<T>();
66    let byte_data = unsafe {
67        std::slice::from_raw_parts(data.as_ptr() as *const u8, data.len() * element_size).to_vec()
68    };
69
70    let serializable = SerializableTensor {
71        data: byte_data,
72        shape: tensor.shape().dims().to_vec(),
73        dtype: std::any::type_name::<T>().to_string(),
74        device: tensor.device(),
75        element_size,
76    };
77
78    serialize_message(&serializable)
79}
80
81/// Deserialize a tensor from communication data
82pub fn deserialize_tensor<T>(data: &[u8], expected_shape: &[usize]) -> TorshResult<Tensor<T>>
83where
84    T: Clone + Send + Sync + 'static + torsh_core::dtype::TensorElement,
85{
86    let serializable: SerializableTensor = deserialize_message(data)?;
87
88    // Validate shape matches expectation
89    if serializable.shape != expected_shape {
90        return Err(TorshDistributedError::TensorShapeMismatch {
91            expected: expected_shape.to_vec(),
92            actual: serializable.shape,
93        });
94    }
95
96    // Validate element size
97    let expected_element_size = std::mem::size_of::<T>();
98    if serializable.element_size != expected_element_size {
99        return Err(TorshDistributedError::SerializationError(format!(
100            "Element size mismatch: expected {}, got {}",
101            expected_element_size, serializable.element_size
102        )));
103    }
104
105    // Convert bytes back to typed data
106    let element_count = serializable.data.len() / serializable.element_size;
107    let typed_data = unsafe {
108        std::slice::from_raw_parts(serializable.data.as_ptr() as *const T, element_count).to_vec()
109    };
110
111    // Create tensor
112    Tensor::from_data(typed_data, serializable.shape, serializable.device).map_err(|e| {
113        TorshDistributedError::SerializationError(format!(
114            "Failed to create tensor from data: {}",
115            e
116        ))
117    })
118}
119
120/// Estimate serialized size of a tensor without actually serializing
121pub fn estimate_tensor_serialized_size<T>(tensor: &Tensor<T>) -> usize
122where
123    T: 'static + TensorElement + Copy,
124{
125    let element_size = std::mem::size_of::<T>();
126    let data_size = tensor.numel() * element_size;
127    let metadata_overhead = 256; // Rough estimate for shape, dtype, etc.
128    data_size + metadata_overhead
129}
130
131/// Compress tensor data for communication (optional optimization)
132#[cfg(feature = "compression")]
133pub fn compress_tensor_data(data: Vec<u8>) -> TorshResult<Vec<u8>> {
134    use flate2::write::GzEncoder;
135    use flate2::Compression;
136    use std::io::Write;
137
138    let mut encoder = GzEncoder::new(Vec::new(), Compression::fast());
139    encoder.write_all(&data).map_err(|e| {
140        TorshDistributedError::SerializationError(format!("Compression failed: {}", e))
141    })?;
142
143    encoder.finish().map_err(|e| {
144        TorshDistributedError::SerializationError(format!("Compression finalization failed: {}", e))
145    })
146}
147
148/// Decompress tensor data from communication
149#[cfg(feature = "compression")]
150pub fn decompress_tensor_data(compressed_data: &[u8]) -> TorshResult<Vec<u8>> {
151    use flate2::read::GzDecoder;
152    use std::io::Read;
153
154    let mut decoder = GzDecoder::new(compressed_data);
155    let mut decompressed = Vec::new();
156    decoder.read_to_end(&mut decompressed).map_err(|e| {
157        TorshDistributedError::SerializationError(format!("Decompression failed: {}", e))
158    })?;
159
160    Ok(decompressed)
161}
162
163#[cfg(test)]
164mod tests {
165    use super::*;
166    use serde::{Deserialize, Serialize};
167    use torsh_tensor::creation::zeros;
168
169    #[derive(Serialize, Deserialize, Debug, PartialEq)]
170    struct TestMessage {
171        id: u32,
172        content: String,
173    }
174
175    #[test]
176    fn test_message_serialization() {
177        let msg = TestMessage {
178            id: 42,
179            content: "test message".to_string(),
180        };
181
182        let serialized = serialize_message(&msg).unwrap();
183        let deserialized: TestMessage = deserialize_message(&serialized).unwrap();
184
185        assert_eq!(msg, deserialized);
186    }
187
188    #[test]
189    fn test_tensor_serialization() {
190        let tensor = zeros::<f32>(&[2, 3]).unwrap();
191        let binding = tensor.shape();
192        let shape = binding.dims();
193
194        let serialized = serialize_tensor(&tensor).unwrap();
195        let deserialized: Tensor<f32> = deserialize_tensor(&serialized, shape).unwrap();
196
197        assert_eq!(tensor.shape().dims(), deserialized.shape().dims());
198        assert_eq!(tensor.device(), deserialized.device());
199    }
200
201    #[test]
202    fn test_tensor_shape_mismatch() {
203        let tensor = zeros::<f32>(&[2, 3]).unwrap();
204        let serialized = serialize_tensor(&tensor).unwrap();
205
206        // Try to deserialize with wrong shape
207        let result: Result<Tensor<f32>, _> = deserialize_tensor(&serialized, &[3, 2]);
208        assert!(result.is_err());
209    }
210
211    #[test]
212    fn test_estimate_tensor_size() {
213        let tensor = zeros::<f32>(&[10, 10]).unwrap();
214        let estimated_size = estimate_tensor_serialized_size(&tensor);
215
216        // Should be at least the data size plus some overhead
217        let expected_min_size = 100 * std::mem::size_of::<f32>();
218        assert!(estimated_size >= expected_min_size);
219    }
220}