train_station/tensor/init/data.rs
1//! Data-based tensor initialization methods
2//!
3//! This module provides methods to create tensors from existing data sources.
4//! All methods validate data compatibility and perform efficient memory copying
5//! with optimized performance characteristics.
6//!
7//! # Key Features
8//!
9//! - **`from_slice`**: Create tensors from slices of f32 data
10//! - **Data validation**: Automatic size and compatibility checking
11//! - **Efficient copying**: Optimized memory operations for performance
12//! - **Error handling**: Clear error messages for validation failures
13//! - **Multi-dimensional support**: Support for 1D, 2D, 3D, and higher-dimensional tensors
14//! - **Zero-sized handling**: Proper handling of empty tensors
15//!
16//! # Performance Characteristics
17//!
18//! - **Memory Copy**: Efficient non-overlapping copy using SIMD when possible
19//! - **Validation**: Fast size validation before allocation
20//! - **Alignment**: Proper memory alignment for optimal performance
21//! - **Large Data**: Optimized handling of large datasets
22//! - **Zero Overhead**: Minimal validation overhead for correct data
23//!
24//! # Examples
25//!
26//! ## Basic Data Initialization
27//!
28//! ```
29//! use train_station::Tensor;
30//!
31//! // Create tensor from slice data
32//! let data = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
33//! let tensor = Tensor::from_slice(&data, vec![2, 3]).unwrap();
34//!
35//! assert_eq!(tensor.size(), 6);
36//! assert_eq!(tensor.shape().dims, vec![2, 3]);
37//!
38//! // Verify data was copied correctly
39//! assert_eq!(tensor.get(&[0, 0]), 1.0);
40//! assert_eq!(tensor.get(&[1, 2]), 6.0);
41//! ```
42//!
43//! ## Multi-Dimensional Tensors
44//!
45//! ```
46//! use train_station::Tensor;
47//!
48//! // 1D tensor
49//! let data_1d = [1.0, 2.0, 3.0];
50//! let tensor_1d = Tensor::from_slice(&data_1d, vec![3]).unwrap();
51//! assert_eq!(tensor_1d.shape().dims, vec![3]);
52//! assert_eq!(tensor_1d.get(&[1]), 2.0);
53//!
54//! // 3D tensor
55//! let data_3d = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
56//! let tensor_3d = Tensor::from_slice(&data_3d, vec![2, 2, 2]).unwrap();
57//! assert_eq!(tensor_3d.shape().dims, vec![2, 2, 2]);
58//! assert_eq!(tensor_3d.get(&[0, 0, 0]), 1.0);
59//! assert_eq!(tensor_3d.get(&[1, 1, 1]), 8.0);
60//! ```
61//!
62//! ## Error Handling
63//!
64//! ```
65//! use train_station::Tensor;
66//!
67//! // Size mismatch error
68//! let data = [1.0, 2.0, 3.0];
69//! let result = Tensor::from_slice(&data, vec![2, 2]);
70//! assert!(result.is_err());
71//! let err = result.unwrap_err();
72//! assert!(err.contains("Data size 3 doesn't match shape size 4"));
73//! ```
74//!
75//! ## Zero-Sized Tensors
76//!
77//! ```
78//! use train_station::Tensor;
79//!
80//! // Handle empty tensors gracefully
81//! let data: [f32; 0] = [];
82//! let tensor = Tensor::from_slice(&data, vec![0]).unwrap();
83//! assert_eq!(tensor.size(), 0);
84//! assert_eq!(tensor.shape().dims, vec![0]);
85//! ```
86//!
87//! ## Large Data Sets
88//!
89//! ```
90//! use train_station::Tensor;
91//!
92//! // Efficient handling of large datasets
93//! let size = 1000;
94//! let data: Vec<f32> = (0..size).map(|i| i as f32).collect();
95//! let tensor = Tensor::from_slice(&data, vec![size]).unwrap();
96//!
97//! assert_eq!(tensor.size(), size);
98//! assert_eq!(tensor.get(&[0]), 0.0);
99//! assert_eq!(tensor.get(&[100]), 100.0);
100//! assert_eq!(tensor.get(&[999]), 999.0);
101//! ```
102//!
103//! # Design Principles
104//!
105//! - **Data Safety**: Comprehensive validation of data compatibility
106//! - **Performance First**: Optimized memory operations for maximum speed
107//! - **Error Clarity**: Clear and descriptive error messages
108//! - **Memory Efficiency**: Efficient copying with minimal overhead
109//! - **Type Safety**: Strong typing for all data operations
110//! - **Zero-Cost Validation**: Minimal overhead for correct data
111
112use crate::tensor::core::Tensor;
113
114impl Tensor {
115 /// Creates a tensor from a slice of data
116 ///
117 /// Creates a new tensor with the specified shape and copies data from the
118 /// provided slice. Validates that the data size matches the tensor shape
119 /// before performing the copy operation.
120 ///
121 /// This method provides an efficient way to create tensors from existing
122 /// data sources while ensuring data integrity and proper memory management.
123 ///
124 /// # Arguments
125 ///
126 /// * `data` - Slice of f32 values to copy into the tensor
127 /// * `shape_dims` - Vector of dimension sizes defining the tensor shape
128 ///
129 /// # Returns
130 ///
131 /// * `Ok(Tensor)` - Successfully created tensor with copied data
132 /// * `Err(String)` - Error if data size doesn't match shape
133 ///
134 /// # Performance
135 ///
136 /// - **Memory Copy**: Efficient non-overlapping copy using SIMD when possible
137 /// - **Validation**: Fast size validation before allocation
138 /// - **Alignment**: Proper memory alignment for optimal performance
139 /// - **Large Data**: Optimized handling of large datasets
140 ///
141 /// # Examples
142 ///
143 /// ## Basic Usage
144 ///
145 /// ```
146 /// use train_station::Tensor;
147 ///
148 /// let data = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
149 /// let tensor = Tensor::from_slice(&data, vec![2, 3]).unwrap();
150 /// assert_eq!(tensor.size(), 6);
151 /// assert_eq!(tensor.get(&[0, 0]), 1.0);
152 /// assert_eq!(tensor.get(&[1, 2]), 6.0);
153 /// ```
154 ///
155 /// ## Multi-Dimensional Data
156 ///
157 /// ```
158 /// use train_station::Tensor;
159 ///
160 /// // 1D tensor
161 /// let data_1d = [1.0, 2.0, 3.0];
162 /// let tensor_1d = Tensor::from_slice(&data_1d, vec![3]).unwrap();
163 /// assert_eq!(tensor_1d.shape().dims, vec![3]);
164 /// assert_eq!(tensor_1d.get(&[1]), 2.0);
165 ///
166 /// // 3D tensor
167 /// let data_3d = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
168 /// let tensor_3d = Tensor::from_slice(&data_3d, vec![2, 2, 2]).unwrap();
169 /// assert_eq!(tensor_3d.shape().dims, vec![2, 2, 2]);
170 /// assert_eq!(tensor_3d.get(&[0, 0, 0]), 1.0);
171 /// assert_eq!(tensor_3d.get(&[1, 1, 1]), 8.0);
172 /// ```
173 ///
174 /// ## Error Handling
175 ///
176 /// ```
177 /// use train_station::Tensor;
178 ///
179 /// // Size mismatch error
180 /// let data = [1.0, 2.0, 3.0];
181 /// let result = Tensor::from_slice(&data, vec![2, 2]);
182 /// assert!(result.is_err());
183 /// let err = result.unwrap_err();
184 /// assert!(err.contains("Data size 3 doesn't match shape size 4"));
185 /// ```
186 ///
187 /// ## Zero-Sized Tensors
188 ///
189 /// ```
190 /// use train_station::Tensor;
191 ///
192 /// // Handle empty tensors gracefully
193 /// let data: [f32; 0] = [];
194 /// let tensor = Tensor::from_slice(&data, vec![0]).unwrap();
195 /// assert_eq!(tensor.size(), 0);
196 /// assert_eq!(tensor.shape().dims, vec![0]);
197 /// ```
198 ///
199 /// ## Large Data Sets
200 ///
201 /// ```
202 /// use train_station::Tensor;
203 ///
204 /// // Efficient handling of large datasets
205 /// let size = 1000;
206 /// let data: Vec<f32> = (0..size).map(|i| i as f32).collect();
207 /// let tensor = Tensor::from_slice(&data, vec![size]).unwrap();
208 ///
209 /// assert_eq!(tensor.size(), size);
210 /// assert_eq!(tensor.get(&[0]), 0.0);
211 /// assert_eq!(tensor.get(&[100]), 100.0);
212 /// assert_eq!(tensor.get(&[999]), 999.0);
213 /// ```
214 ///
215 /// # Implementation Details
216 ///
217 /// This method performs the following steps:
218 /// 1. **Shape Validation**: Creates a Shape object and validates dimensions
219 /// 2. **Size Check**: Ensures data length matches the calculated tensor size
220 /// 3. **Memory Allocation**: Allocates tensor memory with proper alignment
221 /// 4. **Data Copy**: Uses efficient non-overlapping memory copy operation
222 /// 5. **Return**: Returns the created tensor or descriptive error message
223 ///
224 /// The memory copy operation uses `std::ptr::copy_nonoverlapping` for
225 /// maximum performance and safety, ensuring no data corruption occurs
226 /// during the copy process.
227 pub fn from_slice(data: &[f32], shape_dims: Vec<usize>) -> Result<Self, String> {
228 let shape = crate::tensor::Shape::new(shape_dims);
229
230 if data.len() != shape.size {
231 return Err(format!(
232 "Data size {} doesn't match shape size {}",
233 data.len(),
234 shape.size
235 ));
236 }
237
238 let mut tensor = Self::new(shape.dims.clone());
239
240 // Copy data into tensor using efficient non-overlapping copy
241 unsafe {
242 let dst = tensor.as_mut_ptr();
243 std::ptr::copy_nonoverlapping(data.as_ptr(), dst, data.len());
244 }
245
246 Ok(tensor)
247 }
248}
249
250#[cfg(test)]
251mod tests {
252 use super::*;
253
254 #[test]
255 fn test_from_slice_basic() {
256 let data = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
257 let tensor = Tensor::from_slice(&data, vec![2, 3]).unwrap();
258
259 assert_eq!(tensor.size(), 6);
260 assert_eq!(tensor.shape().dims, vec![2, 3]);
261
262 // Verify data was copied correctly
263 assert_eq!(tensor.get(&[0, 0]), 1.0);
264 assert_eq!(tensor.get(&[0, 1]), 2.0);
265 assert_eq!(tensor.get(&[0, 2]), 3.0);
266 assert_eq!(tensor.get(&[1, 0]), 4.0);
267 assert_eq!(tensor.get(&[1, 1]), 5.0);
268 assert_eq!(tensor.get(&[1, 2]), 6.0);
269 }
270
271 #[test]
272 fn test_from_slice_1d() {
273 let data = [1.0, 2.0, 3.0];
274 let tensor = Tensor::from_slice(&data, vec![3]).unwrap();
275
276 assert_eq!(tensor.size(), 3);
277 assert_eq!(tensor.shape().dims, vec![3]);
278
279 assert_eq!(tensor.get(&[0]), 1.0);
280 assert_eq!(tensor.get(&[1]), 2.0);
281 assert_eq!(tensor.get(&[2]), 3.0);
282 }
283
284 #[test]
285 fn test_from_slice_3d() {
286 let data = [1.0, 2.0, 3.0, 4.0, 5.0, 8.0, 7.0, 8.0];
287 let tensor = Tensor::from_slice(&data, vec![2, 2, 2]).unwrap();
288
289 assert_eq!(tensor.size(), 8);
290 assert_eq!(tensor.shape().dims, vec![2, 2, 2]);
291
292 // Verify 3D indexing
293 assert_eq!(tensor.get(&[0, 0, 0]), 1.0);
294 assert_eq!(tensor.get(&[0, 0, 1]), 2.0);
295 assert_eq!(tensor.get(&[0, 1, 0]), 3.0);
296 assert_eq!(tensor.get(&[0, 1, 1]), 4.0);
297 assert_eq!(tensor.get(&[1, 0, 0]), 5.0);
298 assert_eq!(tensor.get(&[1, 0, 1]), 8.0);
299 assert_eq!(tensor.get(&[1, 1, 0]), 7.0);
300 assert_eq!(tensor.get(&[1, 1, 1]), 8.0);
301 }
302
303 #[test]
304 fn test_from_slice_size_mismatch() {
305 let data = [1.0, 2.0, 3.0];
306 let result = Tensor::from_slice(&data, vec![2, 2]);
307
308 assert!(result.is_err());
309 let err = result.unwrap_err();
310 assert!(err.contains("Data size 3 doesn't match shape size 4"));
311 }
312
313 #[test]
314 fn test_from_slice_empty() {
315 let data: [f32; 0] = [];
316 let tensor = Tensor::from_slice(&data, vec![0]).unwrap();
317
318 assert_eq!(tensor.size(), 0);
319 assert_eq!(tensor.shape().dims, vec![0]);
320 }
321
322 #[test]
323 fn test_from_slice_large_data() {
324 let size = 1000;
325 let data: Vec<f32> = (0..size).map(|i| i as f32).collect();
326 let tensor = Tensor::from_slice(&data, vec![size]).unwrap();
327
328 assert_eq!(tensor.size(), size);
329
330 // Verify a few values
331 assert_eq!(tensor.get(&[0]), 0.0);
332 assert_eq!(tensor.get(&[100]), 100.0);
333 assert_eq!(tensor.get(&[999]), 999.0);
334 }
335}