Skip to main content

tensorlogic_adapters/
axis.rs

1//! Axis metadata for variable-to-dimension mappings.
2
3use indexmap::IndexMap;
4use serde::{Deserialize, Serialize};
5
6/// Axis metadata linking variables to tensor dimensions
7#[derive(Clone, Debug, Serialize, Deserialize)]
8pub struct AxisMetadata {
9    pub var_to_axis: IndexMap<String, usize>,
10    pub axis_to_domain: IndexMap<usize, String>,
11    pub axis_to_char: IndexMap<usize, char>,
12}
13
14impl AxisMetadata {
15    pub fn new() -> Self {
16        AxisMetadata {
17            var_to_axis: IndexMap::new(),
18            axis_to_domain: IndexMap::new(),
19            axis_to_char: IndexMap::new(),
20        }
21    }
22
23    pub fn assign(&mut self, var: impl Into<String>, domain: impl Into<String>) -> usize {
24        let var = var.into();
25        let domain = domain.into();
26
27        if let Some(&axis) = self.var_to_axis.get(&var) {
28            return axis;
29        }
30
31        let axis = self.var_to_axis.len();
32        let axis_char = (b'a' + axis as u8) as char;
33
34        self.var_to_axis.insert(var, axis);
35        self.axis_to_domain.insert(axis, domain);
36        self.axis_to_char.insert(axis, axis_char);
37
38        axis
39    }
40
41    pub fn get_axis(&self, var: &str) -> Option<usize> {
42        self.var_to_axis.get(var).copied()
43    }
44
45    pub fn get_domain(&self, axis: usize) -> Option<&str> {
46        self.axis_to_domain.get(&axis).map(|s| s.as_str())
47    }
48
49    pub fn get_char(&self, axis: usize) -> Option<char> {
50        self.axis_to_char.get(&axis).copied()
51    }
52
53    pub fn build_spec(&self, vars: &[String]) -> String {
54        vars.iter()
55            .filter_map(|v| self.var_to_axis.get(v))
56            .filter_map(|&axis| self.axis_to_char.get(&axis))
57            .collect()
58    }
59}
60
61impl Default for AxisMetadata {
62    fn default() -> Self {
63        Self::new()
64    }
65}