Skip to main content

somatize_core/
schema.rs

1//! Schema — dtype and shape for compile-time type checking between filters.
2//!
3//! The compiler validates that connected filters have compatible schemas
4//! before execution begins, catching shape/type mismatches early.
5
6use serde::{Deserialize, Serialize};
7use std::fmt;
8
9/// Primitive data types that Soma values can contain.
10#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
11#[non_exhaustive]
12pub enum DataType {
13    /// 64-bit floating point.
14    Float64,
15    /// 32-bit floating point.
16    Float32,
17    /// 64-bit signed integer.
18    Int64,
19    /// Boolean.
20    Bool,
21    /// UTF-8 string.
22    Utf8,
23    /// Raw bytes.
24    Bytes,
25    /// Structured JSON (any shape).
26    Json,
27}
28
29impl fmt::Display for DataType {
30    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
31        match self {
32            Self::Float64 => write!(f, "f64"),
33            Self::Float32 => write!(f, "f32"),
34            Self::Int64 => write!(f, "i64"),
35            Self::Bool => write!(f, "bool"),
36            Self::Utf8 => write!(f, "str"),
37            Self::Bytes => write!(f, "bytes"),
38            Self::Json => write!(f, "json"),
39        }
40    }
41}
42
43/// Describes the shape and type of a Value, without holding the actual data.
44///
45/// Used by:
46/// - Filters: declare what they accept (input) and produce (output)
47/// - Compiler: validate type compatibility between connected filters
48/// - VirtualValue: know schema without materializing
49/// - Cache metadata: describe stored entries
50#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
51pub struct Schema {
52    /// The primitive data type.
53    pub dtype: DataType,
54
55    /// Shape dimensions. Empty for scalars, [n] for vectors, [r,c] for matrices, etc.
56    /// `None` means shape is dynamic/unknown.
57    pub shape: Option<Vec<Dimension>>,
58}
59
60/// A single dimension in a tensor shape.
61#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
62pub enum Dimension {
63    /// Fixed size (e.g., 128 features).
64    Fixed(usize),
65    /// Dynamic size (e.g., batch dimension). Named for documentation.
66    Dynamic(String),
67}
68
69impl fmt::Display for Dimension {
70    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
71        match self {
72            Self::Fixed(n) => write!(f, "{n}"),
73            Self::Dynamic(name) => write!(f, "{name}"),
74        }
75    }
76}
77
78impl Schema {
79    /// Create a schema for a 1D tensor (vector) of known length.
80    pub fn vector(dtype: DataType, len: usize) -> Self {
81        Self {
82            dtype,
83            shape: Some(vec![Dimension::Fixed(len)]),
84        }
85    }
86
87    /// Create a schema for a 2D tensor (matrix) with known dimensions.
88    pub fn matrix(dtype: DataType, rows: usize, cols: usize) -> Self {
89        Self {
90            dtype,
91            shape: Some(vec![Dimension::Fixed(rows), Dimension::Fixed(cols)]),
92        }
93    }
94
95    /// Create a schema for a tensor with a dynamic batch dimension.
96    pub fn batched(dtype: DataType, feature_dims: &[usize]) -> Self {
97        let mut dims = vec![Dimension::Dynamic("batch".into())];
98        dims.extend(feature_dims.iter().map(|&d| Dimension::Fixed(d)));
99        Self {
100            dtype,
101            shape: Some(dims),
102        }
103    }
104
105    /// Create a schema for a scalar value.
106    pub fn scalar(dtype: DataType) -> Self {
107        Self {
108            dtype,
109            shape: Some(vec![]),
110        }
111    }
112
113    /// Create a schema for JSON data (shape is irrelevant).
114    pub fn json() -> Self {
115        Self {
116            dtype: DataType::Json,
117            shape: None,
118        }
119    }
120
121    /// Create a schema for raw bytes.
122    pub fn bytes() -> Self {
123        Self {
124            dtype: DataType::Bytes,
125            shape: None,
126        }
127    }
128
129    /// Create a schema with fully dynamic (unknown) shape.
130    pub fn dynamic(dtype: DataType) -> Self {
131        Self { dtype, shape: None }
132    }
133
134    /// Check if this schema is compatible with another (can be connected in a pipeline).
135    ///
136    /// Compatibility rules:
137    /// - Same dtype required (no implicit coercion)
138    /// - If both shapes are known, fixed dimensions must match
139    /// - Dynamic dimensions are compatible with any size
140    /// - Unknown shape (None) is compatible with anything of the same dtype
141    pub fn is_compatible_with(&self, other: &Schema) -> bool {
142        if self.dtype != other.dtype {
143            return false;
144        }
145
146        match (&self.shape, &other.shape) {
147            (None, _) | (_, None) => true, // unknown shape is flexible
148            (Some(a), Some(b)) => {
149                if a.len() != b.len() {
150                    return false;
151                }
152                a.iter().zip(b.iter()).all(|(da, db)| match (da, db) {
153                    (Dimension::Fixed(x), Dimension::Fixed(y)) => x == y,
154                    _ => true, // dynamic is compatible with anything
155                })
156            }
157        }
158    }
159
160    /// Number of known dimensions (rank).
161    pub fn rank(&self) -> Option<usize> {
162        self.shape.as_ref().map(|s| s.len())
163    }
164}
165
166impl fmt::Display for Schema {
167    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
168        write!(f, "{}", self.dtype)?;
169        if let Some(shape) = &self.shape {
170            if shape.is_empty() {
171                write!(f, " (scalar)")?;
172            } else {
173                let dims: Vec<String> = shape.iter().map(|d| d.to_string()).collect();
174                write!(f, "[{}]", dims.join(", "))?;
175            }
176        }
177        Ok(())
178    }
179}
180
181#[cfg(test)]
182mod tests {
183    use super::*;
184
185    #[test]
186    fn schema_display() {
187        assert_eq!(
188            Schema::scalar(DataType::Float64).to_string(),
189            "f64 (scalar)"
190        );
191        assert_eq!(
192            Schema::vector(DataType::Float64, 128).to_string(),
193            "f64[128]"
194        );
195        assert_eq!(
196            Schema::matrix(DataType::Float64, 100, 50).to_string(),
197            "f64[100, 50]"
198        );
199        assert_eq!(
200            Schema::batched(DataType::Float32, &[128]).to_string(),
201            "f32[batch, 128]"
202        );
203        assert_eq!(Schema::json().to_string(), "json");
204    }
205
206    #[test]
207    fn compatible_same_schema() {
208        let s = Schema::vector(DataType::Float64, 128);
209        assert!(s.is_compatible_with(&s));
210    }
211
212    #[test]
213    fn compatible_dynamic_with_fixed() {
214        let dynamic = Schema::batched(DataType::Float64, &[128]);
215        let fixed = Schema::matrix(DataType::Float64, 32, 128);
216        assert!(dynamic.is_compatible_with(&fixed));
217        assert!(fixed.is_compatible_with(&dynamic));
218    }
219
220    #[test]
221    fn compatible_unknown_shape() {
222        let unknown = Schema::dynamic(DataType::Float64);
223        let known = Schema::vector(DataType::Float64, 128);
224        assert!(unknown.is_compatible_with(&known));
225        assert!(known.is_compatible_with(&unknown));
226    }
227
228    #[test]
229    fn incompatible_different_dtype() {
230        let f64_schema = Schema::vector(DataType::Float64, 128);
231        let i64_schema = Schema::vector(DataType::Int64, 128);
232        assert!(!f64_schema.is_compatible_with(&i64_schema));
233    }
234
235    #[test]
236    fn incompatible_different_fixed_dims() {
237        let a = Schema::vector(DataType::Float64, 128);
238        let b = Schema::vector(DataType::Float64, 256);
239        assert!(!a.is_compatible_with(&b));
240    }
241
242    #[test]
243    fn incompatible_different_rank() {
244        let vec = Schema::vector(DataType::Float64, 128);
245        let mat = Schema::matrix(DataType::Float64, 128, 64);
246        assert!(!vec.is_compatible_with(&mat));
247    }
248
249    #[test]
250    fn json_compatible_with_json() {
251        assert!(Schema::json().is_compatible_with(&Schema::json()));
252    }
253
254    #[test]
255    fn json_incompatible_with_tensor() {
256        assert!(!Schema::json().is_compatible_with(&Schema::vector(DataType::Float64, 10)));
257    }
258
259    #[test]
260    fn serde_roundtrip() {
261        let schemas = vec![
262            Schema::scalar(DataType::Float64),
263            Schema::vector(DataType::Float32, 100),
264            Schema::batched(DataType::Float64, &[128, 64]),
265            Schema::json(),
266            Schema::dynamic(DataType::Int64),
267        ];
268        for s in schemas {
269            let json = serde_json::to_string(&s).unwrap();
270            let deserialized: Schema = serde_json::from_str(&json).unwrap();
271            assert_eq!(s, deserialized);
272        }
273    }
274
275    #[test]
276    fn rank() {
277        assert_eq!(Schema::scalar(DataType::Float64).rank(), Some(0));
278        assert_eq!(Schema::vector(DataType::Float64, 10).rank(), Some(1));
279        assert_eq!(Schema::matrix(DataType::Float64, 10, 5).rank(), Some(2));
280        assert_eq!(Schema::json().rank(), None);
281    }
282}