Skip to main content

tensorlogic_infer/
typesafe.rs

1//! Type-safe tensor wrappers with compile-time shape checking.
2//!
3//! This module provides strongly-typed tensor wrappers that encode
4//! shape information in the type system for compile-time verification.
5
6use std::marker::PhantomData;
7
8/// Type-level natural number for compile-time dimensions
9pub trait Nat {
10    fn to_usize() -> usize;
11}
12
13/// Zero dimension
14pub struct Z;
15impl Nat for Z {
16    fn to_usize() -> usize {
17        0
18    }
19}
20
21/// Successor of a natural number
22pub struct S<N: Nat>(PhantomData<N>);
23impl<N: Nat> Nat for S<N> {
24    fn to_usize() -> usize {
25        N::to_usize() + 1
26    }
27}
28
29/// Type aliases for common dimensions
30pub type D1 = S<Z>;
31pub type D2 = S<D1>;
32pub type D3 = S<D2>;
33pub type D4 = S<D3>;
34pub type D5 = S<D4>;
35pub type D6 = S<D5>;
36
37/// Type-level dimension size
38pub trait DimSize {
39    fn size() -> usize;
40}
41
42/// Dynamic dimension size (runtime)
43pub struct Dyn;
44impl DimSize for Dyn {
45    fn size() -> usize {
46        0 // Runtime determined
47    }
48}
49
50/// Static dimension size
51pub struct Static<const N: usize>;
52impl<const N: usize> DimSize for Static<N> {
53    fn size() -> usize {
54        N
55    }
56}
57
58/// Type-safe tensor with compile-time rank
59pub struct TypedTensor<T, R: Nat> {
60    inner: T,
61    shape: Vec<usize>,
62    _rank: PhantomData<R>,
63}
64
65impl<T, R: Nat> TypedTensor<T, R> {
66    /// Create a typed tensor with shape validation
67    pub fn new(inner: T, shape: Vec<usize>) -> Result<Self, String> {
68        if shape.len() != R::to_usize() {
69            return Err(format!(
70                "Shape length {} does not match rank {}",
71                shape.len(),
72                R::to_usize()
73            ));
74        }
75
76        Ok(TypedTensor {
77            inner,
78            shape,
79            _rank: PhantomData,
80        })
81    }
82
83    /// Create without validation (unsafe)
84    pub fn new_unchecked(inner: T, shape: Vec<usize>) -> Self {
85        TypedTensor {
86            inner,
87            shape,
88            _rank: PhantomData,
89        }
90    }
91
92    /// Get inner tensor
93    pub fn inner(&self) -> &T {
94        &self.inner
95    }
96
97    /// Get mutable inner tensor
98    pub fn inner_mut(&mut self) -> &mut T {
99        &mut self.inner
100    }
101
102    /// Consume and get inner tensor
103    pub fn into_inner(self) -> T {
104        self.inner
105    }
106
107    /// Get shape
108    pub fn shape(&self) -> &[usize] {
109        &self.shape
110    }
111
112    /// Get rank (compile-time known)
113    pub fn rank() -> usize {
114        R::to_usize()
115    }
116
117    /// Check if shape matches expected
118    pub fn validate_shape(&self, expected: &[usize]) -> bool {
119        self.shape == expected
120    }
121}
122
123/// Scalar (rank 0)
124pub type Scalar<T> = TypedTensor<T, Z>;
125
126/// Vector (rank 1)
127pub type Vector<T> = TypedTensor<T, D1>;
128
129/// Matrix (rank 2)
130pub type Matrix<T> = TypedTensor<T, D2>;
131
132/// 3D Tensor (rank 3)
133pub type Tensor3D<T> = TypedTensor<T, D3>;
134
135/// 4D Tensor (rank 4)
136pub type Tensor4D<T> = TypedTensor<T, D4>;
137
138/// Type-safe tensor with both rank and shape
139pub struct ShapedTensor<T, R: Nat, S: DimSize> {
140    inner: T,
141    _rank: PhantomData<R>,
142    _shape: PhantomData<S>,
143}
144
145impl<T, R: Nat, S: DimSize> ShapedTensor<T, R, S> {
146    pub fn new(inner: T) -> Self {
147        ShapedTensor {
148            inner,
149            _rank: PhantomData,
150            _shape: PhantomData,
151        }
152    }
153
154    pub fn inner(&self) -> &T {
155        &self.inner
156    }
157
158    pub fn inner_mut(&mut self) -> &mut T {
159        &mut self.inner
160    }
161
162    pub fn into_inner(self) -> T {
163        self.inner
164    }
165
166    pub fn rank() -> usize {
167        R::to_usize()
168    }
169
170    pub fn size() -> usize {
171        S::size()
172    }
173}
174
175/// Trait for type-safe tensor operations
176pub trait TypedTensorOps<T, R: Nat> {
177    /// Element-wise addition (same shape)
178    fn add(&self, other: &TypedTensor<T, R>) -> TypedTensor<T, R>;
179
180    /// Element-wise multiplication (same shape)
181    fn mul(&self, other: &TypedTensor<T, R>) -> TypedTensor<T, R>;
182
183    /// Scalar multiplication
184    fn scale(&self, scalar: f64) -> TypedTensor<T, R>;
185}
186
187/// Matrix operations (rank 2 specific)
188pub trait MatrixOps<T> {
189    /// Matrix multiplication (M x N) * (N x K) -> (M x K)
190    fn matmul(&self, other: &Matrix<T>) -> Result<Matrix<T>, String>;
191
192    /// Transpose (M x N) -> (N x M)
193    fn transpose(&self) -> Matrix<T>;
194}
195
196/// Type-safe einsum specification
197pub struct EinsumSpec<Input, Output> {
198    spec_string: String,
199    _input: PhantomData<Input>,
200    _output: PhantomData<Output>,
201}
202
203impl<Input, Output> EinsumSpec<Input, Output> {
204    pub fn new(spec: String) -> Self {
205        EinsumSpec {
206            spec_string: spec,
207            _input: PhantomData,
208            _output: PhantomData,
209        }
210    }
211
212    pub fn spec(&self) -> &str {
213        &self.spec_string
214    }
215}
216
217/// Typed input container for execution
218pub struct TypedInputs<T> {
219    tensors: Vec<T>,
220}
221
222impl<T> TypedInputs<T> {
223    pub fn new() -> Self {
224        TypedInputs {
225            tensors: Vec::new(),
226        }
227    }
228
229    pub fn with(mut self, tensor: T) -> Self {
230        self.tensors.push(tensor);
231        self
232    }
233
234    pub fn tensors(&self) -> &[T] {
235        &self.tensors
236    }
237
238    pub fn into_vec(self) -> Vec<T> {
239        self.tensors
240    }
241}
242
243impl<T> Default for TypedInputs<T> {
244    fn default() -> Self {
245        Self::new()
246    }
247}
248
249/// Typed output container from execution
250pub struct TypedOutputs<T> {
251    tensors: Vec<T>,
252}
253
254impl<T> TypedOutputs<T> {
255    pub fn new(tensors: Vec<T>) -> Self {
256        TypedOutputs { tensors }
257    }
258
259    pub fn get(&self, index: usize) -> Option<&T> {
260        self.tensors.get(index)
261    }
262
263    pub fn len(&self) -> usize {
264        self.tensors.len()
265    }
266
267    pub fn is_empty(&self) -> bool {
268        self.tensors.is_empty()
269    }
270
271    pub fn into_vec(self) -> Vec<T> {
272        self.tensors
273    }
274}
275
276/// Shape constraint for compile-time checking
277pub trait ShapeConstraint<R: Nat> {
278    fn check_shape(shape: &[usize]) -> bool;
279}
280
281/// Fixed shape constraint
282pub struct FixedShape<const N: usize>;
283
284impl<const N: usize, R: Nat> ShapeConstraint<R> for FixedShape<N> {
285    fn check_shape(shape: &[usize]) -> bool {
286        shape.len() == R::to_usize() && shape.iter().all(|&d| d == N)
287    }
288}
289
290/// Broadcasting-compatible shape constraint
291pub struct BroadcastShape;
292
293impl<R: Nat> ShapeConstraint<R> for BroadcastShape {
294    fn check_shape(shape: &[usize]) -> bool {
295        shape.len() == R::to_usize()
296    }
297}
298
299/// Type-safe batch of tensors
300pub struct TypedBatch<T, R: Nat> {
301    tensors: Vec<TypedTensor<T, R>>,
302}
303
304impl<T, R: Nat> TypedBatch<T, R> {
305    pub fn new() -> Self {
306        TypedBatch {
307            tensors: Vec::new(),
308        }
309    }
310
311    pub fn with(mut self, tensor: TypedTensor<T, R>) -> Self {
312        self.tensors.push(tensor);
313        self
314    }
315
316    pub fn len(&self) -> usize {
317        self.tensors.len()
318    }
319
320    pub fn is_empty(&self) -> bool {
321        self.tensors.is_empty()
322    }
323
324    pub fn get(&self, index: usize) -> Option<&TypedTensor<T, R>> {
325        self.tensors.get(index)
326    }
327
328    pub fn iter(&self) -> impl Iterator<Item = &TypedTensor<T, R>> {
329        self.tensors.iter()
330    }
331}
332
333impl<T, R: Nat> Default for TypedBatch<T, R> {
334    fn default() -> Self {
335        Self::new()
336    }
337}
338
339/// Builder for type-safe tensor construction
340pub struct TensorBuilder<T> {
341    inner: Option<T>,
342    shape: Vec<usize>,
343}
344
345impl<T> TensorBuilder<T> {
346    pub fn new(inner: T) -> Self {
347        TensorBuilder {
348            inner: Some(inner),
349            shape: Vec::new(),
350        }
351    }
352
353    pub fn with_shape(mut self, shape: Vec<usize>) -> Self {
354        self.shape = shape;
355        self
356    }
357
358    pub fn build_scalar(self) -> Result<Scalar<T>, String> {
359        let inner = self.inner.ok_or("Missing inner tensor")?;
360        if !self.shape.is_empty() {
361            return Err("Scalar must have empty shape".to_string());
362        }
363        Scalar::new(inner, vec![])
364    }
365
366    pub fn build_vector(self) -> Result<Vector<T>, String> {
367        let inner = self.inner.ok_or("Missing inner tensor")?;
368        if self.shape.len() != 1 {
369            return Err("Vector must have rank 1".to_string());
370        }
371        Vector::new(inner, self.shape)
372    }
373
374    pub fn build_matrix(self) -> Result<Matrix<T>, String> {
375        let inner = self.inner.ok_or("Missing inner tensor")?;
376        if self.shape.len() != 2 {
377            return Err("Matrix must have rank 2".to_string());
378        }
379        Matrix::new(inner, self.shape)
380    }
381
382    pub fn build<R: Nat>(self) -> Result<TypedTensor<T, R>, String> {
383        let inner = self.inner.ok_or("Missing inner tensor")?;
384        TypedTensor::new(inner, self.shape)
385    }
386}
387
388/// Type-safe dimension information
389#[derive(Debug, Clone, Copy, PartialEq, Eq)]
390pub struct Dim<const N: usize>;
391
392impl<const N: usize> Dim<N> {
393    pub const fn size() -> usize {
394        N
395    }
396
397    pub fn matches(actual: usize) -> bool {
398        actual == N
399    }
400}
401
402/// Helper for dimension arithmetic (marker trait)
403pub trait DimOp {
404    // Marker trait for dimension operations
405    // Actual operations would require unstable const generics features
406}
407
408/// Dimension multiplication (placeholder)
409pub struct DimMul<A, B>(PhantomData<(A, B)>);
410
411#[cfg(test)]
412mod tests {
413    use super::*;
414
415    #[test]
416    fn test_nat_types() {
417        assert_eq!(Z::to_usize(), 0);
418        assert_eq!(D1::to_usize(), 1);
419        assert_eq!(D2::to_usize(), 2);
420        assert_eq!(D3::to_usize(), 3);
421        assert_eq!(D4::to_usize(), 4);
422    }
423
424    #[test]
425    fn test_dim_size() {
426        assert_eq!(Static::<10>::size(), 10);
427        assert_eq!(Static::<256>::size(), 256);
428        assert_eq!(Dyn::size(), 0);
429    }
430
431    #[test]
432    fn test_typed_tensor_creation() {
433        let tensor: Vector<f64> = TypedTensor::new(1.0, vec![10]).unwrap();
434        assert_eq!(tensor.shape(), &[10]);
435        assert_eq!(Vector::<f64>::rank(), 1);
436
437        let matrix: Matrix<f64> = TypedTensor::new(2.0, vec![10, 20]).unwrap();
438        assert_eq!(matrix.shape(), &[10, 20]);
439        assert_eq!(Matrix::<f64>::rank(), 2);
440    }
441
442    #[test]
443    fn test_typed_tensor_validation() {
444        let result: Result<Vector<f64>, _> = TypedTensor::new(1.0, vec![10, 20]);
445        assert!(result.is_err()); // Wrong rank
446
447        let result: Result<Matrix<f64>, _> = TypedTensor::new(2.0, vec![10]);
448        assert!(result.is_err()); // Wrong rank
449    }
450
451    #[test]
452    fn test_typed_tensor_inner() {
453        let tensor: Vector<i32> = TypedTensor::new(42, vec![5]).unwrap();
454        assert_eq!(*tensor.inner(), 42);
455
456        let inner = tensor.into_inner();
457        assert_eq!(inner, 42);
458    }
459
460    #[test]
461    fn test_shaped_tensor() {
462        let tensor: ShapedTensor<f64, D2, Static<10>> = ShapedTensor::new(2.5);
463        assert_eq!(ShapedTensor::<f64, D2, Static<10>>::rank(), 2);
464        assert_eq!(ShapedTensor::<f64, D2, Static<10>>::size(), 10);
465        assert_eq!(*tensor.inner(), 2.5);
466    }
467
468    #[test]
469    fn test_typed_inputs() {
470        let inputs: TypedInputs<i32> = TypedInputs::new().with(1).with(2).with(3);
471
472        assert_eq!(inputs.tensors().len(), 3);
473        assert_eq!(inputs.tensors(), &[1, 2, 3]);
474    }
475
476    #[test]
477    fn test_typed_outputs() {
478        let outputs: TypedOutputs<i32> = TypedOutputs::new(vec![1, 2, 3]);
479
480        assert_eq!(outputs.len(), 3);
481        assert!(!outputs.is_empty());
482        assert_eq!(outputs.get(0), Some(&1));
483        assert_eq!(outputs.get(1), Some(&2));
484        assert_eq!(outputs.get(2), Some(&3));
485        assert_eq!(outputs.get(3), None);
486    }
487
488    #[test]
489    fn test_einsum_spec() {
490        let spec: EinsumSpec<(Matrix<f64>, Matrix<f64>), Matrix<f64>> =
491            EinsumSpec::new("ij,jk->ik".to_string());
492        assert_eq!(spec.spec(), "ij,jk->ik");
493    }
494
495    #[test]
496    fn test_typed_batch() {
497        let mut batch: TypedBatch<i32, D1> = TypedBatch::new();
498        assert!(batch.is_empty());
499
500        let tensor1: Vector<i32> = TypedTensor::new(1, vec![5]).unwrap();
501        let tensor2: Vector<i32> = TypedTensor::new(2, vec![5]).unwrap();
502
503        batch = batch.with(tensor1).with(tensor2);
504
505        assert_eq!(batch.len(), 2);
506        assert!(!batch.is_empty());
507
508        let first = batch.get(0).unwrap();
509        assert_eq!(*first.inner(), 1);
510    }
511
512    #[test]
513    fn test_tensor_builder() {
514        let scalar: Scalar<f64> = TensorBuilder::new(2.5)
515            .with_shape(vec![])
516            .build_scalar()
517            .unwrap();
518        assert_eq!(*scalar.inner(), 2.5);
519
520        let vector: Vector<f64> = TensorBuilder::new(2.71)
521            .with_shape(vec![10])
522            .build_vector()
523            .unwrap();
524        assert_eq!(vector.shape(), &[10]);
525
526        let matrix: Matrix<f64> = TensorBuilder::new(1.41)
527            .with_shape(vec![3, 4])
528            .build_matrix()
529            .unwrap();
530        assert_eq!(matrix.shape(), &[3, 4]);
531    }
532
533    #[test]
534    fn test_tensor_builder_errors() {
535        let result = TensorBuilder::new(1.0).with_shape(vec![10]).build_scalar();
536        assert!(result.is_err()); // Scalar can't have shape
537
538        let result = TensorBuilder::new(1.0)
539            .with_shape(vec![10, 20])
540            .build_vector();
541        assert!(result.is_err()); // Vector must be rank 1
542
543        let result = TensorBuilder::new(1.0).with_shape(vec![10]).build_matrix();
544        assert!(result.is_err()); // Matrix must be rank 2
545    }
546
547    #[test]
548    fn test_dim() {
549        assert_eq!(Dim::<10>::size(), 10);
550        assert_eq!(Dim::<256>::size(), 256);
551
552        assert!(Dim::<10>::matches(10));
553        assert!(!Dim::<10>::matches(20));
554    }
555
556    #[test]
557    fn test_shape_validation() {
558        let tensor: Vector<i32> = TypedTensor::new(42, vec![10]).unwrap();
559        assert!(tensor.validate_shape(&[10]));
560        assert!(!tensor.validate_shape(&[20]));
561        assert!(!tensor.validate_shape(&[10, 10]));
562    }
563}