Skip to main content

tenflowers_core/memory/
views.rs

1//! Zero-copy tensor operations and memory views
2//!
3//! This module provides strided tensor views for efficient reshape and transpose
4//! operations, along with memory aliasing detection for safe zero-copy operations.
5
6use crate::{Result, TensorError};
7use std::collections::HashMap;
8use std::sync::{Arc, Mutex};
9
10/// Strided tensor view for zero-copy reshape and transpose operations
11#[derive(Debug, Clone)]
12pub struct StridedView {
13    pub offset: usize,
14    pub shape: Vec<usize>,
15    pub strides: Vec<usize>,
16    pub element_size: usize,
17}
18
19impl StridedView {
20    /// Create a new strided view
21    pub fn new(offset: usize, shape: Vec<usize>, strides: Vec<usize>, element_size: usize) -> Self {
22        Self {
23            offset,
24            shape,
25            strides,
26            element_size,
27        }
28    }
29
30    /// Create a strided view for transpose operation
31    pub fn transpose(&self, axes: &[usize]) -> Result<StridedView> {
32        if axes.len() != self.shape.len() {
33            return Err(TensorError::invalid_argument(
34                "Transpose axes must match tensor dimensions".to_string(),
35            ));
36        }
37
38        let mut new_shape = Vec::new();
39        let mut new_strides = Vec::new();
40
41        for &axis in axes {
42            if axis >= self.shape.len() {
43                return Err(TensorError::invalid_argument(format!(
44                    "Axis {} out of bounds for tensor with {} dimensions",
45                    axis,
46                    self.shape.len()
47                )));
48            }
49            new_shape.push(self.shape[axis]);
50            new_strides.push(self.strides[axis]);
51        }
52
53        Ok(StridedView {
54            offset: self.offset,
55            shape: new_shape,
56            strides: new_strides,
57            element_size: self.element_size,
58        })
59    }
60
61    /// Create a strided view for reshape operation (zero-copy when possible)
62    pub fn reshape(&self, new_shape: &[usize]) -> Result<StridedView> {
63        // Check if reshape is possible without data copy
64        let total_elements: usize = self.shape.iter().product();
65        let new_total_elements: usize = new_shape.iter().product();
66
67        if total_elements != new_total_elements {
68            return Err(TensorError::invalid_argument(
69                "Cannot reshape tensor: element count mismatch".to_string(),
70            ));
71        }
72
73        // Check if tensor is contiguous
74        if self.is_contiguous() {
75            // Can reshape without copy
76            let new_strides = compute_strides(new_shape, self.element_size);
77            Ok(StridedView {
78                offset: self.offset,
79                shape: new_shape.to_vec(),
80                strides: new_strides,
81                element_size: self.element_size,
82            })
83        } else {
84            // Non-contiguous tensor requires copy for reshape
85            Err(TensorError::unsupported_operation_simple(
86                "Reshape requires data copy for non-contiguous tensor".to_string(),
87            ))
88        }
89    }
90
91    /// Check if the tensor is contiguous in memory
92    pub fn is_contiguous(&self) -> bool {
93        let expected_strides = compute_strides(&self.shape, self.element_size);
94        self.strides == expected_strides
95    }
96
97    /// Get the total size in bytes
98    pub fn size_bytes(&self) -> usize {
99        self.shape.iter().product::<usize>() * self.element_size
100    }
101
102    /// Create a slice view
103    pub fn slice(&self, ranges: &[(usize, usize)]) -> Result<StridedView> {
104        if ranges.len() != self.shape.len() {
105            return Err(TensorError::invalid_argument(
106                "Slice ranges must match tensor dimensions".to_string(),
107            ));
108        }
109
110        let mut new_offset = self.offset;
111        let mut new_shape = Vec::new();
112        let mut new_strides = Vec::new();
113
114        for (i, &(start, end)) in ranges.iter().enumerate() {
115            if start >= end || end > self.shape[i] {
116                return Err(TensorError::invalid_argument(format!(
117                    "Invalid slice range [{}, {}) for dimension {} of size {}",
118                    start, end, i, self.shape[i]
119                )));
120            }
121
122            new_offset += start * self.strides[i];
123            new_shape.push(end - start);
124            new_strides.push(self.strides[i]);
125        }
126
127        Ok(StridedView {
128            offset: new_offset,
129            shape: new_shape,
130            strides: new_strides,
131            element_size: self.element_size,
132        })
133    }
134}
135
136/// Memory aliasing detector for safe zero-copy operations
137#[derive(Debug)]
138pub struct MemoryAliasDetector {
139    #[allow(clippy::type_complexity)]
140    active_views: Arc<Mutex<HashMap<usize, Vec<(usize, usize)>>>>, // buffer_id -> [(offset, size)]
141}
142
143impl Default for MemoryAliasDetector {
144    fn default() -> Self {
145        Self::new()
146    }
147}
148
149impl MemoryAliasDetector {
150    /// Create a new memory alias detector
151    pub fn new() -> Self {
152        Self {
153            active_views: Arc::new(Mutex::new(HashMap::new())),
154        }
155    }
156
157    /// Check if a new view would create an alias
158    pub fn check_alias(&self, buffer_id: usize, offset: usize, size: usize) -> bool {
159        let active_views = self
160            .active_views
161            .lock()
162            .expect("lock should not be poisoned");
163
164        if let Some(views) = active_views.get(&buffer_id) {
165            for &(view_offset, view_size) in views {
166                // Enhanced overlap detection: two ranges [a, b) and [c, d) overlap if max(a,c) < min(b,d)
167                let start1 = offset;
168                let end1 = offset + size;
169                let start2 = view_offset;
170                let end2 = view_offset + view_size;
171
172                // Check for any overlap (including touching boundaries)
173                if start1 < end2 && start2 < end1 {
174                    return true; // Alias detected
175                }
176            }
177        }
178
179        false
180    }
181
182    /// Register a new view
183    pub fn register_view(&self, buffer_id: usize, offset: usize, size: usize) {
184        let mut active_views = self
185            .active_views
186            .lock()
187            .expect("lock should not be poisoned");
188        active_views
189            .entry(buffer_id)
190            .or_default()
191            .push((offset, size));
192    }
193
194    /// Unregister a view
195    pub fn unregister_view(&self, buffer_id: usize, offset: usize, size: usize) {
196        let mut active_views = self
197            .active_views
198            .lock()
199            .expect("lock should not be poisoned");
200        if let Some(views) = active_views.get_mut(&buffer_id) {
201            views.retain(|&(view_offset, view_size)| view_offset != offset || view_size != size);
202            if views.is_empty() {
203                active_views.remove(&buffer_id);
204            }
205        }
206    }
207
208    /// Get detailed information about potential aliases for a memory region
209    pub fn get_alias_info(
210        &self,
211        buffer_id: usize,
212        offset: usize,
213        size: usize,
214    ) -> Vec<(usize, usize, usize)> {
215        let active_views = self
216            .active_views
217            .lock()
218            .expect("lock should not be poisoned");
219        let mut aliases = Vec::new();
220
221        if let Some(views) = active_views.get(&buffer_id) {
222            for &(view_offset, view_size) in views {
223                let start1 = offset;
224                let end1 = offset + size;
225                let start2 = view_offset;
226                let end2 = view_offset + view_size;
227
228                // Check for overlap and calculate overlap region
229                if start1 < end2 && start2 < end1 {
230                    let overlap_start = std::cmp::max(start1, start2);
231                    let overlap_end = std::cmp::min(end1, end2);
232                    let overlap_size = overlap_end - overlap_start;
233                    aliases.push((overlap_start, overlap_size, view_size));
234                }
235            }
236        }
237
238        aliases
239    }
240
241    /// Check if a memory region would create partial aliases (useful for optimization decisions)
242    pub fn check_partial_alias(&self, buffer_id: usize, offset: usize, size: usize) -> bool {
243        let active_views = self
244            .active_views
245            .lock()
246            .expect("lock should not be poisoned");
247
248        if let Some(views) = active_views.get(&buffer_id) {
249            for &(view_offset, view_size) in views {
250                let start1 = offset;
251                let end1 = offset + size;
252                let start2 = view_offset;
253                let end2 = view_offset + view_size;
254
255                // Check for partial overlap (not complete containment)
256                if start1 < end2 && start2 < end1 {
257                    // Check if it's not complete containment in either direction
258                    let not_contained_in_existing = !(start1 >= start2 && end1 <= end2);
259                    let not_containing_existing = !(start2 >= start1 && end2 <= end1);
260
261                    // Only return true if NEITHER is completely contained (partial overlap)
262                    if not_contained_in_existing && not_containing_existing {
263                        return true;
264                    }
265                }
266            }
267        }
268
269        false
270    }
271
272    /// Get statistics about active memory views
273    pub fn get_alias_statistics(&self) -> (usize, usize) {
274        let active_views = self
275            .active_views
276            .lock()
277            .expect("lock should not be poisoned");
278        let total_buffers = active_views.len();
279        let total_views: usize = active_views.values().map(|v| v.len()).sum();
280        (total_buffers, total_views)
281    }
282}
283
284/// Compute strides for a given shape
285pub fn compute_strides(shape: &[usize], element_size: usize) -> Vec<usize> {
286    let mut strides = Vec::with_capacity(shape.len());
287    let mut stride = element_size;
288
289    for &dim in shape.iter().rev() {
290        strides.push(stride);
291        stride *= dim;
292    }
293
294    strides.reverse();
295    strides
296}
297
298#[cfg(test)]
299mod tests {
300    use super::*;
301
302    #[test]
303    fn test_strided_view_transpose() {
304        let view = StridedView::new(0, vec![2, 3, 4], vec![48, 16, 4], 4);
305        let transposed = view
306            .transpose(&[2, 0, 1])
307            .expect("test: transpose should succeed");
308
309        assert_eq!(transposed.shape, vec![4, 2, 3]);
310        assert_eq!(transposed.strides, vec![4, 48, 16]);
311    }
312
313    #[test]
314    fn test_strided_view_reshape() {
315        let view = StridedView::new(0, vec![2, 3, 4], vec![48, 16, 4], 4);
316        let reshaped = view.reshape(&[6, 4]).expect("test: reshape should succeed");
317
318        assert_eq!(reshaped.shape, vec![6, 4]);
319        assert_eq!(reshaped.strides, vec![16, 4]);
320    }
321
322    #[test]
323    fn test_strided_view_slice() {
324        let view = StridedView::new(0, vec![4, 4], vec![16, 4], 4);
325        let sliced = view
326            .slice(&[(1, 3), (0, 2)])
327            .expect("test: operation should succeed");
328
329        assert_eq!(sliced.shape, vec![2, 2]);
330        assert_eq!(sliced.strides, vec![16, 4]);
331        assert_eq!(sliced.offset, 16); // 1 * 16 + 0 * 4
332    }
333
334    #[test]
335    fn test_memory_alias_detector() {
336        let detector = MemoryAliasDetector::new();
337
338        // Register a view
339        detector.register_view(0, 0, 100);
340
341        // Check for alias
342        assert!(detector.check_alias(0, 50, 100)); // Overlaps
343        assert!(!detector.check_alias(0, 100, 50)); // No overlap
344
345        // Unregister
346        detector.unregister_view(0, 0, 100);
347        assert!(!detector.check_alias(0, 50, 100)); // No alias after unregister
348    }
349
350    #[test]
351    fn test_compute_strides() {
352        let strides = compute_strides(&[2, 3, 4], 4);
353        assert_eq!(strides, vec![48, 16, 4]);
354    }
355
356    #[test]
357    fn test_is_contiguous() {
358        let contiguous_view = StridedView::new(0, vec![2, 3, 4], vec![48, 16, 4], 4);
359        assert!(contiguous_view.is_contiguous());
360
361        let non_contiguous_view = StridedView::new(0, vec![2, 3, 4], vec![32, 16, 4], 4);
362        assert!(!non_contiguous_view.is_contiguous());
363    }
364
365    #[test]
366    fn test_size_bytes() {
367        let view = StridedView::new(0, vec![2, 3, 4], vec![48, 16, 4], 4);
368        assert_eq!(view.size_bytes(), 96); // 2 * 3 * 4 * 4 = 96
369    }
370
371    #[test]
372    fn test_invalid_transpose() {
373        let view = StridedView::new(0, vec![2, 3], vec![12, 4], 4);
374
375        // Wrong number of axes
376        let result = view.transpose(&[1, 0, 2]);
377        assert!(result.is_err());
378
379        // Axis out of bounds
380        let result = view.transpose(&[0, 3]);
381        assert!(result.is_err());
382    }
383
384    #[test]
385    fn test_invalid_reshape() {
386        let view = StridedView::new(0, vec![2, 3], vec![12, 4], 4);
387
388        // Element count mismatch
389        let result = view.reshape(&[2, 4]);
390        assert!(result.is_err());
391    }
392
393    #[test]
394    fn test_invalid_slice() {
395        let view = StridedView::new(0, vec![4, 4], vec![16, 4], 4);
396
397        // Wrong number of dimensions
398        let result = view.slice(&[(1, 3)]);
399        assert!(result.is_err());
400
401        // Invalid range
402        let result = view.slice(&[(1, 1), (0, 2)]); // start >= end
403        assert!(result.is_err());
404
405        // Out of bounds
406        let result = view.slice(&[(0, 5), (0, 2)]); // end > shape
407        assert!(result.is_err());
408    }
409
410    #[test]
411    fn test_alias_detection_edge_cases() {
412        let detector = MemoryAliasDetector::new();
413
414        // Test touching boundaries
415        detector.register_view(0, 0, 100);
416        assert!(!detector.check_alias(0, 100, 50)); // Adjacent, no overlap
417
418        // Test complete containment
419        detector.register_view(1, 10, 80);
420        assert!(detector.check_alias(1, 20, 30)); // Contained within
421        assert!(detector.check_alias(1, 0, 100)); // Contains the view
422
423        // Test partial overlap
424        assert!(detector.check_partial_alias(1, 50, 80)); // Partial overlap
425        assert!(!detector.check_partial_alias(1, 15, 50)); // Complete containment
426    }
427
428    #[test]
429    fn test_alias_info() {
430        let detector = MemoryAliasDetector::new();
431        detector.register_view(0, 10, 50);
432        detector.register_view(0, 40, 30);
433
434        let aliases = detector.get_alias_info(0, 35, 20);
435        assert_eq!(aliases.len(), 2); // Overlaps with both views
436
437        // Check the overlap details
438        assert!(aliases
439            .iter()
440            .any(|&(start, size, _)| start == 40 && size == 15)); // Overlap with first view
441        assert!(aliases
442            .iter()
443            .any(|&(start, size, _)| start == 40 && size == 15)); // Overlap with second view
444    }
445
446    #[test]
447    fn test_alias_statistics() {
448        let detector = MemoryAliasDetector::new();
449
450        let (buffers, views) = detector.get_alias_statistics();
451        assert_eq!(buffers, 0);
452        assert_eq!(views, 0);
453
454        detector.register_view(0, 0, 100);
455        detector.register_view(0, 100, 100);
456        detector.register_view(1, 0, 50);
457
458        let (buffers, views) = detector.get_alias_statistics();
459        assert_eq!(buffers, 2); // 2 different buffer IDs
460        assert_eq!(views, 3); // 3 total views
461    }
462}