1use std::ops::Range;
7
8#[derive(Debug, Clone, PartialEq, Eq)]
10pub struct TensorView {
11 pub base_tensor_id: usize,
13 pub slices: Vec<SliceSpec>,
15 pub strides: Vec<isize>,
17 pub offset: usize,
19}
20
21impl TensorView {
22 pub fn new(base_tensor_id: usize, slices: Vec<SliceSpec>) -> Self {
24 TensorView {
25 base_tensor_id,
26 slices,
27 strides: vec![],
28 offset: 0,
29 }
30 }
31
32 pub fn full(base_tensor_id: usize, rank: usize) -> Self {
34 TensorView {
35 base_tensor_id,
36 slices: vec![SliceSpec::Full; rank],
37 strides: vec![],
38 offset: 0,
39 }
40 }
41
42 pub fn with_strides(mut self, strides: Vec<isize>) -> Self {
44 self.strides = strides;
45 self
46 }
47
48 pub fn with_offset(mut self, offset: usize) -> Self {
50 self.offset = offset;
51 self
52 }
53
54 pub fn is_contiguous(&self) -> bool {
56 self.slices
57 .iter()
58 .all(|s| matches!(s, SliceSpec::Full | SliceSpec::Range(_)))
59 && self.strides.is_empty()
60 }
61
62 pub fn is_full_view(&self) -> bool {
64 self.slices.iter().all(|s| matches!(s, SliceSpec::Full)) && self.offset == 0
65 }
66
67 pub fn rank(&self) -> usize {
69 self.slices.len()
70 }
71
72 pub fn compose(&self, other: &TensorView) -> Result<TensorView, String> {
74 if self.base_tensor_id != other.base_tensor_id {
75 return Err("Cannot compose views from different base tensors".to_string());
76 }
77
78 if self.rank() != other.rank() {
79 return Err(format!(
80 "Rank mismatch: {} vs {}",
81 self.rank(),
82 other.rank()
83 ));
84 }
85
86 let mut composed_slices = Vec::new();
88 for (s1, s2) in self.slices.iter().zip(other.slices.iter()) {
89 composed_slices.push(s1.compose(s2)?);
90 }
91
92 let composed_offset = self.offset + other.offset;
94
95 Ok(TensorView {
96 base_tensor_id: self.base_tensor_id,
97 slices: composed_slices,
98 strides: if other.strides.is_empty() {
99 self.strides.clone()
100 } else {
101 other.strides.clone()
102 },
103 offset: composed_offset,
104 })
105 }
106}
107
108#[derive(Debug, Clone, PartialEq, Eq)]
110pub enum SliceSpec {
111 Full,
113 Range(Range<usize>),
115 Index(usize),
117 Strided {
119 start: usize,
120 end: usize,
121 stride: usize,
122 },
123 Reverse,
125}
126
127impl SliceSpec {
128 pub fn range(start: usize, end: usize) -> Self {
130 SliceSpec::Range(start..end)
131 }
132
133 pub fn strided(start: usize, end: usize, stride: usize) -> Self {
135 SliceSpec::Strided { start, end, stride }
136 }
137
138 pub fn size(&self, dim_size: usize) -> Result<usize, String> {
140 match self {
141 SliceSpec::Full => Ok(dim_size),
142 SliceSpec::Range(r) => {
143 if r.end > dim_size {
144 Err(format!(
145 "Range end {} exceeds dimension size {}",
146 r.end, dim_size
147 ))
148 } else if r.start >= r.end {
149 Err(format!("Invalid range: {}..{}", r.start, r.end))
150 } else {
151 Ok(r.end - r.start)
152 }
153 }
154 SliceSpec::Index(_) => Ok(1), SliceSpec::Strided { start, end, stride } => {
156 if *end > dim_size {
157 Err(format!(
158 "Strided end {} exceeds dimension size {}",
159 end, dim_size
160 ))
161 } else if start >= end {
162 Err(format!("Invalid strided range: {}..{}", start, end))
163 } else if *stride == 0 {
164 Err("Stride cannot be zero".to_string())
165 } else {
166 Ok((end - start).div_ceil(*stride))
167 }
168 }
169 SliceSpec::Reverse => Ok(dim_size),
170 }
171 }
172
173 pub fn compose(&self, other: &SliceSpec) -> Result<SliceSpec, String> {
175 match (self, other) {
176 (SliceSpec::Full, s) => Ok(s.clone()),
177 (s, SliceSpec::Full) => Ok(s.clone()),
178 (SliceSpec::Range(r1), SliceSpec::Range(r2)) => {
179 let start = r1.start + r2.start;
180 let end = r1.start + r2.end;
181 if end > r1.end {
182 Err(format!(
183 "Composed range end {} exceeds first range end {}",
184 end, r1.end
185 ))
186 } else {
187 Ok(SliceSpec::Range(start..end))
188 }
189 }
190 (SliceSpec::Range(r), SliceSpec::Index(i)) => {
191 if *i >= r.len() {
192 Err(format!("Index {} out of range 0..{}", i, r.len()))
193 } else {
194 Ok(SliceSpec::Index(r.start + i))
195 }
196 }
197 _ => Err("Cannot compose these slice types".to_string()),
198 }
199 }
200}
201
202pub trait TensorViewable {
204 fn view(&self, slices: Vec<SliceSpec>) -> Result<TensorView, String>;
206
207 fn slice(&self, ranges: &[Range<usize>]) -> Result<TensorView, String> {
209 let slices = ranges.iter().map(|r| SliceSpec::Range(r.clone())).collect();
210 self.view(slices)
211 }
212
213 fn stride(&self, strides: Vec<isize>) -> Result<TensorView, String>;
215
216 fn at(&self, indices: &[usize]) -> Result<TensorView, String> {
218 let slices = indices.iter().map(|&i| SliceSpec::Index(i)).collect();
219 self.view(slices)
220 }
221
222 fn reshape_view(&self, new_shape: Vec<usize>) -> Result<TensorView, String>;
224}
225
226pub struct ViewBuilder {
228 base_tensor_id: usize,
229 slices: Vec<SliceSpec>,
230 strides: Vec<isize>,
231 offset: usize,
232}
233
234impl ViewBuilder {
235 pub fn new(base_tensor_id: usize, rank: usize) -> Self {
237 ViewBuilder {
238 base_tensor_id,
239 slices: vec![SliceSpec::Full; rank],
240 strides: vec![],
241 offset: 0,
242 }
243 }
244
245 pub fn slice_dim(mut self, dim: usize, slice: SliceSpec) -> Self {
247 if dim < self.slices.len() {
248 self.slices[dim] = slice;
249 }
250 self
251 }
252
253 pub fn range_dim(mut self, dim: usize, start: usize, end: usize) -> Self {
255 if dim < self.slices.len() {
256 self.slices[dim] = SliceSpec::Range(start..end);
257 }
258 self
259 }
260
261 pub fn index_dim(mut self, dim: usize, index: usize) -> Self {
263 if dim < self.slices.len() {
264 self.slices[dim] = SliceSpec::Index(index);
265 }
266 self
267 }
268
269 pub fn with_strides(mut self, strides: Vec<isize>) -> Self {
271 self.strides = strides;
272 self
273 }
274
275 pub fn with_offset(mut self, offset: usize) -> Self {
277 self.offset = offset;
278 self
279 }
280
281 pub fn build(self) -> TensorView {
283 TensorView {
284 base_tensor_id: self.base_tensor_id,
285 slices: self.slices,
286 strides: self.strides,
287 offset: self.offset,
288 }
289 }
290}
291
292#[derive(Debug, Clone, Copy, PartialEq, Eq)]
294pub enum InPlaceMode {
295 Safe,
297 Unsafe,
299 None,
301}
302
303pub trait InPlaceOps {
305 type Error;
306
307 fn can_do_inplace(&self, output_view: &TensorView, input_views: &[TensorView]) -> bool;
309
310 fn execute_inplace(
312 &mut self,
313 output_view: &TensorView,
314 input_views: &[TensorView],
315 mode: InPlaceMode,
316 ) -> Result<(), Self::Error>;
317}
318
319#[cfg(test)]
320mod tests {
321 use super::*;
322
323 #[test]
324 fn test_tensor_view_creation() {
325 let view = TensorView::new(0, vec![SliceSpec::Full, SliceSpec::Range(10..20)]);
326 assert_eq!(view.base_tensor_id, 0);
327 assert_eq!(view.rank(), 2);
328 assert!(!view.is_full_view());
329 }
330
331 #[test]
332 fn test_full_view() {
333 let view = TensorView::full(0, 3);
334 assert_eq!(view.rank(), 3);
335 assert!(view.is_full_view());
336 assert!(view.is_contiguous());
337 }
338
339 #[test]
340 fn test_slice_spec_size() {
341 assert_eq!(SliceSpec::Full.size(100).unwrap(), 100);
342 assert_eq!(SliceSpec::Range(10..20).size(100).unwrap(), 10);
343 assert_eq!(SliceSpec::Index(5).size(100).unwrap(), 1);
344 assert_eq!(
345 SliceSpec::Strided {
346 start: 0,
347 end: 100,
348 stride: 10
349 }
350 .size(100)
351 .unwrap(),
352 10
353 );
354 }
355
356 #[test]
357 fn test_slice_spec_compose() {
358 let s1 = SliceSpec::Range(10..30);
359 let s2 = SliceSpec::Range(5..15);
360 let composed = s1.compose(&s2).unwrap();
361 assert_eq!(composed, SliceSpec::Range(15..25));
362 }
363
364 #[test]
365 fn test_view_compose() {
366 let view1 = TensorView::new(0, vec![SliceSpec::Range(0..100), SliceSpec::Full]);
367 let view2 = TensorView::new(0, vec![SliceSpec::Range(10..50), SliceSpec::Range(0..64)]);
368 let composed = view1.compose(&view2).unwrap();
369 assert_eq!(composed.base_tensor_id, 0);
370 assert_eq!(composed.rank(), 2);
371 }
372
373 #[test]
374 fn test_view_builder() {
375 let view = ViewBuilder::new(0, 3)
376 .range_dim(0, 10, 20)
377 .index_dim(1, 5)
378 .with_offset(100)
379 .build();
380
381 assert_eq!(view.base_tensor_id, 0);
382 assert_eq!(view.offset, 100);
383 assert_eq!(view.slices[0], SliceSpec::Range(10..20));
384 assert_eq!(view.slices[1], SliceSpec::Index(5));
385 }
386
387 #[test]
388 fn test_contiguous_check() {
389 let view1 = TensorView::new(0, vec![SliceSpec::Full, SliceSpec::Range(0..10)]);
390 assert!(view1.is_contiguous());
391
392 let view2 = TensorView::new(0, vec![SliceSpec::Full, SliceSpec::Range(0..10)]);
394 assert!(view2.is_contiguous());
395
396 let view3 =
398 TensorView::new(0, vec![SliceSpec::Full, SliceSpec::Full]).with_strides(vec![128, 1]);
399 assert!(!view3.is_contiguous());
400 }
401
402 #[test]
403 fn test_strided_slice() {
404 let spec = SliceSpec::strided(0, 100, 10);
405 assert_eq!(spec.size(100).unwrap(), 10);
406
407 let spec2 = SliceSpec::strided(5, 50, 5);
408 assert_eq!(spec2.size(100).unwrap(), 9);
409 }
410
411 #[test]
412 fn test_invalid_slices() {
413 assert!(SliceSpec::Range(10..200).size(100).is_err());
415
416 #[allow(clippy::reversed_empty_ranges)]
418 {
419 assert!(SliceSpec::Range(20..10).size(100).is_err());
420 }
421
422 assert!(SliceSpec::Strided {
424 start: 0,
425 end: 10,
426 stride: 0
427 }
428 .size(100)
429 .is_err());
430 }
431
432 #[test]
433 fn test_view_with_strides() {
434 let view = TensorView::new(0, vec![SliceSpec::Full, SliceSpec::Full])
435 .with_strides(vec![128, 1])
436 .with_offset(0);
437
438 assert_eq!(view.strides, vec![128, 1]);
439 assert!(!view.is_contiguous()); }
441}