Skip to main content

scouter_dataframe/parquet/bifrost/
catalog.rs

1use async_trait::async_trait;
2use dashmap::DashMap;
3use datafusion::catalog::{CatalogProvider, SchemaProvider};
4use datafusion::common::DataFusionError;
5use datafusion::datasource::TableProvider;
6use scouter_types::dataset::DatasetNamespace;
7use std::any::Any;
8use std::fmt::Debug;
9use std::sync::Arc;
10
11/// Custom DataFusion catalog provider that maps the first level of the
12/// three-level `catalog.schema.table` SQL namespace.
13///
14/// Wraps a `DashMap` of schema names → `DatasetSchemaProvider` instances.
15/// Registered on the shared `query_ctx` so DataFusion can resolve table
16/// references in SQL queries.
17#[derive(Debug)]
18pub struct DatasetCatalogProvider {
19    schemas: DashMap<String, Arc<DatasetSchemaProvider>>,
20}
21
22impl DatasetCatalogProvider {
23    pub fn new() -> Self {
24        Self {
25            schemas: DashMap::new(),
26        }
27    }
28
29    /// Get or create a schema provider for the given schema name.
30    pub fn get_or_create_schema(&self, schema_name: &str) -> Arc<DatasetSchemaProvider> {
31        self.schemas
32            .entry(schema_name.to_string())
33            .or_insert_with(|| Arc::new(DatasetSchemaProvider::new()))
34            .clone()
35    }
36
37    /// Atomically swap the `TableProvider` for a table after a Delta write.
38    /// In-flight queries that already obtained a `DataFrame` hold a reference
39    /// to the old snapshot and complete normally.
40    pub fn swap_table(&self, namespace: &DatasetNamespace, provider: Arc<dyn TableProvider>) {
41        let schema = self.get_or_create_schema(&namespace.schema_name);
42        schema.tables.insert(namespace.table.clone(), provider);
43    }
44
45    /// Remove a table from the catalog (used during TTL eviction).
46    pub fn remove_table(&self, namespace: &DatasetNamespace) {
47        if let Some(schema) = self.schemas.get(&namespace.schema_name) {
48            schema.tables.remove(&namespace.table);
49        }
50    }
51
52    /// Check if a table exists in the catalog.
53    pub fn has_table(&self, namespace: &DatasetNamespace) -> bool {
54        self.schemas
55            .get(&namespace.schema_name)
56            .map(|s| s.tables.contains_key(&namespace.table))
57            .unwrap_or(false)
58    }
59}
60
61impl Default for DatasetCatalogProvider {
62    fn default() -> Self {
63        Self::new()
64    }
65}
66
67impl CatalogProvider for DatasetCatalogProvider {
68    fn as_any(&self) -> &dyn Any {
69        self
70    }
71
72    fn schema_names(&self) -> Vec<String> {
73        self.schemas.iter().map(|e| e.key().clone()).collect()
74    }
75
76    fn schema(&self, name: &str) -> Option<Arc<dyn SchemaProvider>> {
77        self.schemas
78            .get(name)
79            .map(|s| Arc::clone(&*s) as Arc<dyn SchemaProvider>)
80    }
81
82    fn register_schema(
83        &self,
84        name: &str,
85        schema: Arc<dyn SchemaProvider>,
86    ) -> datafusion::common::Result<Option<Arc<dyn SchemaProvider>>> {
87        let dataset_schema = schema
88            .as_any()
89            .downcast_ref::<DatasetSchemaProvider>()
90            .ok_or_else(|| {
91                DataFusionError::Internal("Expected DatasetSchemaProvider".to_string())
92            })?;
93        let prev = self
94            .schemas
95            .insert(name.to_string(), Arc::new(dataset_schema.clone()));
96        Ok(prev.map(|p| p as Arc<dyn SchemaProvider>))
97    }
98}
99
100/// Custom DataFusion schema provider that maps the second level of the
101/// three-level namespace. Holds `DashMap<table_name, TableProvider>`.
102#[derive(Debug, Clone)]
103pub struct DatasetSchemaProvider {
104    tables: DashMap<String, Arc<dyn TableProvider>>,
105}
106
107impl DatasetSchemaProvider {
108    pub fn new() -> Self {
109        Self {
110            tables: DashMap::new(),
111        }
112    }
113}
114
115impl Default for DatasetSchemaProvider {
116    fn default() -> Self {
117        Self::new()
118    }
119}
120
121#[async_trait]
122impl SchemaProvider for DatasetSchemaProvider {
123    fn as_any(&self) -> &dyn Any {
124        self
125    }
126
127    fn table_names(&self) -> Vec<String> {
128        self.tables.iter().map(|e| e.key().clone()).collect()
129    }
130
131    async fn table(&self, name: &str) -> Result<Option<Arc<dyn TableProvider>>, DataFusionError> {
132        Ok(self.tables.get(name).map(|t| Arc::clone(&*t)))
133    }
134
135    fn table_exist(&self, name: &str) -> bool {
136        self.tables.contains_key(name)
137    }
138
139    fn register_table(
140        &self,
141        name: String,
142        table: Arc<dyn TableProvider>,
143    ) -> datafusion::common::Result<Option<Arc<dyn TableProvider>>> {
144        Ok(self.tables.insert(name, table))
145    }
146
147    fn deregister_table(
148        &self,
149        name: &str,
150    ) -> datafusion::common::Result<Option<Arc<dyn TableProvider>>> {
151        Ok(self.tables.remove(name).map(|(_, t)| t))
152    }
153}
154
155#[cfg(test)]
156mod tests {
157    use super::*;
158
159    #[test]
160    fn test_catalog_schema_lifecycle() {
161        let catalog = DatasetCatalogProvider::new();
162
163        // Initially empty
164        assert!(catalog.schema_names().is_empty());
165        assert!(catalog.schema("test_schema").is_none());
166
167        // Get or create a schema
168        let schema = catalog.get_or_create_schema("test_schema");
169        assert!(catalog.schema_names().contains(&"test_schema".to_string()));
170        assert!(schema.table_names().is_empty());
171
172        // Getting the same schema again returns the same instance
173        let schema2 = catalog.get_or_create_schema("test_schema");
174        assert_eq!(schema.table_names(), schema2.table_names());
175    }
176
177    #[test]
178    fn test_catalog_has_table() {
179        let catalog = DatasetCatalogProvider::new();
180        let ns = DatasetNamespace::new("cat", "sch", "tbl").unwrap();
181
182        assert!(!catalog.has_table(&ns));
183
184        // Add a table via swap_table
185        let schema = arrow::datatypes::Schema::new(vec![arrow::datatypes::Field::new(
186            "id",
187            arrow::datatypes::DataType::Int64,
188            false,
189        )]);
190        let batch = arrow_array::RecordBatch::new_empty(Arc::new(schema));
191        let provider = Arc::new(
192            datafusion::datasource::MemTable::try_new(batch.schema(), vec![vec![batch]]).unwrap(),
193        );
194        catalog.swap_table(&ns, provider);
195
196        assert!(catalog.has_table(&ns));
197
198        // Remove it
199        catalog.remove_table(&ns);
200        assert!(!catalog.has_table(&ns));
201    }
202}