1use std::marker::PhantomData;
7
8pub trait Nat {
10 fn to_usize() -> usize;
11}
12
13pub struct Z;
15impl Nat for Z {
16 fn to_usize() -> usize {
17 0
18 }
19}
20
21pub struct S<N: Nat>(PhantomData<N>);
23impl<N: Nat> Nat for S<N> {
24 fn to_usize() -> usize {
25 N::to_usize() + 1
26 }
27}
28
29pub type D1 = S<Z>;
31pub type D2 = S<D1>;
32pub type D3 = S<D2>;
33pub type D4 = S<D3>;
34pub type D5 = S<D4>;
35pub type D6 = S<D5>;
36
37pub trait DimSize {
39 fn size() -> usize;
40}
41
42pub struct Dyn;
44impl DimSize for Dyn {
45 fn size() -> usize {
46 0 }
48}
49
50pub struct Static<const N: usize>;
52impl<const N: usize> DimSize for Static<N> {
53 fn size() -> usize {
54 N
55 }
56}
57
58pub struct TypedTensor<T, R: Nat> {
60 inner: T,
61 shape: Vec<usize>,
62 _rank: PhantomData<R>,
63}
64
65impl<T, R: Nat> TypedTensor<T, R> {
66 pub fn new(inner: T, shape: Vec<usize>) -> Result<Self, String> {
68 if shape.len() != R::to_usize() {
69 return Err(format!(
70 "Shape length {} does not match rank {}",
71 shape.len(),
72 R::to_usize()
73 ));
74 }
75
76 Ok(TypedTensor {
77 inner,
78 shape,
79 _rank: PhantomData,
80 })
81 }
82
83 pub fn new_unchecked(inner: T, shape: Vec<usize>) -> Self {
85 TypedTensor {
86 inner,
87 shape,
88 _rank: PhantomData,
89 }
90 }
91
92 pub fn inner(&self) -> &T {
94 &self.inner
95 }
96
97 pub fn inner_mut(&mut self) -> &mut T {
99 &mut self.inner
100 }
101
102 pub fn into_inner(self) -> T {
104 self.inner
105 }
106
107 pub fn shape(&self) -> &[usize] {
109 &self.shape
110 }
111
112 pub fn rank() -> usize {
114 R::to_usize()
115 }
116
117 pub fn validate_shape(&self, expected: &[usize]) -> bool {
119 self.shape == expected
120 }
121}
122
123pub type Scalar<T> = TypedTensor<T, Z>;
125
126pub type Vector<T> = TypedTensor<T, D1>;
128
129pub type Matrix<T> = TypedTensor<T, D2>;
131
132pub type Tensor3D<T> = TypedTensor<T, D3>;
134
135pub type Tensor4D<T> = TypedTensor<T, D4>;
137
138pub struct ShapedTensor<T, R: Nat, S: DimSize> {
140 inner: T,
141 _rank: PhantomData<R>,
142 _shape: PhantomData<S>,
143}
144
145impl<T, R: Nat, S: DimSize> ShapedTensor<T, R, S> {
146 pub fn new(inner: T) -> Self {
147 ShapedTensor {
148 inner,
149 _rank: PhantomData,
150 _shape: PhantomData,
151 }
152 }
153
154 pub fn inner(&self) -> &T {
155 &self.inner
156 }
157
158 pub fn inner_mut(&mut self) -> &mut T {
159 &mut self.inner
160 }
161
162 pub fn into_inner(self) -> T {
163 self.inner
164 }
165
166 pub fn rank() -> usize {
167 R::to_usize()
168 }
169
170 pub fn size() -> usize {
171 S::size()
172 }
173}
174
175pub trait TypedTensorOps<T, R: Nat> {
177 fn add(&self, other: &TypedTensor<T, R>) -> TypedTensor<T, R>;
179
180 fn mul(&self, other: &TypedTensor<T, R>) -> TypedTensor<T, R>;
182
183 fn scale(&self, scalar: f64) -> TypedTensor<T, R>;
185}
186
187pub trait MatrixOps<T> {
189 fn matmul(&self, other: &Matrix<T>) -> Result<Matrix<T>, String>;
191
192 fn transpose(&self) -> Matrix<T>;
194}
195
196pub struct EinsumSpec<Input, Output> {
198 spec_string: String,
199 _input: PhantomData<Input>,
200 _output: PhantomData<Output>,
201}
202
203impl<Input, Output> EinsumSpec<Input, Output> {
204 pub fn new(spec: String) -> Self {
205 EinsumSpec {
206 spec_string: spec,
207 _input: PhantomData,
208 _output: PhantomData,
209 }
210 }
211
212 pub fn spec(&self) -> &str {
213 &self.spec_string
214 }
215}
216
217pub struct TypedInputs<T> {
219 tensors: Vec<T>,
220}
221
222impl<T> TypedInputs<T> {
223 pub fn new() -> Self {
224 TypedInputs {
225 tensors: Vec::new(),
226 }
227 }
228
229 pub fn with(mut self, tensor: T) -> Self {
230 self.tensors.push(tensor);
231 self
232 }
233
234 pub fn tensors(&self) -> &[T] {
235 &self.tensors
236 }
237
238 pub fn into_vec(self) -> Vec<T> {
239 self.tensors
240 }
241}
242
243impl<T> Default for TypedInputs<T> {
244 fn default() -> Self {
245 Self::new()
246 }
247}
248
249pub struct TypedOutputs<T> {
251 tensors: Vec<T>,
252}
253
254impl<T> TypedOutputs<T> {
255 pub fn new(tensors: Vec<T>) -> Self {
256 TypedOutputs { tensors }
257 }
258
259 pub fn get(&self, index: usize) -> Option<&T> {
260 self.tensors.get(index)
261 }
262
263 pub fn len(&self) -> usize {
264 self.tensors.len()
265 }
266
267 pub fn is_empty(&self) -> bool {
268 self.tensors.is_empty()
269 }
270
271 pub fn into_vec(self) -> Vec<T> {
272 self.tensors
273 }
274}
275
276pub trait ShapeConstraint<R: Nat> {
278 fn check_shape(shape: &[usize]) -> bool;
279}
280
281pub struct FixedShape<const N: usize>;
283
284impl<const N: usize, R: Nat> ShapeConstraint<R> for FixedShape<N> {
285 fn check_shape(shape: &[usize]) -> bool {
286 shape.len() == R::to_usize() && shape.iter().all(|&d| d == N)
287 }
288}
289
290pub struct BroadcastShape;
292
293impl<R: Nat> ShapeConstraint<R> for BroadcastShape {
294 fn check_shape(shape: &[usize]) -> bool {
295 shape.len() == R::to_usize()
296 }
297}
298
299pub struct TypedBatch<T, R: Nat> {
301 tensors: Vec<TypedTensor<T, R>>,
302}
303
304impl<T, R: Nat> TypedBatch<T, R> {
305 pub fn new() -> Self {
306 TypedBatch {
307 tensors: Vec::new(),
308 }
309 }
310
311 pub fn with(mut self, tensor: TypedTensor<T, R>) -> Self {
312 self.tensors.push(tensor);
313 self
314 }
315
316 pub fn len(&self) -> usize {
317 self.tensors.len()
318 }
319
320 pub fn is_empty(&self) -> bool {
321 self.tensors.is_empty()
322 }
323
324 pub fn get(&self, index: usize) -> Option<&TypedTensor<T, R>> {
325 self.tensors.get(index)
326 }
327
328 pub fn iter(&self) -> impl Iterator<Item = &TypedTensor<T, R>> {
329 self.tensors.iter()
330 }
331}
332
333impl<T, R: Nat> Default for TypedBatch<T, R> {
334 fn default() -> Self {
335 Self::new()
336 }
337}
338
339pub struct TensorBuilder<T> {
341 inner: Option<T>,
342 shape: Vec<usize>,
343}
344
345impl<T> TensorBuilder<T> {
346 pub fn new(inner: T) -> Self {
347 TensorBuilder {
348 inner: Some(inner),
349 shape: Vec::new(),
350 }
351 }
352
353 pub fn with_shape(mut self, shape: Vec<usize>) -> Self {
354 self.shape = shape;
355 self
356 }
357
358 pub fn build_scalar(self) -> Result<Scalar<T>, String> {
359 let inner = self.inner.ok_or("Missing inner tensor")?;
360 if !self.shape.is_empty() {
361 return Err("Scalar must have empty shape".to_string());
362 }
363 Scalar::new(inner, vec![])
364 }
365
366 pub fn build_vector(self) -> Result<Vector<T>, String> {
367 let inner = self.inner.ok_or("Missing inner tensor")?;
368 if self.shape.len() != 1 {
369 return Err("Vector must have rank 1".to_string());
370 }
371 Vector::new(inner, self.shape)
372 }
373
374 pub fn build_matrix(self) -> Result<Matrix<T>, String> {
375 let inner = self.inner.ok_or("Missing inner tensor")?;
376 if self.shape.len() != 2 {
377 return Err("Matrix must have rank 2".to_string());
378 }
379 Matrix::new(inner, self.shape)
380 }
381
382 pub fn build<R: Nat>(self) -> Result<TypedTensor<T, R>, String> {
383 let inner = self.inner.ok_or("Missing inner tensor")?;
384 TypedTensor::new(inner, self.shape)
385 }
386}
387
388#[derive(Debug, Clone, Copy, PartialEq, Eq)]
390pub struct Dim<const N: usize>;
391
392impl<const N: usize> Dim<N> {
393 pub const fn size() -> usize {
394 N
395 }
396
397 pub fn matches(actual: usize) -> bool {
398 actual == N
399 }
400}
401
402pub trait DimOp {
404 }
407
408pub struct DimMul<A, B>(PhantomData<(A, B)>);
410
411#[cfg(test)]
412mod tests {
413 use super::*;
414
415 #[test]
416 fn test_nat_types() {
417 assert_eq!(Z::to_usize(), 0);
418 assert_eq!(D1::to_usize(), 1);
419 assert_eq!(D2::to_usize(), 2);
420 assert_eq!(D3::to_usize(), 3);
421 assert_eq!(D4::to_usize(), 4);
422 }
423
424 #[test]
425 fn test_dim_size() {
426 assert_eq!(Static::<10>::size(), 10);
427 assert_eq!(Static::<256>::size(), 256);
428 assert_eq!(Dyn::size(), 0);
429 }
430
431 #[test]
432 fn test_typed_tensor_creation() {
433 let tensor: Vector<f64> = TypedTensor::new(1.0, vec![10]).unwrap();
434 assert_eq!(tensor.shape(), &[10]);
435 assert_eq!(Vector::<f64>::rank(), 1);
436
437 let matrix: Matrix<f64> = TypedTensor::new(2.0, vec![10, 20]).unwrap();
438 assert_eq!(matrix.shape(), &[10, 20]);
439 assert_eq!(Matrix::<f64>::rank(), 2);
440 }
441
442 #[test]
443 fn test_typed_tensor_validation() {
444 let result: Result<Vector<f64>, _> = TypedTensor::new(1.0, vec![10, 20]);
445 assert!(result.is_err()); let result: Result<Matrix<f64>, _> = TypedTensor::new(2.0, vec![10]);
448 assert!(result.is_err()); }
450
451 #[test]
452 fn test_typed_tensor_inner() {
453 let tensor: Vector<i32> = TypedTensor::new(42, vec![5]).unwrap();
454 assert_eq!(*tensor.inner(), 42);
455
456 let inner = tensor.into_inner();
457 assert_eq!(inner, 42);
458 }
459
460 #[test]
461 fn test_shaped_tensor() {
462 let tensor: ShapedTensor<f64, D2, Static<10>> = ShapedTensor::new(2.5);
463 assert_eq!(ShapedTensor::<f64, D2, Static<10>>::rank(), 2);
464 assert_eq!(ShapedTensor::<f64, D2, Static<10>>::size(), 10);
465 assert_eq!(*tensor.inner(), 2.5);
466 }
467
468 #[test]
469 fn test_typed_inputs() {
470 let inputs: TypedInputs<i32> = TypedInputs::new().with(1).with(2).with(3);
471
472 assert_eq!(inputs.tensors().len(), 3);
473 assert_eq!(inputs.tensors(), &[1, 2, 3]);
474 }
475
476 #[test]
477 fn test_typed_outputs() {
478 let outputs: TypedOutputs<i32> = TypedOutputs::new(vec![1, 2, 3]);
479
480 assert_eq!(outputs.len(), 3);
481 assert!(!outputs.is_empty());
482 assert_eq!(outputs.get(0), Some(&1));
483 assert_eq!(outputs.get(1), Some(&2));
484 assert_eq!(outputs.get(2), Some(&3));
485 assert_eq!(outputs.get(3), None);
486 }
487
488 #[test]
489 fn test_einsum_spec() {
490 let spec: EinsumSpec<(Matrix<f64>, Matrix<f64>), Matrix<f64>> =
491 EinsumSpec::new("ij,jk->ik".to_string());
492 assert_eq!(spec.spec(), "ij,jk->ik");
493 }
494
495 #[test]
496 fn test_typed_batch() {
497 let mut batch: TypedBatch<i32, D1> = TypedBatch::new();
498 assert!(batch.is_empty());
499
500 let tensor1: Vector<i32> = TypedTensor::new(1, vec![5]).unwrap();
501 let tensor2: Vector<i32> = TypedTensor::new(2, vec![5]).unwrap();
502
503 batch = batch.with(tensor1).with(tensor2);
504
505 assert_eq!(batch.len(), 2);
506 assert!(!batch.is_empty());
507
508 let first = batch.get(0).unwrap();
509 assert_eq!(*first.inner(), 1);
510 }
511
512 #[test]
513 fn test_tensor_builder() {
514 let scalar: Scalar<f64> = TensorBuilder::new(2.5)
515 .with_shape(vec![])
516 .build_scalar()
517 .unwrap();
518 assert_eq!(*scalar.inner(), 2.5);
519
520 let vector: Vector<f64> = TensorBuilder::new(2.71)
521 .with_shape(vec![10])
522 .build_vector()
523 .unwrap();
524 assert_eq!(vector.shape(), &[10]);
525
526 let matrix: Matrix<f64> = TensorBuilder::new(1.41)
527 .with_shape(vec![3, 4])
528 .build_matrix()
529 .unwrap();
530 assert_eq!(matrix.shape(), &[3, 4]);
531 }
532
533 #[test]
534 fn test_tensor_builder_errors() {
535 let result = TensorBuilder::new(1.0).with_shape(vec![10]).build_scalar();
536 assert!(result.is_err()); let result = TensorBuilder::new(1.0)
539 .with_shape(vec![10, 20])
540 .build_vector();
541 assert!(result.is_err()); let result = TensorBuilder::new(1.0).with_shape(vec![10]).build_matrix();
544 assert!(result.is_err()); }
546
547 #[test]
548 fn test_dim() {
549 assert_eq!(Dim::<10>::size(), 10);
550 assert_eq!(Dim::<256>::size(), 256);
551
552 assert!(Dim::<10>::matches(10));
553 assert!(!Dim::<10>::matches(20));
554 }
555
556 #[test]
557 fn test_shape_validation() {
558 let tensor: Vector<i32> = TypedTensor::new(42, vec![10]).unwrap();
559 assert!(tensor.validate_shape(&[10]));
560 assert!(!tensor.validate_shape(&[20]));
561 assert!(!tensor.validate_shape(&[10, 10]));
562 }
563}