Skip to main content

rust_tensors/
tensor.rs

1use crate::address_iterator::AddressIterator;
2use crate::adressable::Addressable;
3use std::ops::{Add, Index, IndexMut, Sub};
4
5pub trait Tensor<
6    T,
7    V: Copy + From<u8> + Add<Output = V> + Sub<Output = V> + PartialOrd,
8    A: Addressable<V, DIMENSION>,
9    const DIMENSION: usize,
10>: Index<A, Output = T> + IndexMut<A, Output = T>
11{
12    fn smallest_contained_address(&self) -> A;
13    fn largest_contained_address(&self) -> A;
14    /// Attempts to get a reference of the value at the given address. Will return Err if the address
15    /// is not contained in the matrix.
16    ///
17    /// # Arguments
18    ///
19    /// * `address`: The address of the value to be retrieved
20    ///
21    /// Returns: Result<&T, String>, the Result-wrapped reference to the value.
22    fn get(&self, address: A) -> Result<&T, String> {
23        if self.contains_address(address) {
24            Ok(&self[address])
25        } else {
26            Err(format!(
27                "Cannot retrieve value at {address:?}, index out of bounds."
28            ))
29        }
30    }
31    /// Attempts to get a mutable reference of the value at the given address. Will return Err if the
32    /// address is not contained in the matrix.
33    ///
34    /// # Arguments
35    ///
36    /// * `address`: The address of the value to be retrieved
37    ///
38    /// Returns: Result<&T, String>, the Result-wrapped reference to the value.
39    fn get_mut(&mut self, address: A) -> Result<&mut T, String> {
40        if self.contains_address(address) {
41            Ok(&mut self[address])
42        } else {
43            Err(format!(
44                "Cannot retrieve value at {address:?}, index out of bounds."
45            ))
46        }
47    }
48    /// Evaluates whether an address is valid and has an associated value in the tensor.
49    ///
50    /// # Arguments
51    ///
52    /// * `address`: The address to be evaluated
53    ///
54    /// Returns: bool, A boolean which is true if and only if the address is valid and has an
55    /// associated value.
56    fn contains_address(&self, address: A) -> bool {
57        !(0..DIMENSION).any(|d| {
58            address.get_value_at_dimension_index(d)
59                < self
60                    .smallest_contained_address()
61                    .get_value_at_dimension_index(d)
62                || address.get_value_at_dimension_index(d)
63                    > self
64                        .largest_contained_address()
65                        .get_value_at_dimension_index(d)
66        })
67    }
68    /// Creates an iterator over the addresses within the bounds of the tensor.
69    ///
70    /// The iterator will traverse all addresses starting from the smallest contained address
71    /// and ending at the largest contained address, inclusive.
72    ///
73    /// # Returns
74    ///
75    /// An instance of `AddressIterator<V, A, DIMENSION>`, initialized to iterate between
76    /// the smallest and largest contained addresses of the current object.
77    fn address_iter(&self) -> AddressIterator<V, A, DIMENSION> {
78        AddressIterator::<V, A, DIMENSION>::new(
79            self.smallest_contained_address().into(),
80            self.largest_contained_address().into(),
81        )
82    }
83}