term_guard/core/context.rs
1//! DataFusion context management for Term validation library.
2//!
3//! This module provides [`TermContext`], an abstraction layer over DataFusion's
4//! [`SessionContext`] with optimized settings for data validation workloads.
5
6use crate::prelude::*;
7use datafusion::datasource::TableProvider;
8use datafusion::execution::context::{SessionConfig, SessionContext};
9use datafusion::execution::memory_pool::{FairSpillPool, MemoryPool};
10use datafusion::execution::runtime_env::RuntimeEnvBuilder;
11use std::collections::HashMap;
12use std::sync::Arc;
13use tracing::instrument;
14
15/// Configuration for creating a [`TermContext`].
16#[derive(Debug, Clone)]
17pub struct TermContextConfig {
18 /// Batch size for query execution
19 pub batch_size: usize,
20 /// Target number of partitions for parallel execution
21 pub target_partitions: usize,
22 /// Maximum memory for query execution (in bytes)
23 pub max_memory: usize,
24 /// Memory fraction to use before spilling (0.0 to 1.0)
25 pub memory_fraction: f64,
26}
27
28impl Default for TermContextConfig {
29 fn default() -> Self {
30 Self {
31 batch_size: 8192,
32 target_partitions: std::thread::available_parallelism()
33 .map(|p| p.get())
34 .unwrap_or(4),
35 max_memory: 2 * 1024 * 1024 * 1024, // 2GB
36 memory_fraction: 0.9,
37 }
38 }
39}
40
41/// A managed DataFusion context for Term validation operations.
42///
43/// `TermContext` wraps DataFusion's [`SessionContext`] and provides:
44/// - Optimized default settings for data validation workloads
45/// - Memory management with configurable limits
46/// - Table registration helpers with tracking
47/// - Automatic resource cleanup
48///
49/// # Examples
50///
51/// ```rust,ignore
52/// use term_guard::core::TermContext;
53///
54/// # async fn example() -> Result<()> {
55/// // Create context with default settings
56/// let ctx = TermContext::new()?;
57///
58/// // Register a table
59/// ctx.register_csv("users", "data/users.csv").await?;
60///
61/// // Use the underlying SessionContext for queries
62/// let df = ctx.inner().sql("SELECT COUNT(*) FROM users").await?;
63/// # Ok(())
64/// # }
65/// ```
66pub struct TermContext {
67 inner: SessionContext,
68 pub(crate) tables: HashMap<String, Arc<dyn TableProvider>>,
69 config: TermContextConfig,
70}
71
72impl TermContext {
73 /// Creates a new context with default configuration.
74 ///
75 /// # Examples
76 ///
77 /// ```rust,ignore
78 /// use term_guard::core::TermContext;
79 ///
80 /// let ctx = TermContext::new()?;
81 /// ```
82 #[instrument]
83 pub fn new() -> Result<Self> {
84 Self::with_config(TermContextConfig::default())
85 }
86
87 /// Creates a new context with custom configuration.
88 ///
89 /// # Examples
90 ///
91 /// ```rust,ignore
92 /// use term_guard::core::{TermContext, TermContextConfig};
93 ///
94 /// let config = TermContextConfig {
95 /// batch_size: 16384,
96 /// max_memory: 4 * 1024 * 1024 * 1024, // 4GB
97 /// ..Default::default()
98 /// };
99 ///
100 /// let ctx = TermContext::with_config(config)?;
101 /// ```
102 #[instrument(skip(config))]
103 pub fn with_config(config: TermContextConfig) -> Result<Self> {
104 // Create session configuration
105 let session_config = SessionConfig::new()
106 .with_batch_size(config.batch_size)
107 .with_target_partitions(config.target_partitions)
108 .with_information_schema(true);
109
110 // Create memory pool
111 let memory_pool = Arc::new(FairSpillPool::new(config.max_memory)) as Arc<dyn MemoryPool>;
112
113 // Create runtime environment
114 let runtime_env = RuntimeEnvBuilder::new()
115 .with_memory_pool(memory_pool)
116 .with_temp_file_path(std::env::temp_dir())
117 .build()
118 .map(Arc::new)?;
119
120 // Create session context
121 let inner = SessionContext::new_with_config_rt(session_config, runtime_env);
122
123 Ok(Self {
124 inner,
125 tables: HashMap::new(),
126 config,
127 })
128 }
129
130 /// Returns a reference to the underlying DataFusion [`SessionContext`].
131 ///
132 /// This allows direct access to all DataFusion functionality while
133 /// still benefiting from Term's resource management.
134 ///
135 /// # Examples
136 ///
137 /// ```rust,ignore
138 /// # use term_guard::core::TermContext;
139 /// # async fn example(ctx: &TermContext) -> Result<()> {
140 /// let df = ctx.inner().sql("SELECT * FROM data").await?;
141 /// # Ok(())
142 /// # }
143 /// ```
144 pub fn inner(&self) -> &SessionContext {
145 &self.inner
146 }
147
148 /// Returns a mutable reference to the underlying [`SessionContext`].
149 pub fn inner_mut(&mut self) -> &mut SessionContext {
150 &mut self.inner
151 }
152
153 /// Returns the configuration used to create this context.
154 pub fn config(&self) -> &TermContextConfig {
155 &self.config
156 }
157
158 /// Returns the names of all registered tables.
159 ///
160 /// # Examples
161 ///
162 /// ```rust,ignore
163 /// # use term_guard::core::TermContext;
164 /// # async fn example(ctx: &TermContext) -> Result<()> {
165 /// let tables = ctx.registered_tables();
166 /// println!("Registered tables: {:?}", tables);
167 /// # Ok(())
168 /// # }
169 /// ```
170 pub fn registered_tables(&self) -> Vec<&str> {
171 self.tables.keys().map(|s| s.as_str()).collect()
172 }
173
174 /// Checks if a table is registered.
175 ///
176 /// # Examples
177 ///
178 /// ```rust,ignore
179 /// # use term_guard::core::TermContext;
180 /// # async fn example(ctx: &TermContext) -> Result<()> {
181 /// if ctx.has_table("users") {
182 /// println!("Users table is registered");
183 /// }
184 /// # Ok(())
185 /// # }
186 /// ```
187 pub fn has_table(&self, name: &str) -> bool {
188 self.tables.contains_key(name)
189 }
190
191 /// Registers a CSV file as a table.
192 ///
193 /// This is a convenience method that reads a CSV file and registers it
194 /// as a table in the context.
195 ///
196 /// # Examples
197 ///
198 /// ```rust,ignore
199 /// # use term_guard::core::TermContext;
200 /// # async fn example() -> Result<()> {
201 /// let mut ctx = TermContext::new()?;
202 /// ctx.register_csv("users", "data/users.csv").await?;
203 /// # Ok(())
204 /// # }
205 /// ```
206 #[instrument(skip(self))]
207 pub async fn register_csv(&mut self, name: &str, path: &str) -> Result<()> {
208 self.inner
209 .register_csv(name, path, Default::default())
210 .await?;
211
212 // Track the table
213 let source = self.inner.table_provider(name).await?;
214 self.tables.insert(name.to_string(), source);
215
216 Ok(())
217 }
218
219 /// Registers a Parquet file as a table.
220 ///
221 /// # Examples
222 ///
223 /// ```rust,ignore
224 /// # use term_guard::core::TermContext;
225 /// # async fn example() -> Result<()> {
226 /// let mut ctx = TermContext::new()?;
227 /// ctx.register_parquet("events", "data/events.parquet").await?;
228 /// # Ok(())
229 /// # }
230 /// ```
231 #[instrument(skip(self))]
232 pub async fn register_parquet(&mut self, name: &str, path: &str) -> Result<()> {
233 self.inner
234 .register_parquet(name, path, Default::default())
235 .await?;
236
237 // Track the table
238 let source = self.inner.table_provider(name).await?;
239 self.tables.insert(name.to_string(), source);
240
241 Ok(())
242 }
243
244 /// Deregisters a table from the context.
245 ///
246 /// # Examples
247 ///
248 /// ```rust,ignore
249 /// # use term_guard::core::TermContext;
250 /// # async fn example(ctx: &mut TermContext) -> Result<()> {
251 /// ctx.deregister_table("users")?;
252 /// # Ok(())
253 /// # }
254 /// ```
255 pub fn deregister_table(&mut self, name: &str) -> Result<()> {
256 self.inner.deregister_table(name)?;
257 self.tables.remove(name);
258 Ok(())
259 }
260
261 /// Registers a table directly and tracks it.
262 ///
263 /// This is a lower-level method that allows registering any TableProvider
264 /// directly. The table is automatically tracked for cleanup.
265 ///
266 /// # Examples
267 ///
268 /// ```rust,ignore
269 /// # use term_guard::core::TermContext;
270 /// # use datafusion::datasource::MemTable;
271 /// # async fn example(ctx: &mut TermContext, table: Arc<MemTable>) -> Result<()> {
272 /// ctx.register_table_provider("data", table).await?;
273 /// # Ok(())
274 /// # }
275 /// ```
276 #[instrument(skip(self, provider))]
277 pub async fn register_table_provider(
278 &mut self,
279 name: &str,
280 provider: Arc<dyn TableProvider>,
281 ) -> Result<()> {
282 self.inner.register_table(name, provider.clone())?;
283 self.tables.insert(name.to_string(), provider);
284 Ok(())
285 }
286
287 /// Clears all registered tables.
288 ///
289 /// This is useful for resetting the context between validation runs.
290 ///
291 /// # Examples
292 ///
293 /// ```rust,ignore
294 /// # use term_guard::core::TermContext;
295 /// # fn example(ctx: &mut TermContext) -> Result<()> {
296 /// ctx.clear_tables()?;
297 /// # Ok(())
298 /// # }
299 /// ```
300 pub fn clear_tables(&mut self) -> Result<()> {
301 let table_names: Vec<_> = self.tables.keys().cloned().collect();
302 for name in table_names {
303 self.deregister_table(&name)?;
304 }
305 Ok(())
306 }
307}
308
309/// Ensure proper cleanup when the context is dropped.
310impl Drop for TermContext {
311 fn drop(&mut self) {
312 // Clear all tables to ensure proper cleanup
313 if let Err(e) = self.clear_tables() {
314 tracing::warn!("Failed to clear tables during TermContext drop: {}", e);
315 }
316 }
317}
318
319#[cfg(test)]
320mod tests {
321 use super::*;
322 use arrow::array::{Int64Array, StringArray};
323 use arrow::datatypes::{DataType, Field, Schema};
324 use arrow::record_batch::RecordBatch;
325 use datafusion::datasource::MemTable;
326
327 #[test]
328 fn test_default_config() {
329 let config = TermContextConfig::default();
330 assert_eq!(config.batch_size, 8192);
331 assert_eq!(
332 config.target_partitions,
333 std::thread::available_parallelism()
334 .map(|p| p.get())
335 .unwrap_or(4)
336 );
337 assert_eq!(config.max_memory, 2 * 1024 * 1024 * 1024);
338 assert_eq!(config.memory_fraction, 0.9);
339 }
340
341 #[tokio::test]
342 async fn test_context_creation() {
343 let ctx = TermContext::new().unwrap();
344 assert!(ctx.registered_tables().is_empty());
345 }
346
347 #[tokio::test]
348 async fn test_context_with_custom_config() {
349 let config = TermContextConfig {
350 batch_size: 16384,
351 max_memory: 4 * 1024 * 1024 * 1024,
352 ..Default::default()
353 };
354
355 let ctx = TermContext::with_config(config.clone()).unwrap();
356 assert_eq!(ctx.config().batch_size, 16384);
357 assert_eq!(ctx.config().max_memory, 4 * 1024 * 1024 * 1024);
358 }
359
360 #[tokio::test]
361 async fn test_table_registration() {
362 let mut ctx = TermContext::new().unwrap();
363
364 // Create test data
365 let schema = Arc::new(Schema::new(vec![
366 Field::new("id", DataType::Int64, false),
367 Field::new("name", DataType::Utf8, false),
368 ]));
369
370 let batch = RecordBatch::try_new(
371 schema.clone(),
372 vec![
373 Arc::new(Int64Array::from(vec![1, 2, 3])),
374 Arc::new(StringArray::from(vec!["Alice", "Bob", "Charlie"])),
375 ],
376 )
377 .unwrap();
378
379 let table = MemTable::try_new(schema, vec![vec![batch]]).unwrap();
380
381 // Register table
382 ctx.register_table_provider("users", Arc::new(table))
383 .await
384 .unwrap();
385
386 assert!(ctx.has_table("users"));
387 assert_eq!(ctx.registered_tables(), vec!["users"]);
388 }
389
390 #[tokio::test]
391 async fn test_table_deregistration() {
392 let mut ctx = TermContext::new().unwrap();
393
394 // Create and register test table
395 let schema = Arc::new(Schema::new(vec![Field::new("id", DataType::Int64, false)]));
396 let batch = RecordBatch::try_new(
397 schema.clone(),
398 vec![Arc::new(Int64Array::from(vec![1, 2, 3]))],
399 )
400 .unwrap();
401 let table = MemTable::try_new(schema, vec![vec![batch]]).unwrap();
402
403 ctx.register_table_provider("test", Arc::new(table))
404 .await
405 .unwrap();
406
407 assert!(ctx.has_table("test"));
408
409 // Deregister
410 ctx.deregister_table("test").unwrap();
411 assert!(!ctx.has_table("test"));
412 assert!(ctx.registered_tables().is_empty());
413 }
414
415 #[tokio::test]
416 async fn test_clear_tables() {
417 let mut ctx = TermContext::new().unwrap();
418
419 // Register multiple tables
420 for i in 0..3 {
421 let name = format!("table{i}");
422 let schema = Arc::new(Schema::new(vec![Field::new("id", DataType::Int64, false)]));
423 let batch =
424 RecordBatch::try_new(schema.clone(), vec![Arc::new(Int64Array::from(vec![i]))])
425 .unwrap();
426 let table = MemTable::try_new(schema, vec![vec![batch]]).unwrap();
427
428 ctx.register_table_provider(&name, Arc::new(table))
429 .await
430 .unwrap();
431 }
432
433 assert_eq!(ctx.registered_tables().len(), 3);
434
435 // Clear all
436 ctx.clear_tables().unwrap();
437 assert!(ctx.registered_tables().is_empty());
438 }
439}