train_station/tensor/init/
data.rs

1//! Data-based tensor initialization methods
2//!
3//! This module provides methods to create tensors from existing data sources.
4//! All methods validate data compatibility and perform efficient memory copying
5//! with optimized performance characteristics.
6//!
7//! # Key Features
8//!
9//! - **`from_slice`**: Create tensors from slices of f32 data
10//! - **Data validation**: Automatic size and compatibility checking
11//! - **Efficient copying**: Optimized memory operations for performance
12//! - **Error handling**: Clear error messages for validation failures
13//! - **Multi-dimensional support**: Support for 1D, 2D, 3D, and higher-dimensional tensors
14//! - **Zero-sized handling**: Proper handling of empty tensors
15//!
16//! # Performance Characteristics
17//!
18//! - **Memory Copy**: Efficient non-overlapping copy using SIMD when possible
19//! - **Validation**: Fast size validation before allocation
20//! - **Alignment**: Proper memory alignment for optimal performance
21//! - **Large Data**: Optimized handling of large datasets
22//! - **Zero Overhead**: Minimal validation overhead for correct data
23//!
24//! # Examples
25//!
26//! ## Basic Data Initialization
27//!
28//! ```
29//! use train_station::Tensor;
30//!
31//! // Create tensor from slice data
32//! let data = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
33//! let tensor = Tensor::from_slice(&data, vec![2, 3]).unwrap();
34//!
35//! assert_eq!(tensor.size(), 6);
36//! assert_eq!(tensor.shape().dims, vec![2, 3]);
37//!
38//! // Verify data was copied correctly
39//! assert_eq!(tensor.get(&[0, 0]), 1.0);
40//! assert_eq!(tensor.get(&[1, 2]), 6.0);
41//! ```
42//!
43//! ## Multi-Dimensional Tensors
44//!
45//! ```
46//! use train_station::Tensor;
47//!
48//! // 1D tensor
49//! let data_1d = [1.0, 2.0, 3.0];
50//! let tensor_1d = Tensor::from_slice(&data_1d, vec![3]).unwrap();
51//! assert_eq!(tensor_1d.shape().dims, vec![3]);
52//! assert_eq!(tensor_1d.get(&[1]), 2.0);
53//!
54//! // 3D tensor
55//! let data_3d = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
56//! let tensor_3d = Tensor::from_slice(&data_3d, vec![2, 2, 2]).unwrap();
57//! assert_eq!(tensor_3d.shape().dims, vec![2, 2, 2]);
58//! assert_eq!(tensor_3d.get(&[0, 0, 0]), 1.0);
59//! assert_eq!(tensor_3d.get(&[1, 1, 1]), 8.0);
60//! ```
61//!
62//! ## Error Handling
63//!
64//! ```
65//! use train_station::Tensor;
66//!
67//! // Size mismatch error
68//! let data = [1.0, 2.0, 3.0];
69//! let result = Tensor::from_slice(&data, vec![2, 2]);
70//! assert!(result.is_err());
71//! let err = result.unwrap_err();
72//! assert!(err.contains("Data size 3 doesn't match shape size 4"));
73//! ```
74//!
75//! ## Zero-Sized Tensors
76//!
77//! ```
78//! use train_station::Tensor;
79//!
80//! // Handle empty tensors gracefully
81//! let data: [f32; 0] = [];
82//! let tensor = Tensor::from_slice(&data, vec![0]).unwrap();
83//! assert_eq!(tensor.size(), 0);
84//! assert_eq!(tensor.shape().dims, vec![0]);
85//! ```
86//!
87//! ## Large Data Sets
88//!
89//! ```
90//! use train_station::Tensor;
91//!
92//! // Efficient handling of large datasets
93//! let size = 1000;
94//! let data: Vec<f32> = (0..size).map(|i| i as f32).collect();
95//! let tensor = Tensor::from_slice(&data, vec![size]).unwrap();
96//!
97//! assert_eq!(tensor.size(), size);
98//! assert_eq!(tensor.get(&[0]), 0.0);
99//! assert_eq!(tensor.get(&[100]), 100.0);
100//! assert_eq!(tensor.get(&[999]), 999.0);
101//! ```
102//!
103//! # Design Principles
104//!
105//! - **Data Safety**: Comprehensive validation of data compatibility
106//! - **Performance First**: Optimized memory operations for maximum speed
107//! - **Error Clarity**: Clear and descriptive error messages
108//! - **Memory Efficiency**: Efficient copying with minimal overhead
109//! - **Type Safety**: Strong typing for all data operations
110//! - **Zero-Cost Validation**: Minimal overhead for correct data
111
112use crate::tensor::core::Tensor;
113
114impl Tensor {
115    /// Creates a tensor from a slice of data
116    ///
117    /// Creates a new tensor with the specified shape and copies data from the
118    /// provided slice. Validates that the data size matches the tensor shape
119    /// before performing the copy operation.
120    ///
121    /// This method provides an efficient way to create tensors from existing
122    /// data sources while ensuring data integrity and proper memory management.
123    ///
124    /// # Arguments
125    ///
126    /// * `data` - Slice of f32 values to copy into the tensor
127    /// * `shape_dims` - Vector of dimension sizes defining the tensor shape
128    ///
129    /// # Returns
130    ///
131    /// * `Ok(Tensor)` - Successfully created tensor with copied data
132    /// * `Err(String)` - Error if data size doesn't match shape
133    ///
134    /// # Performance
135    ///
136    /// - **Memory Copy**: Efficient non-overlapping copy using SIMD when possible
137    /// - **Validation**: Fast size validation before allocation
138    /// - **Alignment**: Proper memory alignment for optimal performance
139    /// - **Large Data**: Optimized handling of large datasets
140    ///
141    /// # Examples
142    ///
143    /// ## Basic Usage
144    ///
145    /// ```
146    /// use train_station::Tensor;
147    ///
148    /// let data = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
149    /// let tensor = Tensor::from_slice(&data, vec![2, 3]).unwrap();
150    /// assert_eq!(tensor.size(), 6);
151    /// assert_eq!(tensor.get(&[0, 0]), 1.0);
152    /// assert_eq!(tensor.get(&[1, 2]), 6.0);
153    /// ```
154    ///
155    /// ## Multi-Dimensional Data
156    ///
157    /// ```
158    /// use train_station::Tensor;
159    ///
160    /// // 1D tensor
161    /// let data_1d = [1.0, 2.0, 3.0];
162    /// let tensor_1d = Tensor::from_slice(&data_1d, vec![3]).unwrap();
163    /// assert_eq!(tensor_1d.shape().dims, vec![3]);
164    /// assert_eq!(tensor_1d.get(&[1]), 2.0);
165    ///
166    /// // 3D tensor
167    /// let data_3d = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
168    /// let tensor_3d = Tensor::from_slice(&data_3d, vec![2, 2, 2]).unwrap();
169    /// assert_eq!(tensor_3d.shape().dims, vec![2, 2, 2]);
170    /// assert_eq!(tensor_3d.get(&[0, 0, 0]), 1.0);
171    /// assert_eq!(tensor_3d.get(&[1, 1, 1]), 8.0);
172    /// ```
173    ///
174    /// ## Error Handling
175    ///
176    /// ```
177    /// use train_station::Tensor;
178    ///
179    /// // Size mismatch error
180    /// let data = [1.0, 2.0, 3.0];
181    /// let result = Tensor::from_slice(&data, vec![2, 2]);
182    /// assert!(result.is_err());
183    /// let err = result.unwrap_err();
184    /// assert!(err.contains("Data size 3 doesn't match shape size 4"));
185    /// ```
186    ///
187    /// ## Zero-Sized Tensors
188    ///
189    /// ```
190    /// use train_station::Tensor;
191    ///
192    /// // Handle empty tensors gracefully
193    /// let data: [f32; 0] = [];
194    /// let tensor = Tensor::from_slice(&data, vec![0]).unwrap();
195    /// assert_eq!(tensor.size(), 0);
196    /// assert_eq!(tensor.shape().dims, vec![0]);
197    /// ```
198    ///
199    /// ## Large Data Sets
200    ///
201    /// ```
202    /// use train_station::Tensor;
203    ///
204    /// // Efficient handling of large datasets
205    /// let size = 1000;
206    /// let data: Vec<f32> = (0..size).map(|i| i as f32).collect();
207    /// let tensor = Tensor::from_slice(&data, vec![size]).unwrap();
208    ///
209    /// assert_eq!(tensor.size(), size);
210    /// assert_eq!(tensor.get(&[0]), 0.0);
211    /// assert_eq!(tensor.get(&[100]), 100.0);
212    /// assert_eq!(tensor.get(&[999]), 999.0);
213    /// ```
214    ///
215    /// # Implementation Details
216    ///
217    /// This method performs the following steps:
218    /// 1. **Shape Validation**: Creates a Shape object and validates dimensions
219    /// 2. **Size Check**: Ensures data length matches the calculated tensor size
220    /// 3. **Memory Allocation**: Allocates tensor memory with proper alignment
221    /// 4. **Data Copy**: Uses efficient non-overlapping memory copy operation
222    /// 5. **Return**: Returns the created tensor or descriptive error message
223    ///
224    /// The memory copy operation uses `std::ptr::copy_nonoverlapping` for
225    /// maximum performance and safety, ensuring no data corruption occurs
226    /// during the copy process.
227    #[track_caller]
228    pub fn from_slice(data: &[f32], shape_dims: Vec<usize>) -> Result<Self, String> {
229        let shape = crate::tensor::Shape::new(shape_dims);
230
231        if data.len() != shape.size {
232            return Err(format!(
233                "Data size {} doesn't match shape size {}",
234                data.len(),
235                shape.size
236            ));
237        }
238
239        let mut tensor = Self::new(shape.dims.clone());
240
241        // Copy data into tensor using efficient non-overlapping copy
242        unsafe {
243            let dst = tensor.as_mut_ptr();
244            std::ptr::copy_nonoverlapping(data.as_ptr(), dst, data.len());
245        }
246
247        Ok(tensor)
248    }
249}
250
251#[cfg(test)]
252mod tests {
253    use super::*;
254
255    #[test]
256    fn test_from_slice_basic() {
257        let data = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
258        let tensor = Tensor::from_slice(&data, vec![2, 3]).unwrap();
259
260        assert_eq!(tensor.size(), 6);
261        assert_eq!(tensor.shape().dims, vec![2, 3]);
262
263        // Verify data was copied correctly
264        assert_eq!(tensor.get(&[0, 0]), 1.0);
265        assert_eq!(tensor.get(&[0, 1]), 2.0);
266        assert_eq!(tensor.get(&[0, 2]), 3.0);
267        assert_eq!(tensor.get(&[1, 0]), 4.0);
268        assert_eq!(tensor.get(&[1, 1]), 5.0);
269        assert_eq!(tensor.get(&[1, 2]), 6.0);
270    }
271
272    #[test]
273    fn test_from_slice_1d() {
274        let data = [1.0, 2.0, 3.0];
275        let tensor = Tensor::from_slice(&data, vec![3]).unwrap();
276
277        assert_eq!(tensor.size(), 3);
278        assert_eq!(tensor.shape().dims, vec![3]);
279
280        assert_eq!(tensor.get(&[0]), 1.0);
281        assert_eq!(tensor.get(&[1]), 2.0);
282        assert_eq!(tensor.get(&[2]), 3.0);
283    }
284
285    #[test]
286    fn test_from_slice_3d() {
287        let data = [1.0, 2.0, 3.0, 4.0, 5.0, 8.0, 7.0, 8.0];
288        let tensor = Tensor::from_slice(&data, vec![2, 2, 2]).unwrap();
289
290        assert_eq!(tensor.size(), 8);
291        assert_eq!(tensor.shape().dims, vec![2, 2, 2]);
292
293        // Verify 3D indexing
294        assert_eq!(tensor.get(&[0, 0, 0]), 1.0);
295        assert_eq!(tensor.get(&[0, 0, 1]), 2.0);
296        assert_eq!(tensor.get(&[0, 1, 0]), 3.0);
297        assert_eq!(tensor.get(&[0, 1, 1]), 4.0);
298        assert_eq!(tensor.get(&[1, 0, 0]), 5.0);
299        assert_eq!(tensor.get(&[1, 0, 1]), 8.0);
300        assert_eq!(tensor.get(&[1, 1, 0]), 7.0);
301        assert_eq!(tensor.get(&[1, 1, 1]), 8.0);
302    }
303
304    #[test]
305    fn test_from_slice_size_mismatch() {
306        let data = [1.0, 2.0, 3.0];
307        let result = Tensor::from_slice(&data, vec![2, 2]);
308
309        assert!(result.is_err());
310        let err = result.unwrap_err();
311        assert!(err.contains("Data size 3 doesn't match shape size 4"));
312    }
313
314    #[test]
315    fn test_from_slice_empty() {
316        let data: [f32; 0] = [];
317        let tensor = Tensor::from_slice(&data, vec![0]).unwrap();
318
319        assert_eq!(tensor.size(), 0);
320        assert_eq!(tensor.shape().dims, vec![0]);
321    }
322
323    #[test]
324    fn test_from_slice_large_data() {
325        let size = 1000;
326        let data: Vec<f32> = (0..size).map(|i| i as f32).collect();
327        let tensor = Tensor::from_slice(&data, vec![size]).unwrap();
328
329        assert_eq!(tensor.size(), size);
330
331        // Verify a few values
332        assert_eq!(tensor.get(&[0]), 0.0);
333        assert_eq!(tensor.get(&[100]), 100.0);
334        assert_eq!(tensor.get(&[999]), 999.0);
335    }
336}