1use crate::Scalar;
7use crate::dtype::Float;
8use crate::error::{CoreError, Result};
9
10use super::{Tensor, compute_strides};
11
12#[derive(Debug, Clone)]
28pub struct NamedTensor<T: Scalar> {
29 tensor: Tensor<T>,
30 names: Vec<Option<String>>,
31}
32
33impl<T: Scalar> NamedTensor<T> {
34 pub fn new(tensor: Tensor<T>, names: Vec<Option<String>>) -> Result<Self> {
42 if names.len() != tensor.ndim() {
43 return Err(CoreError::InvalidArgument {
44 reason: "number of dimension names must match tensor rank",
45 });
46 }
47 let named: Vec<&str> = names.iter().filter_map(|n| n.as_deref()).collect();
49 let mut sorted = named.clone();
50 sorted.sort_unstable();
51 for window in sorted.windows(2) {
52 if window[0] == window[1] {
53 return Err(CoreError::InvalidArgument {
54 reason: "duplicate dimension names are not allowed",
55 });
56 }
57 }
58 Ok(Self { tensor, names })
59 }
60
61 pub fn from_tensor(tensor: Tensor<T>) -> Self {
63 let ndim = tensor.ndim();
64 Self {
65 tensor,
66 names: vec![None; ndim],
67 }
68 }
69
70 #[inline]
72 pub fn tensor(&self) -> &Tensor<T> {
73 &self.tensor
74 }
75
76 #[inline]
78 pub fn names(&self) -> &[Option<String>] {
79 &self.names
80 }
81
82 #[inline]
84 pub fn into_tensor(self) -> Tensor<T> {
85 self.tensor
86 }
87
88 pub fn dim_index(&self, name: &str) -> Result<usize> {
94 self.names
95 .iter()
96 .position(|n| n.as_deref() == Some(name))
97 .ok_or(CoreError::InvalidArgument {
98 reason: "dimension name not found",
99 })
100 }
101
102 pub fn rename(&mut self, old: &str, new: &str) -> Result<()> {
108 if self.names.iter().any(|n| n.as_deref() == Some(new)) {
110 return Err(CoreError::InvalidArgument {
111 reason: "new dimension name already exists",
112 });
113 }
114 let idx = self.dim_index(old)?;
115 self.names[idx] = Some(new.to_string());
116 Ok(())
117 }
118
119 pub fn set_names(&mut self, names: Vec<Option<String>>) -> Result<()> {
125 if names.len() != self.tensor.ndim() {
126 return Err(CoreError::InvalidArgument {
127 reason: "number of dimension names must match tensor rank",
128 });
129 }
130 self.names = names;
131 Ok(())
132 }
133
134 pub fn align_to(&self, target_names: &[&str]) -> Result<NamedTensor<T>> {
144 if target_names.len() != self.tensor.ndim() {
145 return Err(CoreError::InvalidArgument {
146 reason: "target names length must match tensor rank",
147 });
148 }
149
150 let perm: Vec<usize> = target_names
152 .iter()
153 .map(|name| self.dim_index(name))
154 .collect::<Result<Vec<_>>>()?;
155
156 let src_shape = self.tensor.shape();
157 let src_strides = self.tensor.strides();
158 let src_data = self.tensor.as_slice();
159
160 let new_shape: Vec<usize> = perm.iter().map(|&p| src_shape[p]).collect();
162 let new_names: Vec<Option<String>> = perm.iter().map(|&p| self.names[p].clone()).collect();
163 let new_strides = compute_strides(&new_shape);
164
165 let numel: usize = new_shape.iter().product();
166 let mut new_data = vec![T::zero(); numel];
167
168 for (out_flat, dest) in new_data.iter_mut().enumerate() {
170 let mut remaining = out_flat;
172 let mut src_flat = 0usize;
173 for (dim, &stride) in new_strides.iter().enumerate() {
174 let idx = remaining / stride;
175 remaining %= stride;
176 src_flat += idx * src_strides[perm[dim]];
177 }
178 *dest = src_data[src_flat];
179 }
180
181 let new_tensor = Tensor::from_vec(new_data, new_shape)?;
182 Ok(NamedTensor {
183 tensor: new_tensor,
184 names: new_names,
185 })
186 }
187
188 pub fn select(&self, name: &str, index: usize) -> Result<NamedTensor<T>> {
194 let axis = self.dim_index(name)?;
195 let shape = self.tensor.shape();
196 if index >= shape[axis] {
197 return Err(CoreError::IndexOutOfBounds {
198 index: vec![index],
199 shape: shape.to_vec(),
200 });
201 }
202
203 let ndim = shape.len();
204 let strides = self.tensor.strides();
205 let src_data = self.tensor.as_slice();
206
207 let new_shape: Vec<usize> = shape
209 .iter()
210 .enumerate()
211 .filter(|&(i, _)| i != axis)
212 .map(|(_, &s)| s)
213 .collect();
214 let new_names: Vec<Option<String>> = self
215 .names
216 .iter()
217 .enumerate()
218 .filter(|&(i, _)| i != axis)
219 .map(|(_, n)| n.clone())
220 .collect();
221
222 let numel: usize = new_shape.iter().product();
223 let new_strides = compute_strides(&new_shape);
224 let mut new_data = vec![T::zero(); numel];
225
226 let dim_map: Vec<usize> = (0..ndim).filter(|&d| d != axis).collect();
228
229 for (out_flat, dest) in new_data.iter_mut().enumerate() {
230 let mut remaining = out_flat;
231 let mut src_flat = index * strides[axis];
232 for (out_dim, &src_dim) in dim_map.iter().enumerate() {
233 let idx = if out_dim < new_strides.len() {
234 let i = remaining / new_strides[out_dim];
235 remaining %= new_strides[out_dim];
236 i
237 } else {
238 remaining
239 };
240 src_flat += idx * strides[src_dim];
241 }
242 *dest = src_data[src_flat];
243 }
244
245 let new_tensor = Tensor::from_vec(new_data, new_shape)?;
246 Ok(NamedTensor {
247 tensor: new_tensor,
248 names: new_names,
249 })
250 }
251}
252
253impl<T: Scalar + Float> NamedTensor<T> {
254 pub fn sum_dim(&self, name: &str) -> Result<NamedTensor<T>> {
260 let axis = self.dim_index(name)?;
261 let shape = self.tensor.shape();
262 let strides = self.tensor.strides();
263 let src_data = self.tensor.as_slice();
264 let ndim = shape.len();
265 let axis_len = shape[axis];
266
267 let new_shape: Vec<usize> = shape
268 .iter()
269 .enumerate()
270 .filter(|&(i, _)| i != axis)
271 .map(|(_, &s)| s)
272 .collect();
273 let new_names: Vec<Option<String>> = self
274 .names
275 .iter()
276 .enumerate()
277 .filter(|&(i, _)| i != axis)
278 .map(|(_, n)| n.clone())
279 .collect();
280
281 let numel: usize = new_shape.iter().product();
282 let new_strides = compute_strides(&new_shape);
283 let mut new_data = vec![T::zero(); numel];
284
285 let dim_map: Vec<usize> = (0..ndim).filter(|&d| d != axis).collect();
287
288 for (out_flat, dest) in new_data.iter_mut().enumerate() {
289 let mut remaining = out_flat;
290 let mut out_indices = vec![0usize; ndim];
292 for (out_dim, &src_dim) in dim_map.iter().enumerate() {
293 let idx = if out_dim < new_strides.len() {
294 let i = remaining / new_strides[out_dim];
295 remaining %= new_strides[out_dim];
296 i
297 } else {
298 remaining
299 };
300 out_indices[src_dim] = idx;
301 }
302
303 let mut acc = T::zero();
304 for k in 0..axis_len {
305 out_indices[axis] = k;
306 let src_flat: usize = out_indices
307 .iter()
308 .zip(strides.iter())
309 .map(|(&idx, &s)| idx * s)
310 .sum();
311 acc += src_data[src_flat];
312 }
313 *dest = acc;
314 }
315
316 let new_tensor = Tensor::from_vec(new_data, new_shape)?;
317 Ok(NamedTensor {
318 tensor: new_tensor,
319 names: new_names,
320 })
321 }
322
323 pub fn mean_dim(&self, name: &str) -> Result<NamedTensor<T>> {
329 let axis = self.dim_index(name)?;
330 let axis_len = self.tensor.shape()[axis];
331 let summed = self.sum_dim(name)?;
332 let divisor = T::from_usize(axis_len);
333 let result_tensor = summed.tensor.map(|x| x / divisor);
334 Ok(NamedTensor {
335 tensor: result_tensor,
336 names: summed.names,
337 })
338 }
339}
340
341#[cfg(test)]
342mod tests {
343 use super::*;
344
345 #[test]
346 fn test_named_tensor_basic() {
347 let t = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![2, 3]).unwrap();
348 let nt = NamedTensor::new(t, vec![Some("batch".into()), Some("feature".into())]).unwrap();
349 assert_eq!(nt.names().len(), 2);
350 assert_eq!(nt.names()[0].as_deref(), Some("batch"));
351 assert_eq!(nt.names()[1].as_deref(), Some("feature"));
352 assert_eq!(nt.tensor().shape(), &[2, 3]);
353 }
354
355 #[test]
356 fn test_rename_dimension() {
357 let t = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
358 let mut nt = NamedTensor::new(t, vec![Some("rows".into()), Some("cols".into())]).unwrap();
359 nt.rename("rows", "samples").unwrap();
360 assert_eq!(nt.names()[0].as_deref(), Some("samples"));
361 assert_eq!(nt.dim_index("samples").unwrap(), 0);
362 assert!(nt.dim_index("rows").is_err());
363 }
364
365 #[test]
366 fn test_align_to() {
367 let numel = 2 * 3 * 4;
369 let data: Vec<f64> = (0..numel).map(f64::from).collect();
370 let t = Tensor::from_vec(data, vec![2, 3, 4]).unwrap();
371 let nt = NamedTensor::new(
372 t.clone(),
373 vec![
374 Some("batch".into()),
375 Some("channel".into()),
376 Some("width".into()),
377 ],
378 )
379 .unwrap();
380
381 let aligned = nt.align_to(&["channel", "width", "batch"]).unwrap();
383 assert_eq!(aligned.tensor().shape(), &[3, 4, 2]);
384 assert_eq!(aligned.names()[0].as_deref(), Some("channel"));
385 assert_eq!(aligned.names()[1].as_deref(), Some("width"));
386 assert_eq!(aligned.names()[2].as_deref(), Some("batch"));
387
388 let original_val = *t.get(&[1, 2, 3]).unwrap();
390 let aligned_val = *aligned.tensor().get(&[2, 3, 1]).unwrap();
391 assert!((original_val - aligned_val).abs() < 1e-15);
392 }
393
394 #[test]
395 fn test_dim_index() {
396 let t = Tensor::from_vec(vec![1.0_f64, 2.0, 3.0], vec![3]).unwrap();
397 let nt = NamedTensor::new(t, vec![Some("time".into())]).unwrap();
398 assert_eq!(nt.dim_index("time").unwrap(), 0);
399 assert!(nt.dim_index("space").is_err());
400 }
401
402 #[test]
403 fn test_sum_dim() {
404 let t = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![2, 3]).unwrap();
406 let nt = NamedTensor::new(t, vec![Some("rows".into()), Some("cols".into())]).unwrap();
407 let summed = nt.sum_dim("rows").unwrap();
408 assert_eq!(summed.tensor().shape(), &[3]);
409 let data = summed.tensor().as_slice();
410 assert!((data[0] - 5.0).abs() < 1e-15); assert!((data[1] - 7.0).abs() < 1e-15); assert!((data[2] - 9.0).abs() < 1e-15); assert_eq!(summed.names()[0].as_deref(), Some("cols"));
414 }
415
416 #[test]
417 fn test_select() {
418 let t = Tensor::from_vec(vec![1.0_f64, 2.0, 3.0, 4.0, 5.0, 6.0], vec![2, 3]).unwrap();
420 let nt = NamedTensor::new(t, vec![Some("rows".into()), Some("cols".into())]).unwrap();
421 let selected = nt.select("rows", 1).unwrap();
422 assert_eq!(selected.tensor().shape(), &[3]);
423 assert_eq!(selected.tensor().as_slice(), &[4.0, 5.0, 6.0]);
424 assert_eq!(selected.names()[0].as_deref(), Some("cols"));
425 }
426
427 #[test]
428 fn test_invalid_names_length() {
429 let t = Tensor::from_vec(vec![1.0, 2.0, 3.0], vec![3]).unwrap();
430 let result = NamedTensor::new(t, vec![Some("a".into()), Some("b".into())]);
431 assert!(result.is_err());
432 }
433}