1use ndarray::{ArrayD, Axis, IxDyn};
5use std::fmt;
6
7#[derive(Clone)]
9pub struct TlTensor {
10 pub data: ArrayD<f64>,
11 pub name: Option<String>,
12}
13
14impl fmt::Debug for TlTensor {
15 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
16 write!(f, "Tensor(shape={:?})", self.data.shape())
17 }
18}
19
20impl fmt::Display for TlTensor {
21 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
22 let shape = self.data.shape();
23 if shape.len() == 1 && shape[0] <= 10 {
24 write!(f, "tensor([")?;
25 for (i, v) in self.data.iter().enumerate() {
26 if i > 0 {
27 write!(f, ", ")?;
28 }
29 if v.fract() == 0.0 {
30 write!(f, "{v:.1}")?;
31 } else {
32 write!(f, "{v}")?;
33 }
34 }
35 write!(f, "])")
36 } else {
37 write!(f, "tensor(shape={:?})", shape)
38 }
39 }
40}
41
42impl TlTensor {
43 pub fn zeros(shape: &[usize]) -> Self {
45 TlTensor {
46 data: ArrayD::zeros(IxDyn(shape)),
47 name: None,
48 }
49 }
50
51 pub fn ones(shape: &[usize]) -> Self {
53 TlTensor {
54 data: ArrayD::ones(IxDyn(shape)),
55 name: None,
56 }
57 }
58
59 pub fn from_vec(data: Vec<f64>, shape: &[usize]) -> Result<Self, String> {
61 let expected: usize = shape.iter().product();
62 if data.len() != expected {
63 return Err(format!(
64 "Shape {:?} requires {} elements, got {}",
65 shape,
66 expected,
67 data.len()
68 ));
69 }
70 let arr = ArrayD::from_shape_vec(IxDyn(shape), data)
71 .map_err(|e| format!("Failed to create tensor: {e}"))?;
72 Ok(TlTensor {
73 data: arr,
74 name: None,
75 })
76 }
77
78 pub fn from_list(data: Vec<f64>) -> Self {
80 let len = data.len();
81 TlTensor {
82 data: ArrayD::from_shape_vec(IxDyn(&[len]), data).unwrap(),
83 name: None,
84 }
85 }
86
87 pub fn shape(&self) -> Vec<usize> {
89 self.data.shape().to_vec()
90 }
91
92 pub fn reshape(&self, new_shape: &[usize]) -> Result<Self, String> {
94 let new_data = self
95 .data
96 .clone()
97 .into_shape(IxDyn(new_shape))
98 .map_err(|e| format!("Reshape failed: {e}"))?;
99 Ok(TlTensor {
100 data: new_data,
101 name: self.name.clone(),
102 })
103 }
104
105 pub fn transpose(&self) -> Result<Self, String> {
107 if self.data.ndim() != 2 {
108 return Err(format!(
109 "Transpose requires 2D tensor, got {}D",
110 self.data.ndim()
111 ));
112 }
113 let transposed = self.data.clone().reversed_axes();
114 Ok(TlTensor {
115 data: transposed,
116 name: self.name.clone(),
117 })
118 }
119
120 pub fn flatten(&self) -> Self {
122 let flat: Vec<f64> = self.data.iter().cloned().collect();
123 TlTensor::from_list(flat)
124 }
125
126 pub fn sum(&self) -> f64 {
128 self.data.sum()
129 }
130
131 pub fn mean(&self) -> f64 {
133 let n = self.data.len() as f64;
134 if n == 0.0 { 0.0 } else { self.data.sum() / n }
135 }
136
137 pub fn min(&self) -> f64 {
139 self.data.iter().cloned().fold(f64::INFINITY, f64::min)
140 }
141
142 pub fn max(&self) -> f64 {
144 self.data.iter().cloned().fold(f64::NEG_INFINITY, f64::max)
145 }
146
147 pub fn get(&self, indices: &[usize]) -> Option<f64> {
149 self.data.get(IxDyn(indices)).cloned()
150 }
151
152 pub fn slice(&self, start: usize, end: usize) -> Result<Self, String> {
154 if self.data.ndim() == 0 {
155 return Err("Cannot slice a scalar tensor".to_string());
156 }
157 let sliced = self
158 .data
159 .slice_axis(Axis(0), ndarray::Slice::from(start..end));
160 Ok(TlTensor {
161 data: sliced.to_owned(),
162 name: self.name.clone(),
163 })
164 }
165
166 pub fn to_vec(&self) -> Vec<f64> {
168 self.data.iter().cloned().collect()
169 }
170
171 pub fn add(&self, other: &TlTensor) -> Result<Self, String> {
173 let result = &self.data + &other.data;
174 Ok(TlTensor {
175 data: result,
176 name: None,
177 })
178 }
179
180 pub fn sub(&self, other: &TlTensor) -> Result<Self, String> {
182 let result = &self.data - &other.data;
183 Ok(TlTensor {
184 data: result,
185 name: None,
186 })
187 }
188
189 pub fn mul(&self, other: &TlTensor) -> Result<Self, String> {
191 let result = &self.data * &other.data;
192 Ok(TlTensor {
193 data: result,
194 name: None,
195 })
196 }
197
198 pub fn div(&self, other: &TlTensor) -> Result<Self, String> {
200 let result = &self.data / &other.data;
201 Ok(TlTensor {
202 data: result,
203 name: None,
204 })
205 }
206
207 pub fn dot(&self, other: &TlTensor) -> Result<Self, String> {
209 if self.data.ndim() == 1 && other.data.ndim() == 1 {
211 let a = self.data.as_slice().ok_or("Non-contiguous tensor")?;
212 let b = other.data.as_slice().ok_or("Non-contiguous tensor")?;
213 if a.len() != b.len() {
214 return Err(format!(
215 "Dot product dimension mismatch: {} vs {}",
216 a.len(),
217 b.len()
218 ));
219 }
220 let result: f64 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
221 Ok(TlTensor {
222 data: ArrayD::from_elem(IxDyn(&[]), result),
223 name: None,
224 })
225 }
226 else if self.data.ndim() == 2 && other.data.ndim() == 2 {
228 let a = self
229 .data
230 .view()
231 .into_dimensionality::<ndarray::Ix2>()
232 .map_err(|e| format!("Shape error: {e}"))?;
233 let b = other
234 .data
235 .view()
236 .into_dimensionality::<ndarray::Ix2>()
237 .map_err(|e| format!("Shape error: {e}"))?;
238 let c = a.dot(&b);
239 Ok(TlTensor {
240 data: c.into_dyn(),
241 name: None,
242 })
243 } else {
244 Err(format!(
245 "Dot product not supported for {}D and {}D tensors",
246 self.data.ndim(),
247 other.data.ndim()
248 ))
249 }
250 }
251
252 pub fn scale(&self, scalar: f64) -> Self {
254 TlTensor {
255 data: &self.data * scalar,
256 name: self.name.clone(),
257 }
258 }
259
260 pub fn cosine_similarity(&self, other: &TlTensor) -> Result<f64, String> {
262 let a = self.data.as_slice().ok_or("Non-contiguous tensor")?;
263 let b = other.data.as_slice().ok_or("Non-contiguous tensor")?;
264 if a.len() != b.len() {
265 return Err(format!(
266 "Dimension mismatch for cosine similarity: {} vs {}",
267 a.len(),
268 b.len()
269 ));
270 }
271 let dot: f64 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
272 let norm_a: f64 = a.iter().map(|x| x * x).sum::<f64>().sqrt();
273 let norm_b: f64 = b.iter().map(|x| x * x).sum::<f64>().sqrt();
274 if norm_a == 0.0 || norm_b == 0.0 {
275 return Ok(0.0);
276 }
277 Ok(dot / (norm_a * norm_b))
278 }
279}
280
281#[cfg(test)]
282mod tests {
283 use super::*;
284
285 #[test]
286 fn test_zeros_ones() {
287 let z = TlTensor::zeros(&[2, 3]);
288 assert_eq!(z.shape(), vec![2, 3]);
289 assert_eq!(z.sum(), 0.0);
290
291 let o = TlTensor::ones(&[2, 3]);
292 assert_eq!(o.sum(), 6.0);
293 }
294
295 #[test]
296 fn test_from_vec() {
297 let t = TlTensor::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[2, 2]).unwrap();
298 assert_eq!(t.shape(), vec![2, 2]);
299 assert_eq!(t.get(&[0, 0]), Some(1.0));
300 assert_eq!(t.get(&[1, 1]), Some(4.0));
301 }
302
303 #[test]
304 fn test_from_list() {
305 let t = TlTensor::from_list(vec![1.0, 2.0, 3.0]);
306 assert_eq!(t.shape(), vec![3]);
307 assert_eq!(t.sum(), 6.0);
308 }
309
310 #[test]
311 fn test_arithmetic() {
312 let a = TlTensor::from_list(vec![1.0, 2.0, 3.0]);
313 let b = TlTensor::from_list(vec![4.0, 5.0, 6.0]);
314
315 let sum = a.add(&b).unwrap();
316 assert_eq!(sum.to_vec(), vec![5.0, 7.0, 9.0]);
317
318 let diff = a.sub(&b).unwrap();
319 assert_eq!(diff.to_vec(), vec![-3.0, -3.0, -3.0]);
320
321 let prod = a.mul(&b).unwrap();
322 assert_eq!(prod.to_vec(), vec![4.0, 10.0, 18.0]);
323
324 let quot = b.div(&a).unwrap();
325 assert_eq!(quot.to_vec(), vec![4.0, 2.5, 2.0]);
326 }
327
328 #[test]
329 fn test_reshape() {
330 let t = TlTensor::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[2, 3]).unwrap();
331 let r = t.reshape(&[3, 2]).unwrap();
332 assert_eq!(r.shape(), vec![3, 2]);
333 assert_eq!(r.get(&[0, 0]), Some(1.0));
334 }
335
336 #[test]
337 fn test_transpose() {
338 let t = TlTensor::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[2, 3]).unwrap();
339 let tr = t.transpose().unwrap();
340 assert_eq!(tr.shape(), vec![3, 2]);
341 assert_eq!(tr.get(&[0, 0]), Some(1.0));
342 assert_eq!(tr.get(&[0, 1]), Some(4.0));
343 }
344
345 #[test]
346 fn test_dot_1d() {
347 let a = TlTensor::from_list(vec![1.0, 2.0, 3.0]);
348 let b = TlTensor::from_list(vec![4.0, 5.0, 6.0]);
349 let dot = a.dot(&b).unwrap();
350 assert_eq!(dot.sum(), 32.0); }
352
353 #[test]
354 fn test_dot_2d() {
355 let a = TlTensor::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[2, 2]).unwrap();
356 let b = TlTensor::from_vec(vec![5.0, 6.0, 7.0, 8.0], &[2, 2]).unwrap();
357 let c = a.dot(&b).unwrap();
358 assert_eq!(c.shape(), vec![2, 2]);
359 assert_eq!(c.get(&[0, 0]), Some(19.0)); assert_eq!(c.get(&[0, 1]), Some(22.0)); }
362
363 #[test]
364 fn test_reductions() {
365 let t = TlTensor::from_list(vec![1.0, 2.0, 3.0, 4.0, 5.0]);
366 assert_eq!(t.sum(), 15.0);
367 assert_eq!(t.mean(), 3.0);
368 assert_eq!(t.min(), 1.0);
369 assert_eq!(t.max(), 5.0);
370 }
371
372 #[test]
373 fn test_cosine_similarity() {
374 let a = TlTensor::from_list(vec![1.0, 0.0]);
375 let b = TlTensor::from_list(vec![1.0, 0.0]);
376 let sim = a.cosine_similarity(&b).unwrap();
377 assert!((sim - 1.0).abs() < 1e-10);
378
379 let c = TlTensor::from_list(vec![0.0, 1.0]);
380 let sim2 = a.cosine_similarity(&c).unwrap();
381 assert!(sim2.abs() < 1e-10); }
383
384 #[test]
385 fn test_scale() {
386 let t = TlTensor::from_list(vec![1.0, 2.0, 3.0]);
387 let scaled = t.scale(2.0);
388 assert_eq!(scaled.to_vec(), vec![2.0, 4.0, 6.0]);
389 }
390
391 #[test]
392 fn test_flatten() {
393 let t = TlTensor::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[2, 2]).unwrap();
394 let flat = t.flatten();
395 assert_eq!(flat.shape(), vec![4]);
396 assert_eq!(flat.to_vec(), vec![1.0, 2.0, 3.0, 4.0]);
397 }
398
399 #[test]
400 fn test_slice() {
401 let t = TlTensor::from_list(vec![10.0, 20.0, 30.0, 40.0, 50.0]);
402 let sliced = t.slice(1, 4).unwrap();
403 assert_eq!(sliced.to_vec(), vec![20.0, 30.0, 40.0]);
404 }
405
406 #[test]
407 fn test_display() {
408 let t = TlTensor::from_list(vec![1.0, 2.0, 3.0]);
409 let s = format!("{t}");
410 assert_eq!(s, "tensor([1.0, 2.0, 3.0])");
411 }
412}