Skip to main content

tl_ai/
tensor.rs

1// ThinkingLanguage — Tensor type
2// Wraps ndarray::ArrayD<f64> for numerical computing.
3
4use ndarray::{ArrayD, Axis, IxDyn};
5use std::fmt;
6
7/// A dynamically-shaped tensor of f64 values.
8#[derive(Clone)]
9pub struct TlTensor {
10    pub data: ArrayD<f64>,
11    pub name: Option<String>,
12}
13
14impl fmt::Debug for TlTensor {
15    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
16        write!(f, "Tensor(shape={:?})", self.data.shape())
17    }
18}
19
20impl fmt::Display for TlTensor {
21    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
22        let shape = self.data.shape();
23        if shape.len() == 1 && shape[0] <= 10 {
24            write!(f, "tensor([")?;
25            for (i, v) in self.data.iter().enumerate() {
26                if i > 0 {
27                    write!(f, ", ")?;
28                }
29                if v.fract() == 0.0 {
30                    write!(f, "{v:.1}")?;
31                } else {
32                    write!(f, "{v}")?;
33                }
34            }
35            write!(f, "])")
36        } else {
37            write!(f, "tensor(shape={:?})", shape)
38        }
39    }
40}
41
42impl TlTensor {
43    /// Create a tensor filled with zeros.
44    pub fn zeros(shape: &[usize]) -> Self {
45        TlTensor {
46            data: ArrayD::zeros(IxDyn(shape)),
47            name: None,
48        }
49    }
50
51    /// Create a tensor filled with ones.
52    pub fn ones(shape: &[usize]) -> Self {
53        TlTensor {
54            data: ArrayD::ones(IxDyn(shape)),
55            name: None,
56        }
57    }
58
59    /// Create a tensor from a flat Vec and a shape.
60    pub fn from_vec(data: Vec<f64>, shape: &[usize]) -> Result<Self, String> {
61        let expected: usize = shape.iter().product();
62        if data.len() != expected {
63            return Err(format!(
64                "Shape {:?} requires {} elements, got {}",
65                shape,
66                expected,
67                data.len()
68            ));
69        }
70        let arr = ArrayD::from_shape_vec(IxDyn(shape), data)
71            .map_err(|e| format!("Failed to create tensor: {e}"))?;
72        Ok(TlTensor {
73            data: arr,
74            name: None,
75        })
76    }
77
78    /// Create a 1D tensor from a list of f64 values.
79    pub fn from_list(data: Vec<f64>) -> Self {
80        let len = data.len();
81        TlTensor {
82            data: ArrayD::from_shape_vec(IxDyn(&[len]), data).unwrap(),
83            name: None,
84        }
85    }
86
87    /// Get the shape as a Vec.
88    pub fn shape(&self) -> Vec<usize> {
89        self.data.shape().to_vec()
90    }
91
92    /// Reshape the tensor.
93    pub fn reshape(&self, new_shape: &[usize]) -> Result<Self, String> {
94        let new_data = self
95            .data
96            .clone()
97            .into_shape(IxDyn(new_shape))
98            .map_err(|e| format!("Reshape failed: {e}"))?;
99        Ok(TlTensor {
100            data: new_data,
101            name: self.name.clone(),
102        })
103    }
104
105    /// Transpose a 2D tensor.
106    pub fn transpose(&self) -> Result<Self, String> {
107        if self.data.ndim() != 2 {
108            return Err(format!(
109                "Transpose requires 2D tensor, got {}D",
110                self.data.ndim()
111            ));
112        }
113        let transposed = self.data.clone().reversed_axes();
114        Ok(TlTensor {
115            data: transposed,
116            name: self.name.clone(),
117        })
118    }
119
120    /// Flatten to 1D.
121    pub fn flatten(&self) -> Self {
122        let flat: Vec<f64> = self.data.iter().cloned().collect();
123        TlTensor::from_list(flat)
124    }
125
126    /// Sum of all elements.
127    pub fn sum(&self) -> f64 {
128        self.data.sum()
129    }
130
131    /// Mean of all elements.
132    pub fn mean(&self) -> f64 {
133        let n = self.data.len() as f64;
134        if n == 0.0 { 0.0 } else { self.data.sum() / n }
135    }
136
137    /// Minimum element.
138    pub fn min(&self) -> f64 {
139        self.data.iter().cloned().fold(f64::INFINITY, f64::min)
140    }
141
142    /// Maximum element.
143    pub fn max(&self) -> f64 {
144        self.data.iter().cloned().fold(f64::NEG_INFINITY, f64::max)
145    }
146
147    /// Get element by flat index for 1D tensors or multi-index.
148    pub fn get(&self, indices: &[usize]) -> Option<f64> {
149        self.data.get(IxDyn(indices)).cloned()
150    }
151
152    /// Slice along first axis.
153    pub fn slice(&self, start: usize, end: usize) -> Result<Self, String> {
154        if self.data.ndim() == 0 {
155            return Err("Cannot slice a scalar tensor".to_string());
156        }
157        let sliced = self
158            .data
159            .slice_axis(Axis(0), ndarray::Slice::from(start..end));
160        Ok(TlTensor {
161            data: sliced.to_owned(),
162            name: self.name.clone(),
163        })
164    }
165
166    /// Convert to a flat Vec<f64>.
167    pub fn to_vec(&self) -> Vec<f64> {
168        self.data.iter().cloned().collect()
169    }
170
171    /// Element-wise addition.
172    pub fn add(&self, other: &TlTensor) -> Result<Self, String> {
173        let result = &self.data + &other.data;
174        Ok(TlTensor {
175            data: result,
176            name: None,
177        })
178    }
179
180    /// Element-wise subtraction.
181    pub fn sub(&self, other: &TlTensor) -> Result<Self, String> {
182        let result = &self.data - &other.data;
183        Ok(TlTensor {
184            data: result,
185            name: None,
186        })
187    }
188
189    /// Element-wise multiplication.
190    pub fn mul(&self, other: &TlTensor) -> Result<Self, String> {
191        let result = &self.data * &other.data;
192        Ok(TlTensor {
193            data: result,
194            name: None,
195        })
196    }
197
198    /// Element-wise division.
199    pub fn div(&self, other: &TlTensor) -> Result<Self, String> {
200        let result = &self.data / &other.data;
201        Ok(TlTensor {
202            data: result,
203            name: None,
204        })
205    }
206
207    /// Matrix multiplication (dot product) for 1D or 2D tensors.
208    pub fn dot(&self, other: &TlTensor) -> Result<Self, String> {
209        // 1D dot 1D → scalar
210        if self.data.ndim() == 1 && other.data.ndim() == 1 {
211            let a = self.data.as_slice().ok_or("Non-contiguous tensor")?;
212            let b = other.data.as_slice().ok_or("Non-contiguous tensor")?;
213            if a.len() != b.len() {
214                return Err(format!(
215                    "Dot product dimension mismatch: {} vs {}",
216                    a.len(),
217                    b.len()
218                ));
219            }
220            let result: f64 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
221            Ok(TlTensor {
222                data: ArrayD::from_elem(IxDyn(&[]), result),
223                name: None,
224            })
225        }
226        // 2D dot 2D → matrix multiply
227        else if self.data.ndim() == 2 && other.data.ndim() == 2 {
228            let a = self
229                .data
230                .view()
231                .into_dimensionality::<ndarray::Ix2>()
232                .map_err(|e| format!("Shape error: {e}"))?;
233            let b = other
234                .data
235                .view()
236                .into_dimensionality::<ndarray::Ix2>()
237                .map_err(|e| format!("Shape error: {e}"))?;
238            let c = a.dot(&b);
239            Ok(TlTensor {
240                data: c.into_dyn(),
241                name: None,
242            })
243        } else {
244            Err(format!(
245                "Dot product not supported for {}D and {}D tensors",
246                self.data.ndim(),
247                other.data.ndim()
248            ))
249        }
250    }
251
252    /// Scalar multiplication.
253    pub fn scale(&self, scalar: f64) -> Self {
254        TlTensor {
255            data: &self.data * scalar,
256            name: self.name.clone(),
257        }
258    }
259
260    /// Cosine similarity between two 1D tensors.
261    pub fn cosine_similarity(&self, other: &TlTensor) -> Result<f64, String> {
262        let a = self.data.as_slice().ok_or("Non-contiguous tensor")?;
263        let b = other.data.as_slice().ok_or("Non-contiguous tensor")?;
264        if a.len() != b.len() {
265            return Err(format!(
266                "Dimension mismatch for cosine similarity: {} vs {}",
267                a.len(),
268                b.len()
269            ));
270        }
271        let dot: f64 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
272        let norm_a: f64 = a.iter().map(|x| x * x).sum::<f64>().sqrt();
273        let norm_b: f64 = b.iter().map(|x| x * x).sum::<f64>().sqrt();
274        if norm_a == 0.0 || norm_b == 0.0 {
275            return Ok(0.0);
276        }
277        Ok(dot / (norm_a * norm_b))
278    }
279}
280
281#[cfg(test)]
282mod tests {
283    use super::*;
284
285    #[test]
286    fn test_zeros_ones() {
287        let z = TlTensor::zeros(&[2, 3]);
288        assert_eq!(z.shape(), vec![2, 3]);
289        assert_eq!(z.sum(), 0.0);
290
291        let o = TlTensor::ones(&[2, 3]);
292        assert_eq!(o.sum(), 6.0);
293    }
294
295    #[test]
296    fn test_from_vec() {
297        let t = TlTensor::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[2, 2]).unwrap();
298        assert_eq!(t.shape(), vec![2, 2]);
299        assert_eq!(t.get(&[0, 0]), Some(1.0));
300        assert_eq!(t.get(&[1, 1]), Some(4.0));
301    }
302
303    #[test]
304    fn test_from_list() {
305        let t = TlTensor::from_list(vec![1.0, 2.0, 3.0]);
306        assert_eq!(t.shape(), vec![3]);
307        assert_eq!(t.sum(), 6.0);
308    }
309
310    #[test]
311    fn test_arithmetic() {
312        let a = TlTensor::from_list(vec![1.0, 2.0, 3.0]);
313        let b = TlTensor::from_list(vec![4.0, 5.0, 6.0]);
314
315        let sum = a.add(&b).unwrap();
316        assert_eq!(sum.to_vec(), vec![5.0, 7.0, 9.0]);
317
318        let diff = a.sub(&b).unwrap();
319        assert_eq!(diff.to_vec(), vec![-3.0, -3.0, -3.0]);
320
321        let prod = a.mul(&b).unwrap();
322        assert_eq!(prod.to_vec(), vec![4.0, 10.0, 18.0]);
323
324        let quot = b.div(&a).unwrap();
325        assert_eq!(quot.to_vec(), vec![4.0, 2.5, 2.0]);
326    }
327
328    #[test]
329    fn test_reshape() {
330        let t = TlTensor::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[2, 3]).unwrap();
331        let r = t.reshape(&[3, 2]).unwrap();
332        assert_eq!(r.shape(), vec![3, 2]);
333        assert_eq!(r.get(&[0, 0]), Some(1.0));
334    }
335
336    #[test]
337    fn test_transpose() {
338        let t = TlTensor::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[2, 3]).unwrap();
339        let tr = t.transpose().unwrap();
340        assert_eq!(tr.shape(), vec![3, 2]);
341        assert_eq!(tr.get(&[0, 0]), Some(1.0));
342        assert_eq!(tr.get(&[0, 1]), Some(4.0));
343    }
344
345    #[test]
346    fn test_dot_1d() {
347        let a = TlTensor::from_list(vec![1.0, 2.0, 3.0]);
348        let b = TlTensor::from_list(vec![4.0, 5.0, 6.0]);
349        let dot = a.dot(&b).unwrap();
350        assert_eq!(dot.sum(), 32.0); // 1*4 + 2*5 + 3*6
351    }
352
353    #[test]
354    fn test_dot_2d() {
355        let a = TlTensor::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[2, 2]).unwrap();
356        let b = TlTensor::from_vec(vec![5.0, 6.0, 7.0, 8.0], &[2, 2]).unwrap();
357        let c = a.dot(&b).unwrap();
358        assert_eq!(c.shape(), vec![2, 2]);
359        assert_eq!(c.get(&[0, 0]), Some(19.0)); // 1*5 + 2*7
360        assert_eq!(c.get(&[0, 1]), Some(22.0)); // 1*6 + 2*8
361    }
362
363    #[test]
364    fn test_reductions() {
365        let t = TlTensor::from_list(vec![1.0, 2.0, 3.0, 4.0, 5.0]);
366        assert_eq!(t.sum(), 15.0);
367        assert_eq!(t.mean(), 3.0);
368        assert_eq!(t.min(), 1.0);
369        assert_eq!(t.max(), 5.0);
370    }
371
372    #[test]
373    fn test_cosine_similarity() {
374        let a = TlTensor::from_list(vec![1.0, 0.0]);
375        let b = TlTensor::from_list(vec![1.0, 0.0]);
376        let sim = a.cosine_similarity(&b).unwrap();
377        assert!((sim - 1.0).abs() < 1e-10);
378
379        let c = TlTensor::from_list(vec![0.0, 1.0]);
380        let sim2 = a.cosine_similarity(&c).unwrap();
381        assert!(sim2.abs() < 1e-10); // orthogonal
382    }
383
384    #[test]
385    fn test_scale() {
386        let t = TlTensor::from_list(vec![1.0, 2.0, 3.0]);
387        let scaled = t.scale(2.0);
388        assert_eq!(scaled.to_vec(), vec![2.0, 4.0, 6.0]);
389    }
390
391    #[test]
392    fn test_flatten() {
393        let t = TlTensor::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[2, 2]).unwrap();
394        let flat = t.flatten();
395        assert_eq!(flat.shape(), vec![4]);
396        assert_eq!(flat.to_vec(), vec![1.0, 2.0, 3.0, 4.0]);
397    }
398
399    #[test]
400    fn test_slice() {
401        let t = TlTensor::from_list(vec![10.0, 20.0, 30.0, 40.0, 50.0]);
402        let sliced = t.slice(1, 4).unwrap();
403        assert_eq!(sliced.to_vec(), vec![20.0, 30.0, 40.0]);
404    }
405
406    #[test]
407    fn test_display() {
408        let t = TlTensor::from_list(vec![1.0, 2.0, 3.0]);
409        let s = format!("{t}");
410        assert_eq!(s, "tensor([1.0, 2.0, 3.0])");
411    }
412}