term_guard/core/
context.rs

1//! DataFusion context management for Term validation library.
2//!
3//! This module provides [`TermContext`], an abstraction layer over DataFusion's
4//! [`SessionContext`] with optimized settings for data validation workloads.
5
6use crate::prelude::*;
7use datafusion::datasource::TableProvider;
8use datafusion::execution::context::{SessionConfig, SessionContext};
9use datafusion::execution::memory_pool::{FairSpillPool, MemoryPool};
10use datafusion::execution::runtime_env::RuntimeEnvBuilder;
11use std::collections::HashMap;
12use std::sync::Arc;
13use tracing::instrument;
14
15/// Configuration for creating a [`TermContext`].
16#[derive(Debug, Clone)]
17pub struct TermContextConfig {
18    /// Batch size for query execution
19    pub batch_size: usize,
20    /// Target number of partitions for parallel execution
21    pub target_partitions: usize,
22    /// Maximum memory for query execution (in bytes)
23    pub max_memory: usize,
24    /// Memory fraction to use before spilling (0.0 to 1.0)
25    pub memory_fraction: f64,
26}
27
28impl Default for TermContextConfig {
29    fn default() -> Self {
30        Self {
31            batch_size: 8192,
32            target_partitions: std::thread::available_parallelism()
33                .map(|p| p.get())
34                .unwrap_or(4),
35            max_memory: 2 * 1024 * 1024 * 1024, // 2GB
36            memory_fraction: 0.9,
37        }
38    }
39}
40
41/// A managed DataFusion context for Term validation operations.
42///
43/// `TermContext` wraps DataFusion's [`SessionContext`] and provides:
44/// - Optimized default settings for data validation workloads
45/// - Memory management with configurable limits
46/// - Table registration helpers with tracking
47/// - Automatic resource cleanup
48///
49/// # Examples
50///
51/// ```rust,ignore
52/// use term_guard::core::TermContext;
53///
54/// # async fn example() -> Result<()> {
55/// // Create context with default settings
56/// let ctx = TermContext::new()?;
57///
58/// // Register a table
59/// ctx.register_csv("users", "data/users.csv").await?;
60///
61/// // Use the underlying SessionContext for queries
62/// let df = ctx.inner().sql("SELECT COUNT(*) FROM users").await?;
63/// # Ok(())
64/// # }
65/// ```
66pub struct TermContext {
67    inner: SessionContext,
68    pub(crate) tables: HashMap<String, Arc<dyn TableProvider>>,
69    config: TermContextConfig,
70}
71
72impl TermContext {
73    /// Creates a new context with default configuration.
74    ///
75    /// # Examples
76    ///
77    /// ```rust,ignore
78    /// use term_guard::core::TermContext;
79    ///
80    /// let ctx = TermContext::new()?;
81    /// ```
82    #[instrument]
83    pub fn new() -> Result<Self> {
84        Self::with_config(TermContextConfig::default())
85    }
86
87    /// Creates a new context with custom configuration.
88    ///
89    /// # Examples
90    ///
91    /// ```rust,ignore
92    /// use term_guard::core::{TermContext, TermContextConfig};
93    ///
94    /// let config = TermContextConfig {
95    ///     batch_size: 16384,
96    ///     max_memory: 4 * 1024 * 1024 * 1024, // 4GB
97    ///     ..Default::default()
98    /// };
99    ///
100    /// let ctx = TermContext::with_config(config)?;
101    /// ```
102    #[instrument(skip(config))]
103    pub fn with_config(config: TermContextConfig) -> Result<Self> {
104        // Create session configuration
105        let session_config = SessionConfig::new()
106            .with_batch_size(config.batch_size)
107            .with_target_partitions(config.target_partitions)
108            .with_information_schema(true);
109
110        // Create memory pool
111        let memory_pool = Arc::new(FairSpillPool::new(config.max_memory)) as Arc<dyn MemoryPool>;
112
113        // Create runtime environment
114        let runtime_env = RuntimeEnvBuilder::new()
115            .with_memory_pool(memory_pool)
116            .with_temp_file_path(std::env::temp_dir())
117            .build()
118            .map(Arc::new)?;
119
120        // Create session context
121        let inner = SessionContext::new_with_config_rt(session_config, runtime_env);
122
123        Ok(Self {
124            inner,
125            tables: HashMap::new(),
126            config,
127        })
128    }
129
130    /// Returns a reference to the underlying DataFusion [`SessionContext`].
131    ///
132    /// This allows direct access to all DataFusion functionality while
133    /// still benefiting from Term's resource management.
134    ///
135    /// # Examples
136    ///
137    /// ```rust,ignore
138    /// # use term_guard::core::TermContext;
139    /// # async fn example(ctx: &TermContext) -> Result<()> {
140    /// let df = ctx.inner().sql("SELECT * FROM data").await?;
141    /// # Ok(())
142    /// # }
143    /// ```
144    pub fn inner(&self) -> &SessionContext {
145        &self.inner
146    }
147
148    /// Returns a mutable reference to the underlying [`SessionContext`].
149    pub fn inner_mut(&mut self) -> &mut SessionContext {
150        &mut self.inner
151    }
152
153    /// Returns the configuration used to create this context.
154    pub fn config(&self) -> &TermContextConfig {
155        &self.config
156    }
157
158    /// Returns the names of all registered tables.
159    ///
160    /// # Examples
161    ///
162    /// ```rust,ignore
163    /// # use term_guard::core::TermContext;
164    /// # async fn example(ctx: &TermContext) -> Result<()> {
165    /// let tables = ctx.registered_tables();
166    /// println!("Registered tables: {:?}", tables);
167    /// # Ok(())
168    /// # }
169    /// ```
170    pub fn registered_tables(&self) -> Vec<&str> {
171        self.tables.keys().map(|s| s.as_str()).collect()
172    }
173
174    /// Checks if a table is registered.
175    ///
176    /// # Examples
177    ///
178    /// ```rust,ignore
179    /// # use term_guard::core::TermContext;
180    /// # async fn example(ctx: &TermContext) -> Result<()> {
181    /// if ctx.has_table("users") {
182    ///     println!("Users table is registered");
183    /// }
184    /// # Ok(())
185    /// # }
186    /// ```
187    pub fn has_table(&self, name: &str) -> bool {
188        self.tables.contains_key(name)
189    }
190
191    /// Registers a CSV file as a table.
192    ///
193    /// This is a convenience method that reads a CSV file and registers it
194    /// as a table in the context.
195    ///
196    /// # Examples
197    ///
198    /// ```rust,ignore
199    /// # use term_guard::core::TermContext;
200    /// # async fn example() -> Result<()> {
201    /// let mut ctx = TermContext::new()?;
202    /// ctx.register_csv("users", "data/users.csv").await?;
203    /// # Ok(())
204    /// # }
205    /// ```
206    #[instrument(skip(self))]
207    pub async fn register_csv(&mut self, name: &str, path: &str) -> Result<()> {
208        self.inner
209            .register_csv(name, path, Default::default())
210            .await?;
211
212        // Track the table
213        let source = self.inner.table_provider(name).await?;
214        self.tables.insert(name.to_string(), source);
215
216        Ok(())
217    }
218
219    /// Registers a Parquet file as a table.
220    ///
221    /// # Examples
222    ///
223    /// ```rust,ignore
224    /// # use term_guard::core::TermContext;
225    /// # async fn example() -> Result<()> {
226    /// let mut ctx = TermContext::new()?;
227    /// ctx.register_parquet("events", "data/events.parquet").await?;
228    /// # Ok(())
229    /// # }
230    /// ```
231    #[instrument(skip(self))]
232    pub async fn register_parquet(&mut self, name: &str, path: &str) -> Result<()> {
233        self.inner
234            .register_parquet(name, path, Default::default())
235            .await?;
236
237        // Track the table
238        let source = self.inner.table_provider(name).await?;
239        self.tables.insert(name.to_string(), source);
240
241        Ok(())
242    }
243
244    /// Deregisters a table from the context.
245    ///
246    /// # Examples
247    ///
248    /// ```rust,ignore
249    /// # use term_guard::core::TermContext;
250    /// # async fn example(ctx: &mut TermContext) -> Result<()> {
251    /// ctx.deregister_table("users")?;
252    /// # Ok(())
253    /// # }
254    /// ```
255    pub fn deregister_table(&mut self, name: &str) -> Result<()> {
256        self.inner.deregister_table(name)?;
257        self.tables.remove(name);
258        Ok(())
259    }
260
261    /// Registers a table directly and tracks it.
262    ///
263    /// This is a lower-level method that allows registering any TableProvider
264    /// directly. The table is automatically tracked for cleanup.
265    ///
266    /// # Examples
267    ///
268    /// ```rust,ignore
269    /// # use term_guard::core::TermContext;
270    /// # use datafusion::datasource::MemTable;
271    /// # async fn example(ctx: &mut TermContext, table: Arc<MemTable>) -> Result<()> {
272    /// ctx.register_table_provider("data", table).await?;
273    /// # Ok(())
274    /// # }
275    /// ```
276    #[instrument(skip(self, provider))]
277    pub async fn register_table_provider(
278        &mut self,
279        name: &str,
280        provider: Arc<dyn TableProvider>,
281    ) -> Result<()> {
282        self.inner.register_table(name, provider.clone())?;
283        self.tables.insert(name.to_string(), provider);
284        Ok(())
285    }
286
287    /// Clears all registered tables.
288    ///
289    /// This is useful for resetting the context between validation runs.
290    ///
291    /// # Examples
292    ///
293    /// ```rust,ignore
294    /// # use term_guard::core::TermContext;
295    /// # fn example(ctx: &mut TermContext) -> Result<()> {
296    /// ctx.clear_tables()?;
297    /// # Ok(())
298    /// # }
299    /// ```
300    pub fn clear_tables(&mut self) -> Result<()> {
301        let table_names: Vec<_> = self.tables.keys().cloned().collect();
302        for name in table_names {
303            self.deregister_table(&name)?;
304        }
305        Ok(())
306    }
307}
308
309/// Ensure proper cleanup when the context is dropped.
310impl Drop for TermContext {
311    fn drop(&mut self) {
312        // Clear all tables to ensure proper cleanup
313        if let Err(e) = self.clear_tables() {
314            tracing::warn!("Failed to clear tables during TermContext drop: {}", e);
315        }
316    }
317}
318
319#[cfg(test)]
320mod tests {
321    use super::*;
322    use arrow::array::{Int64Array, StringArray};
323    use arrow::datatypes::{DataType, Field, Schema};
324    use arrow::record_batch::RecordBatch;
325    use datafusion::datasource::MemTable;
326
327    #[test]
328    fn test_default_config() {
329        let config = TermContextConfig::default();
330        assert_eq!(config.batch_size, 8192);
331        assert_eq!(
332            config.target_partitions,
333            std::thread::available_parallelism()
334                .map(|p| p.get())
335                .unwrap_or(4)
336        );
337        assert_eq!(config.max_memory, 2 * 1024 * 1024 * 1024);
338        assert_eq!(config.memory_fraction, 0.9);
339    }
340
341    #[tokio::test]
342    async fn test_context_creation() {
343        let ctx = TermContext::new().unwrap();
344        assert!(ctx.registered_tables().is_empty());
345    }
346
347    #[tokio::test]
348    async fn test_context_with_custom_config() {
349        let config = TermContextConfig {
350            batch_size: 16384,
351            max_memory: 4 * 1024 * 1024 * 1024,
352            ..Default::default()
353        };
354
355        let ctx = TermContext::with_config(config.clone()).unwrap();
356        assert_eq!(ctx.config().batch_size, 16384);
357        assert_eq!(ctx.config().max_memory, 4 * 1024 * 1024 * 1024);
358    }
359
360    #[tokio::test]
361    async fn test_table_registration() {
362        let mut ctx = TermContext::new().unwrap();
363
364        // Create test data
365        let schema = Arc::new(Schema::new(vec![
366            Field::new("id", DataType::Int64, false),
367            Field::new("name", DataType::Utf8, false),
368        ]));
369
370        let batch = RecordBatch::try_new(
371            schema.clone(),
372            vec![
373                Arc::new(Int64Array::from(vec![1, 2, 3])),
374                Arc::new(StringArray::from(vec!["Alice", "Bob", "Charlie"])),
375            ],
376        )
377        .unwrap();
378
379        let table = MemTable::try_new(schema, vec![vec![batch]]).unwrap();
380
381        // Register table
382        ctx.register_table_provider("users", Arc::new(table))
383            .await
384            .unwrap();
385
386        assert!(ctx.has_table("users"));
387        assert_eq!(ctx.registered_tables(), vec!["users"]);
388    }
389
390    #[tokio::test]
391    async fn test_table_deregistration() {
392        let mut ctx = TermContext::new().unwrap();
393
394        // Create and register test table
395        let schema = Arc::new(Schema::new(vec![Field::new("id", DataType::Int64, false)]));
396        let batch = RecordBatch::try_new(
397            schema.clone(),
398            vec![Arc::new(Int64Array::from(vec![1, 2, 3]))],
399        )
400        .unwrap();
401        let table = MemTable::try_new(schema, vec![vec![batch]]).unwrap();
402
403        ctx.register_table_provider("test", Arc::new(table))
404            .await
405            .unwrap();
406
407        assert!(ctx.has_table("test"));
408
409        // Deregister
410        ctx.deregister_table("test").unwrap();
411        assert!(!ctx.has_table("test"));
412        assert!(ctx.registered_tables().is_empty());
413    }
414
415    #[tokio::test]
416    async fn test_clear_tables() {
417        let mut ctx = TermContext::new().unwrap();
418
419        // Register multiple tables
420        for i in 0..3 {
421            let name = format!("table{i}");
422            let schema = Arc::new(Schema::new(vec![Field::new("id", DataType::Int64, false)]));
423            let batch =
424                RecordBatch::try_new(schema.clone(), vec![Arc::new(Int64Array::from(vec![i]))])
425                    .unwrap();
426            let table = MemTable::try_new(schema, vec![vec![batch]]).unwrap();
427
428            ctx.register_table_provider(&name, Arc::new(table))
429                .await
430                .unwrap();
431        }
432
433        assert_eq!(ctx.registered_tables().len(), 3);
434
435        // Clear all
436        ctx.clear_tables().unwrap();
437        assert!(ctx.registered_tables().is_empty());
438    }
439}