Skip to main content

tensorlogic_infer/
tensor_view.rs

1//! Zero-copy tensor views and slicing operations.
2//!
3//! This module provides infrastructure for zero-copy tensor operations,
4//! enabling efficient memory access patterns without data duplication.
5
6use std::ops::Range;
7
8/// Tensor view descriptor for zero-copy operations
9#[derive(Debug, Clone, PartialEq, Eq)]
10pub struct TensorView {
11    /// Base tensor identifier
12    pub base_tensor_id: usize,
13    /// Slice specification for each dimension
14    pub slices: Vec<SliceSpec>,
15    /// Strides for each dimension (for strided access)
16    pub strides: Vec<isize>,
17    /// Offset from the base tensor
18    pub offset: usize,
19}
20
21impl TensorView {
22    /// Create a new tensor view
23    pub fn new(base_tensor_id: usize, slices: Vec<SliceSpec>) -> Self {
24        TensorView {
25            base_tensor_id,
26            slices,
27            strides: vec![],
28            offset: 0,
29        }
30    }
31
32    /// Create a full view of a tensor (no slicing)
33    pub fn full(base_tensor_id: usize, rank: usize) -> Self {
34        TensorView {
35            base_tensor_id,
36            slices: vec![SliceSpec::Full; rank],
37            strides: vec![],
38            offset: 0,
39        }
40    }
41
42    /// Create a view with specific offset and strides
43    pub fn with_strides(mut self, strides: Vec<isize>) -> Self {
44        self.strides = strides;
45        self
46    }
47
48    /// Create a view with specific offset
49    pub fn with_offset(mut self, offset: usize) -> Self {
50        self.offset = offset;
51        self
52    }
53
54    /// Check if this view represents a contiguous slice
55    pub fn is_contiguous(&self) -> bool {
56        self.slices
57            .iter()
58            .all(|s| matches!(s, SliceSpec::Full | SliceSpec::Range(_)))
59            && self.strides.is_empty()
60    }
61
62    /// Check if this view is a complete view (no slicing)
63    pub fn is_full_view(&self) -> bool {
64        self.slices.iter().all(|s| matches!(s, SliceSpec::Full)) && self.offset == 0
65    }
66
67    /// Get the rank of the view
68    pub fn rank(&self) -> usize {
69        self.slices.len()
70    }
71
72    /// Compose two views (create a view of a view)
73    pub fn compose(&self, other: &TensorView) -> Result<TensorView, String> {
74        if self.base_tensor_id != other.base_tensor_id {
75            return Err("Cannot compose views from different base tensors".to_string());
76        }
77
78        if self.rank() != other.rank() {
79            return Err(format!(
80                "Rank mismatch: {} vs {}",
81                self.rank(),
82                other.rank()
83            ));
84        }
85
86        // Compose slices
87        let mut composed_slices = Vec::new();
88        for (s1, s2) in self.slices.iter().zip(other.slices.iter()) {
89            composed_slices.push(s1.compose(s2)?);
90        }
91
92        // Compute composed offset
93        let composed_offset = self.offset + other.offset;
94
95        Ok(TensorView {
96            base_tensor_id: self.base_tensor_id,
97            slices: composed_slices,
98            strides: if other.strides.is_empty() {
99                self.strides.clone()
100            } else {
101                other.strides.clone()
102            },
103            offset: composed_offset,
104        })
105    }
106}
107
108/// Slice specification for a single dimension
109#[derive(Debug, Clone, PartialEq, Eq)]
110pub enum SliceSpec {
111    /// Full dimension (no slicing)
112    Full,
113    /// Range slice [start..end)
114    Range(Range<usize>),
115    /// Single index (reduces dimension)
116    Index(usize),
117    /// Strided slice (start, end, stride)
118    Strided {
119        start: usize,
120        end: usize,
121        stride: usize,
122    },
123    /// Reverse slice (full dimension in reverse order)
124    Reverse,
125}
126
127impl SliceSpec {
128    /// Create a range slice
129    pub fn range(start: usize, end: usize) -> Self {
130        SliceSpec::Range(start..end)
131    }
132
133    /// Create a strided slice
134    pub fn strided(start: usize, end: usize, stride: usize) -> Self {
135        SliceSpec::Strided { start, end, stride }
136    }
137
138    /// Get the size of this slice given the dimension size
139    pub fn size(&self, dim_size: usize) -> Result<usize, String> {
140        match self {
141            SliceSpec::Full => Ok(dim_size),
142            SliceSpec::Range(r) => {
143                if r.end > dim_size {
144                    Err(format!(
145                        "Range end {} exceeds dimension size {}",
146                        r.end, dim_size
147                    ))
148                } else if r.start >= r.end {
149                    Err(format!("Invalid range: {}..{}", r.start, r.end))
150                } else {
151                    Ok(r.end - r.start)
152                }
153            }
154            SliceSpec::Index(_) => Ok(1), // Single element
155            SliceSpec::Strided { start, end, stride } => {
156                if *end > dim_size {
157                    Err(format!(
158                        "Strided end {} exceeds dimension size {}",
159                        end, dim_size
160                    ))
161                } else if start >= end {
162                    Err(format!("Invalid strided range: {}..{}", start, end))
163                } else if *stride == 0 {
164                    Err("Stride cannot be zero".to_string())
165                } else {
166                    Ok((end - start).div_ceil(*stride))
167                }
168            }
169            SliceSpec::Reverse => Ok(dim_size),
170        }
171    }
172
173    /// Compose two slice specs
174    pub fn compose(&self, other: &SliceSpec) -> Result<SliceSpec, String> {
175        match (self, other) {
176            (SliceSpec::Full, s) => Ok(s.clone()),
177            (s, SliceSpec::Full) => Ok(s.clone()),
178            (SliceSpec::Range(r1), SliceSpec::Range(r2)) => {
179                let start = r1.start + r2.start;
180                let end = r1.start + r2.end;
181                if end > r1.end {
182                    Err(format!(
183                        "Composed range end {} exceeds first range end {}",
184                        end, r1.end
185                    ))
186                } else {
187                    Ok(SliceSpec::Range(start..end))
188                }
189            }
190            (SliceSpec::Range(r), SliceSpec::Index(i)) => {
191                if *i >= r.len() {
192                    Err(format!("Index {} out of range 0..{}", i, r.len()))
193                } else {
194                    Ok(SliceSpec::Index(r.start + i))
195                }
196            }
197            _ => Err("Cannot compose these slice types".to_string()),
198        }
199    }
200}
201
202/// Trait for tensors that support zero-copy views
203pub trait TensorViewable {
204    /// Create a view of this tensor
205    fn view(&self, slices: Vec<SliceSpec>) -> Result<TensorView, String>;
206
207    /// Create a slice of this tensor
208    fn slice(&self, ranges: &[Range<usize>]) -> Result<TensorView, String> {
209        let slices = ranges.iter().map(|r| SliceSpec::Range(r.clone())).collect();
210        self.view(slices)
211    }
212
213    /// Create a strided view
214    fn stride(&self, strides: Vec<isize>) -> Result<TensorView, String>;
215
216    /// Get a single element view
217    fn at(&self, indices: &[usize]) -> Result<TensorView, String> {
218        let slices = indices.iter().map(|&i| SliceSpec::Index(i)).collect();
219        self.view(slices)
220    }
221
222    /// Reshape view (if possible without copying)
223    fn reshape_view(&self, new_shape: Vec<usize>) -> Result<TensorView, String>;
224}
225
226/// View builder for ergonomic API
227pub struct ViewBuilder {
228    base_tensor_id: usize,
229    slices: Vec<SliceSpec>,
230    strides: Vec<isize>,
231    offset: usize,
232}
233
234impl ViewBuilder {
235    /// Create a new view builder
236    pub fn new(base_tensor_id: usize, rank: usize) -> Self {
237        ViewBuilder {
238            base_tensor_id,
239            slices: vec![SliceSpec::Full; rank],
240            strides: vec![],
241            offset: 0,
242        }
243    }
244
245    /// Set slice for a dimension
246    pub fn slice_dim(mut self, dim: usize, slice: SliceSpec) -> Self {
247        if dim < self.slices.len() {
248            self.slices[dim] = slice;
249        }
250        self
251    }
252
253    /// Set range for a dimension
254    pub fn range_dim(mut self, dim: usize, start: usize, end: usize) -> Self {
255        if dim < self.slices.len() {
256            self.slices[dim] = SliceSpec::Range(start..end);
257        }
258        self
259    }
260
261    /// Set index for a dimension
262    pub fn index_dim(mut self, dim: usize, index: usize) -> Self {
263        if dim < self.slices.len() {
264            self.slices[dim] = SliceSpec::Index(index);
265        }
266        self
267    }
268
269    /// Set strides
270    pub fn with_strides(mut self, strides: Vec<isize>) -> Self {
271        self.strides = strides;
272        self
273    }
274
275    /// Set offset
276    pub fn with_offset(mut self, offset: usize) -> Self {
277        self.offset = offset;
278        self
279    }
280
281    /// Build the tensor view
282    pub fn build(self) -> TensorView {
283        TensorView {
284            base_tensor_id: self.base_tensor_id,
285            slices: self.slices,
286            strides: self.strides,
287            offset: self.offset,
288        }
289    }
290}
291
292/// In-place operation marker
293#[derive(Debug, Clone, Copy, PartialEq, Eq)]
294pub enum InPlaceMode {
295    /// Safe in-place (no aliasing)
296    Safe,
297    /// Unsafe in-place (potential aliasing, user responsible)
298    Unsafe,
299    /// No in-place operation
300    None,
301}
302
303/// Trait for in-place operations
304pub trait InPlaceOps {
305    type Error;
306
307    /// Check if in-place operation is safe
308    fn can_do_inplace(&self, output_view: &TensorView, input_views: &[TensorView]) -> bool;
309
310    /// Execute operation in-place if possible
311    fn execute_inplace(
312        &mut self,
313        output_view: &TensorView,
314        input_views: &[TensorView],
315        mode: InPlaceMode,
316    ) -> Result<(), Self::Error>;
317}
318
319#[cfg(test)]
320mod tests {
321    use super::*;
322
323    #[test]
324    fn test_tensor_view_creation() {
325        let view = TensorView::new(0, vec![SliceSpec::Full, SliceSpec::Range(10..20)]);
326        assert_eq!(view.base_tensor_id, 0);
327        assert_eq!(view.rank(), 2);
328        assert!(!view.is_full_view());
329    }
330
331    #[test]
332    fn test_full_view() {
333        let view = TensorView::full(0, 3);
334        assert_eq!(view.rank(), 3);
335        assert!(view.is_full_view());
336        assert!(view.is_contiguous());
337    }
338
339    #[test]
340    fn test_slice_spec_size() {
341        assert_eq!(SliceSpec::Full.size(100).unwrap(), 100);
342        assert_eq!(SliceSpec::Range(10..20).size(100).unwrap(), 10);
343        assert_eq!(SliceSpec::Index(5).size(100).unwrap(), 1);
344        assert_eq!(
345            SliceSpec::Strided {
346                start: 0,
347                end: 100,
348                stride: 10
349            }
350            .size(100)
351            .unwrap(),
352            10
353        );
354    }
355
356    #[test]
357    fn test_slice_spec_compose() {
358        let s1 = SliceSpec::Range(10..30);
359        let s2 = SliceSpec::Range(5..15);
360        let composed = s1.compose(&s2).unwrap();
361        assert_eq!(composed, SliceSpec::Range(15..25));
362    }
363
364    #[test]
365    fn test_view_compose() {
366        let view1 = TensorView::new(0, vec![SliceSpec::Range(0..100), SliceSpec::Full]);
367        let view2 = TensorView::new(0, vec![SliceSpec::Range(10..50), SliceSpec::Range(0..64)]);
368        let composed = view1.compose(&view2).unwrap();
369        assert_eq!(composed.base_tensor_id, 0);
370        assert_eq!(composed.rank(), 2);
371    }
372
373    #[test]
374    fn test_view_builder() {
375        let view = ViewBuilder::new(0, 3)
376            .range_dim(0, 10, 20)
377            .index_dim(1, 5)
378            .with_offset(100)
379            .build();
380
381        assert_eq!(view.base_tensor_id, 0);
382        assert_eq!(view.offset, 100);
383        assert_eq!(view.slices[0], SliceSpec::Range(10..20));
384        assert_eq!(view.slices[1], SliceSpec::Index(5));
385    }
386
387    #[test]
388    fn test_contiguous_check() {
389        let view1 = TensorView::new(0, vec![SliceSpec::Full, SliceSpec::Range(0..10)]);
390        assert!(view1.is_contiguous());
391
392        // Note: Index slices are considered contiguous only if no explicit strides
393        let view2 = TensorView::new(0, vec![SliceSpec::Full, SliceSpec::Range(0..10)]);
394        assert!(view2.is_contiguous());
395
396        // View with explicit strides is not contiguous
397        let view3 =
398            TensorView::new(0, vec![SliceSpec::Full, SliceSpec::Full]).with_strides(vec![128, 1]);
399        assert!(!view3.is_contiguous());
400    }
401
402    #[test]
403    fn test_strided_slice() {
404        let spec = SliceSpec::strided(0, 100, 10);
405        assert_eq!(spec.size(100).unwrap(), 10);
406
407        let spec2 = SliceSpec::strided(5, 50, 5);
408        assert_eq!(spec2.size(100).unwrap(), 9);
409    }
410
411    #[test]
412    fn test_invalid_slices() {
413        // Range exceeds dimension
414        assert!(SliceSpec::Range(10..200).size(100).is_err());
415
416        // Invalid range (intentionally reversed to test error handling)
417        #[allow(clippy::reversed_empty_ranges)]
418        {
419            assert!(SliceSpec::Range(20..10).size(100).is_err());
420        }
421
422        // Zero stride
423        assert!(SliceSpec::Strided {
424            start: 0,
425            end: 10,
426            stride: 0
427        }
428        .size(100)
429        .is_err());
430    }
431
432    #[test]
433    fn test_view_with_strides() {
434        let view = TensorView::new(0, vec![SliceSpec::Full, SliceSpec::Full])
435            .with_strides(vec![128, 1])
436            .with_offset(0);
437
438        assert_eq!(view.strides, vec![128, 1]);
439        assert!(!view.is_contiguous()); // Has explicit strides
440    }
441}