1use serde::{Deserialize, Serialize};
7use std::fmt;
8
9#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
11#[non_exhaustive]
12pub enum DataType {
13 Float64,
15 Float32,
17 Int64,
19 Bool,
21 Utf8,
23 Bytes,
25 Json,
27}
28
29impl fmt::Display for DataType {
30 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
31 match self {
32 Self::Float64 => write!(f, "f64"),
33 Self::Float32 => write!(f, "f32"),
34 Self::Int64 => write!(f, "i64"),
35 Self::Bool => write!(f, "bool"),
36 Self::Utf8 => write!(f, "str"),
37 Self::Bytes => write!(f, "bytes"),
38 Self::Json => write!(f, "json"),
39 }
40 }
41}
42
43#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
51pub struct Schema {
52 pub dtype: DataType,
54
55 pub shape: Option<Vec<Dimension>>,
58}
59
60#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
62pub enum Dimension {
63 Fixed(usize),
65 Dynamic(String),
67}
68
69impl fmt::Display for Dimension {
70 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
71 match self {
72 Self::Fixed(n) => write!(f, "{n}"),
73 Self::Dynamic(name) => write!(f, "{name}"),
74 }
75 }
76}
77
78impl Schema {
79 pub fn vector(dtype: DataType, len: usize) -> Self {
81 Self {
82 dtype,
83 shape: Some(vec![Dimension::Fixed(len)]),
84 }
85 }
86
87 pub fn matrix(dtype: DataType, rows: usize, cols: usize) -> Self {
89 Self {
90 dtype,
91 shape: Some(vec![Dimension::Fixed(rows), Dimension::Fixed(cols)]),
92 }
93 }
94
95 pub fn batched(dtype: DataType, feature_dims: &[usize]) -> Self {
97 let mut dims = vec![Dimension::Dynamic("batch".into())];
98 dims.extend(feature_dims.iter().map(|&d| Dimension::Fixed(d)));
99 Self {
100 dtype,
101 shape: Some(dims),
102 }
103 }
104
105 pub fn scalar(dtype: DataType) -> Self {
107 Self {
108 dtype,
109 shape: Some(vec![]),
110 }
111 }
112
113 pub fn json() -> Self {
115 Self {
116 dtype: DataType::Json,
117 shape: None,
118 }
119 }
120
121 pub fn bytes() -> Self {
123 Self {
124 dtype: DataType::Bytes,
125 shape: None,
126 }
127 }
128
129 pub fn dynamic(dtype: DataType) -> Self {
131 Self { dtype, shape: None }
132 }
133
134 pub fn is_compatible_with(&self, other: &Schema) -> bool {
142 if self.dtype != other.dtype {
143 return false;
144 }
145
146 match (&self.shape, &other.shape) {
147 (None, _) | (_, None) => true, (Some(a), Some(b)) => {
149 if a.len() != b.len() {
150 return false;
151 }
152 a.iter().zip(b.iter()).all(|(da, db)| match (da, db) {
153 (Dimension::Fixed(x), Dimension::Fixed(y)) => x == y,
154 _ => true, })
156 }
157 }
158 }
159
160 pub fn rank(&self) -> Option<usize> {
162 self.shape.as_ref().map(|s| s.len())
163 }
164}
165
166impl fmt::Display for Schema {
167 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
168 write!(f, "{}", self.dtype)?;
169 if let Some(shape) = &self.shape {
170 if shape.is_empty() {
171 write!(f, " (scalar)")?;
172 } else {
173 let dims: Vec<String> = shape.iter().map(|d| d.to_string()).collect();
174 write!(f, "[{}]", dims.join(", "))?;
175 }
176 }
177 Ok(())
178 }
179}
180
181#[cfg(test)]
182mod tests {
183 use super::*;
184
185 #[test]
186 fn schema_display() {
187 assert_eq!(
188 Schema::scalar(DataType::Float64).to_string(),
189 "f64 (scalar)"
190 );
191 assert_eq!(
192 Schema::vector(DataType::Float64, 128).to_string(),
193 "f64[128]"
194 );
195 assert_eq!(
196 Schema::matrix(DataType::Float64, 100, 50).to_string(),
197 "f64[100, 50]"
198 );
199 assert_eq!(
200 Schema::batched(DataType::Float32, &[128]).to_string(),
201 "f32[batch, 128]"
202 );
203 assert_eq!(Schema::json().to_string(), "json");
204 }
205
206 #[test]
207 fn compatible_same_schema() {
208 let s = Schema::vector(DataType::Float64, 128);
209 assert!(s.is_compatible_with(&s));
210 }
211
212 #[test]
213 fn compatible_dynamic_with_fixed() {
214 let dynamic = Schema::batched(DataType::Float64, &[128]);
215 let fixed = Schema::matrix(DataType::Float64, 32, 128);
216 assert!(dynamic.is_compatible_with(&fixed));
217 assert!(fixed.is_compatible_with(&dynamic));
218 }
219
220 #[test]
221 fn compatible_unknown_shape() {
222 let unknown = Schema::dynamic(DataType::Float64);
223 let known = Schema::vector(DataType::Float64, 128);
224 assert!(unknown.is_compatible_with(&known));
225 assert!(known.is_compatible_with(&unknown));
226 }
227
228 #[test]
229 fn incompatible_different_dtype() {
230 let f64_schema = Schema::vector(DataType::Float64, 128);
231 let i64_schema = Schema::vector(DataType::Int64, 128);
232 assert!(!f64_schema.is_compatible_with(&i64_schema));
233 }
234
235 #[test]
236 fn incompatible_different_fixed_dims() {
237 let a = Schema::vector(DataType::Float64, 128);
238 let b = Schema::vector(DataType::Float64, 256);
239 assert!(!a.is_compatible_with(&b));
240 }
241
242 #[test]
243 fn incompatible_different_rank() {
244 let vec = Schema::vector(DataType::Float64, 128);
245 let mat = Schema::matrix(DataType::Float64, 128, 64);
246 assert!(!vec.is_compatible_with(&mat));
247 }
248
249 #[test]
250 fn json_compatible_with_json() {
251 assert!(Schema::json().is_compatible_with(&Schema::json()));
252 }
253
254 #[test]
255 fn json_incompatible_with_tensor() {
256 assert!(!Schema::json().is_compatible_with(&Schema::vector(DataType::Float64, 10)));
257 }
258
259 #[test]
260 fn serde_roundtrip() {
261 let schemas = vec![
262 Schema::scalar(DataType::Float64),
263 Schema::vector(DataType::Float32, 100),
264 Schema::batched(DataType::Float64, &[128, 64]),
265 Schema::json(),
266 Schema::dynamic(DataType::Int64),
267 ];
268 for s in schemas {
269 let json = serde_json::to_string(&s).unwrap();
270 let deserialized: Schema = serde_json::from_str(&json).unwrap();
271 assert_eq!(s, deserialized);
272 }
273 }
274
275 #[test]
276 fn rank() {
277 assert_eq!(Schema::scalar(DataType::Float64).rank(), Some(0));
278 assert_eq!(Schema::vector(DataType::Float64, 10).rank(), Some(1));
279 assert_eq!(Schema::matrix(DataType::Float64, 10, 5).rank(), Some(2));
280 assert_eq!(Schema::json().rank(), None);
281 }
282}