shape_runtime/
type_mapping.rs1use crate::data::DataFrame;
13use shape_ast::error::{Result, ShapeError};
14use std::collections::HashMap;
15use std::sync::{Arc, RwLock};
16
17#[derive(Debug, Clone)]
32pub struct TypeMapping {
33 pub type_name: String,
35
36 pub field_to_column: HashMap<String, String>,
39
40 pub required_columns: Vec<String>,
43}
44
45impl TypeMapping {
46 pub fn new(type_name: String) -> Self {
48 Self {
49 type_name,
50 field_to_column: HashMap::new(),
51 required_columns: Vec::new(),
52 }
53 }
54
55 pub fn add_field(mut self, field: &str, column: &str) -> Self {
61 self.field_to_column
62 .insert(field.to_string(), column.to_string());
63 self
64 }
65
66 pub fn add_required(mut self, column: &str) -> Self {
68 self.required_columns.push(column.to_string());
69 self
70 }
71
72 pub fn with_mapping(mut self, custom: HashMap<String, String>) -> Self {
76 self.field_to_column.extend(custom);
77 self
78 }
79
80 pub fn validate(&self, df: &DataFrame) -> Result<()> {
97 let mut missing = Vec::new();
98
99 for col in &self.required_columns {
100 if !df.has_column(col) {
101 missing.push(col.clone());
102 }
103 }
104
105 if !missing.is_empty() {
106 return Err(ShapeError::RuntimeError {
107 message: format!(
108 "DataFrame missing required columns for type '{}': {}",
109 self.type_name,
110 missing.join(", ")
111 ),
112 location: None,
113 });
114 }
115
116 Ok(())
117 }
118
119 pub fn get_column(&self, field: &str) -> Option<&str> {
129 self.field_to_column.get(field).map(|s| s.as_str())
130 }
131}
132
133#[derive(Clone)]
138pub struct TypeMappingRegistry {
139 mappings: Arc<RwLock<HashMap<String, TypeMapping>>>,
141}
142
143impl TypeMappingRegistry {
144 pub fn new() -> Self {
149 Self {
150 mappings: Arc::new(RwLock::new(HashMap::new())),
151 }
152 }
153
154 pub fn register(&self, type_name: &str, mapping: TypeMapping) {
161 let mut mappings = self.mappings.write().unwrap();
162 mappings.insert(type_name.to_string(), mapping);
163 }
164
165 pub fn get(&self, type_name: &str) -> Option<TypeMapping> {
175 let mappings = self.mappings.read().unwrap();
176 mappings.get(type_name).cloned()
177 }
178
179 pub fn has(&self, type_name: &str) -> bool {
181 let mappings = self.mappings.read().unwrap();
182 mappings.contains_key(type_name)
183 }
184
185 pub fn list_types(&self) -> Vec<String> {
187 let mappings = self.mappings.read().unwrap();
188 mappings.keys().cloned().collect()
189 }
190
191 pub fn unregister(&self, type_name: &str) -> bool {
193 let mut mappings = self.mappings.write().unwrap();
194 mappings.remove(type_name).is_some()
195 }
196
197 pub fn clear(&self) {
199 let mut mappings = self.mappings.write().unwrap();
200 mappings.clear();
201 }
202}
203
204impl Default for TypeMappingRegistry {
205 fn default() -> Self {
206 Self::new()
207 }
208}
209
210#[cfg(test)]
211mod tests {
212 use super::*;
213 use crate::data::Timeframe;
214
215 #[test]
216 fn test_generic_mapping() {
217 let mapping = TypeMapping::new("SensorData".to_string())
219 .add_field("temp", "temperature")
220 .add_field("humid", "humidity")
221 .add_required("temperature")
222 .add_required("humidity");
223
224 assert_eq!(mapping.type_name, "SensorData");
225 assert_eq!(mapping.required_columns.len(), 2);
226 assert!(
227 mapping
228 .required_columns
229 .contains(&"temperature".to_string())
230 );
231 assert!(mapping.required_columns.contains(&"humidity".to_string()));
232 assert_eq!(mapping.get_column("temp"), Some("temperature"));
233 }
234
235 #[test]
236 fn test_validate_success() {
237 let mapping = TypeMapping::new("Metrics".to_string())
239 .add_required("value")
240 .add_required("count");
241
242 let mut df = DataFrame::new("TEST", Timeframe::d1());
243 df.add_column("value", vec![100.0, 101.0]);
244 df.add_column("count", vec![5.0, 6.0]);
245
246 assert!(mapping.validate(&df).is_ok());
247 }
248
249 #[test]
250 fn test_validate_missing_column() {
251 let mapping = TypeMapping::new("DataPoint".to_string())
253 .add_required("value")
254 .add_required("count");
255
256 let mut df = DataFrame::new("TEST", Timeframe::d1());
257 df.add_column("value", vec![100.0]);
258 assert!(mapping.validate(&df).is_err());
261 }
262
263 #[test]
264 fn test_custom_mapping() {
265 let mapping = TypeMapping::new("CustomType".to_string())
266 .add_field("price", "close")
267 .add_field("size", "volume")
268 .add_required("close")
269 .add_required("volume");
270
271 assert_eq!(mapping.type_name, "CustomType");
272 assert_eq!(mapping.get_column("price"), Some("close"));
273 assert_eq!(mapping.get_column("size"), Some("volume"));
274 assert_eq!(mapping.required_columns.len(), 2);
275 }
276
277 #[test]
278 fn test_registry_operations() {
279 let registry = TypeMappingRegistry::new();
280
281 assert!(!registry.has("Candle"));
283 assert_eq!(registry.list_types().len(), 0);
284
285 let custom = TypeMapping::new("DataPoint".to_string());
287 registry.register("DataPoint", custom);
288
289 assert!(registry.has("DataPoint"));
290 assert_eq!(registry.list_types().len(), 1);
291
292 assert!(registry.unregister("DataPoint"));
294 assert!(!registry.has("DataPoint"));
295 }
296
297 #[test]
298 fn test_registry_clear() {
299 let registry = TypeMappingRegistry::new();
300 registry.clear();
301
302 assert_eq!(registry.list_types().len(), 0);
303 }
304}