1use crate::{Result, TensorError};
7use std::collections::HashMap;
8use std::sync::{Arc, Mutex};
9
10#[derive(Debug, Clone)]
12pub struct StridedView {
13 pub offset: usize,
14 pub shape: Vec<usize>,
15 pub strides: Vec<usize>,
16 pub element_size: usize,
17}
18
19impl StridedView {
20 pub fn new(offset: usize, shape: Vec<usize>, strides: Vec<usize>, element_size: usize) -> Self {
22 Self {
23 offset,
24 shape,
25 strides,
26 element_size,
27 }
28 }
29
30 pub fn transpose(&self, axes: &[usize]) -> Result<StridedView> {
32 if axes.len() != self.shape.len() {
33 return Err(TensorError::invalid_argument(
34 "Transpose axes must match tensor dimensions".to_string(),
35 ));
36 }
37
38 let mut new_shape = Vec::new();
39 let mut new_strides = Vec::new();
40
41 for &axis in axes {
42 if axis >= self.shape.len() {
43 return Err(TensorError::invalid_argument(format!(
44 "Axis {} out of bounds for tensor with {} dimensions",
45 axis,
46 self.shape.len()
47 )));
48 }
49 new_shape.push(self.shape[axis]);
50 new_strides.push(self.strides[axis]);
51 }
52
53 Ok(StridedView {
54 offset: self.offset,
55 shape: new_shape,
56 strides: new_strides,
57 element_size: self.element_size,
58 })
59 }
60
61 pub fn reshape(&self, new_shape: &[usize]) -> Result<StridedView> {
63 let total_elements: usize = self.shape.iter().product();
65 let new_total_elements: usize = new_shape.iter().product();
66
67 if total_elements != new_total_elements {
68 return Err(TensorError::invalid_argument(
69 "Cannot reshape tensor: element count mismatch".to_string(),
70 ));
71 }
72
73 if self.is_contiguous() {
75 let new_strides = compute_strides(new_shape, self.element_size);
77 Ok(StridedView {
78 offset: self.offset,
79 shape: new_shape.to_vec(),
80 strides: new_strides,
81 element_size: self.element_size,
82 })
83 } else {
84 Err(TensorError::unsupported_operation_simple(
86 "Reshape requires data copy for non-contiguous tensor".to_string(),
87 ))
88 }
89 }
90
91 pub fn is_contiguous(&self) -> bool {
93 let expected_strides = compute_strides(&self.shape, self.element_size);
94 self.strides == expected_strides
95 }
96
97 pub fn size_bytes(&self) -> usize {
99 self.shape.iter().product::<usize>() * self.element_size
100 }
101
102 pub fn slice(&self, ranges: &[(usize, usize)]) -> Result<StridedView> {
104 if ranges.len() != self.shape.len() {
105 return Err(TensorError::invalid_argument(
106 "Slice ranges must match tensor dimensions".to_string(),
107 ));
108 }
109
110 let mut new_offset = self.offset;
111 let mut new_shape = Vec::new();
112 let mut new_strides = Vec::new();
113
114 for (i, &(start, end)) in ranges.iter().enumerate() {
115 if start >= end || end > self.shape[i] {
116 return Err(TensorError::invalid_argument(format!(
117 "Invalid slice range [{}, {}) for dimension {} of size {}",
118 start, end, i, self.shape[i]
119 )));
120 }
121
122 new_offset += start * self.strides[i];
123 new_shape.push(end - start);
124 new_strides.push(self.strides[i]);
125 }
126
127 Ok(StridedView {
128 offset: new_offset,
129 shape: new_shape,
130 strides: new_strides,
131 element_size: self.element_size,
132 })
133 }
134}
135
136#[derive(Debug)]
138pub struct MemoryAliasDetector {
139 #[allow(clippy::type_complexity)]
140 active_views: Arc<Mutex<HashMap<usize, Vec<(usize, usize)>>>>, }
142
143impl Default for MemoryAliasDetector {
144 fn default() -> Self {
145 Self::new()
146 }
147}
148
149impl MemoryAliasDetector {
150 pub fn new() -> Self {
152 Self {
153 active_views: Arc::new(Mutex::new(HashMap::new())),
154 }
155 }
156
157 pub fn check_alias(&self, buffer_id: usize, offset: usize, size: usize) -> bool {
159 let active_views = self
160 .active_views
161 .lock()
162 .expect("lock should not be poisoned");
163
164 if let Some(views) = active_views.get(&buffer_id) {
165 for &(view_offset, view_size) in views {
166 let start1 = offset;
168 let end1 = offset + size;
169 let start2 = view_offset;
170 let end2 = view_offset + view_size;
171
172 if start1 < end2 && start2 < end1 {
174 return true; }
176 }
177 }
178
179 false
180 }
181
182 pub fn register_view(&self, buffer_id: usize, offset: usize, size: usize) {
184 let mut active_views = self
185 .active_views
186 .lock()
187 .expect("lock should not be poisoned");
188 active_views
189 .entry(buffer_id)
190 .or_default()
191 .push((offset, size));
192 }
193
194 pub fn unregister_view(&self, buffer_id: usize, offset: usize, size: usize) {
196 let mut active_views = self
197 .active_views
198 .lock()
199 .expect("lock should not be poisoned");
200 if let Some(views) = active_views.get_mut(&buffer_id) {
201 views.retain(|&(view_offset, view_size)| view_offset != offset || view_size != size);
202 if views.is_empty() {
203 active_views.remove(&buffer_id);
204 }
205 }
206 }
207
208 pub fn get_alias_info(
210 &self,
211 buffer_id: usize,
212 offset: usize,
213 size: usize,
214 ) -> Vec<(usize, usize, usize)> {
215 let active_views = self
216 .active_views
217 .lock()
218 .expect("lock should not be poisoned");
219 let mut aliases = Vec::new();
220
221 if let Some(views) = active_views.get(&buffer_id) {
222 for &(view_offset, view_size) in views {
223 let start1 = offset;
224 let end1 = offset + size;
225 let start2 = view_offset;
226 let end2 = view_offset + view_size;
227
228 if start1 < end2 && start2 < end1 {
230 let overlap_start = std::cmp::max(start1, start2);
231 let overlap_end = std::cmp::min(end1, end2);
232 let overlap_size = overlap_end - overlap_start;
233 aliases.push((overlap_start, overlap_size, view_size));
234 }
235 }
236 }
237
238 aliases
239 }
240
241 pub fn check_partial_alias(&self, buffer_id: usize, offset: usize, size: usize) -> bool {
243 let active_views = self
244 .active_views
245 .lock()
246 .expect("lock should not be poisoned");
247
248 if let Some(views) = active_views.get(&buffer_id) {
249 for &(view_offset, view_size) in views {
250 let start1 = offset;
251 let end1 = offset + size;
252 let start2 = view_offset;
253 let end2 = view_offset + view_size;
254
255 if start1 < end2 && start2 < end1 {
257 let not_contained_in_existing = !(start1 >= start2 && end1 <= end2);
259 let not_containing_existing = !(start2 >= start1 && end2 <= end1);
260
261 if not_contained_in_existing && not_containing_existing {
263 return true;
264 }
265 }
266 }
267 }
268
269 false
270 }
271
272 pub fn get_alias_statistics(&self) -> (usize, usize) {
274 let active_views = self
275 .active_views
276 .lock()
277 .expect("lock should not be poisoned");
278 let total_buffers = active_views.len();
279 let total_views: usize = active_views.values().map(|v| v.len()).sum();
280 (total_buffers, total_views)
281 }
282}
283
284pub fn compute_strides(shape: &[usize], element_size: usize) -> Vec<usize> {
286 let mut strides = Vec::with_capacity(shape.len());
287 let mut stride = element_size;
288
289 for &dim in shape.iter().rev() {
290 strides.push(stride);
291 stride *= dim;
292 }
293
294 strides.reverse();
295 strides
296}
297
298#[cfg(test)]
299mod tests {
300 use super::*;
301
302 #[test]
303 fn test_strided_view_transpose() {
304 let view = StridedView::new(0, vec![2, 3, 4], vec![48, 16, 4], 4);
305 let transposed = view
306 .transpose(&[2, 0, 1])
307 .expect("test: transpose should succeed");
308
309 assert_eq!(transposed.shape, vec![4, 2, 3]);
310 assert_eq!(transposed.strides, vec![4, 48, 16]);
311 }
312
313 #[test]
314 fn test_strided_view_reshape() {
315 let view = StridedView::new(0, vec![2, 3, 4], vec![48, 16, 4], 4);
316 let reshaped = view.reshape(&[6, 4]).expect("test: reshape should succeed");
317
318 assert_eq!(reshaped.shape, vec![6, 4]);
319 assert_eq!(reshaped.strides, vec![16, 4]);
320 }
321
322 #[test]
323 fn test_strided_view_slice() {
324 let view = StridedView::new(0, vec![4, 4], vec![16, 4], 4);
325 let sliced = view
326 .slice(&[(1, 3), (0, 2)])
327 .expect("test: operation should succeed");
328
329 assert_eq!(sliced.shape, vec![2, 2]);
330 assert_eq!(sliced.strides, vec![16, 4]);
331 assert_eq!(sliced.offset, 16); }
333
334 #[test]
335 fn test_memory_alias_detector() {
336 let detector = MemoryAliasDetector::new();
337
338 detector.register_view(0, 0, 100);
340
341 assert!(detector.check_alias(0, 50, 100)); assert!(!detector.check_alias(0, 100, 50)); detector.unregister_view(0, 0, 100);
347 assert!(!detector.check_alias(0, 50, 100)); }
349
350 #[test]
351 fn test_compute_strides() {
352 let strides = compute_strides(&[2, 3, 4], 4);
353 assert_eq!(strides, vec![48, 16, 4]);
354 }
355
356 #[test]
357 fn test_is_contiguous() {
358 let contiguous_view = StridedView::new(0, vec![2, 3, 4], vec![48, 16, 4], 4);
359 assert!(contiguous_view.is_contiguous());
360
361 let non_contiguous_view = StridedView::new(0, vec![2, 3, 4], vec![32, 16, 4], 4);
362 assert!(!non_contiguous_view.is_contiguous());
363 }
364
365 #[test]
366 fn test_size_bytes() {
367 let view = StridedView::new(0, vec![2, 3, 4], vec![48, 16, 4], 4);
368 assert_eq!(view.size_bytes(), 96); }
370
371 #[test]
372 fn test_invalid_transpose() {
373 let view = StridedView::new(0, vec![2, 3], vec![12, 4], 4);
374
375 let result = view.transpose(&[1, 0, 2]);
377 assert!(result.is_err());
378
379 let result = view.transpose(&[0, 3]);
381 assert!(result.is_err());
382 }
383
384 #[test]
385 fn test_invalid_reshape() {
386 let view = StridedView::new(0, vec![2, 3], vec![12, 4], 4);
387
388 let result = view.reshape(&[2, 4]);
390 assert!(result.is_err());
391 }
392
393 #[test]
394 fn test_invalid_slice() {
395 let view = StridedView::new(0, vec![4, 4], vec![16, 4], 4);
396
397 let result = view.slice(&[(1, 3)]);
399 assert!(result.is_err());
400
401 let result = view.slice(&[(1, 1), (0, 2)]); assert!(result.is_err());
404
405 let result = view.slice(&[(0, 5), (0, 2)]); assert!(result.is_err());
408 }
409
410 #[test]
411 fn test_alias_detection_edge_cases() {
412 let detector = MemoryAliasDetector::new();
413
414 detector.register_view(0, 0, 100);
416 assert!(!detector.check_alias(0, 100, 50)); detector.register_view(1, 10, 80);
420 assert!(detector.check_alias(1, 20, 30)); assert!(detector.check_alias(1, 0, 100)); assert!(detector.check_partial_alias(1, 50, 80)); assert!(!detector.check_partial_alias(1, 15, 50)); }
427
428 #[test]
429 fn test_alias_info() {
430 let detector = MemoryAliasDetector::new();
431 detector.register_view(0, 10, 50);
432 detector.register_view(0, 40, 30);
433
434 let aliases = detector.get_alias_info(0, 35, 20);
435 assert_eq!(aliases.len(), 2); assert!(aliases
439 .iter()
440 .any(|&(start, size, _)| start == 40 && size == 15)); assert!(aliases
442 .iter()
443 .any(|&(start, size, _)| start == 40 && size == 15)); }
445
446 #[test]
447 fn test_alias_statistics() {
448 let detector = MemoryAliasDetector::new();
449
450 let (buffers, views) = detector.get_alias_statistics();
451 assert_eq!(buffers, 0);
452 assert_eq!(views, 0);
453
454 detector.register_view(0, 0, 100);
455 detector.register_view(0, 100, 100);
456 detector.register_view(1, 0, 50);
457
458 let (buffers, views) = detector.get_alias_statistics();
459 assert_eq!(buffers, 2); assert_eq!(views, 3); }
462}