1use datafusion::arrow::array::RecordBatch;
2use datafusion::arrow::util::pretty::pretty_format_batches;
3use datafusion::execution::context::SessionContext;
4use datafusion::execution::disk_manager::DiskManagerConfig;
5use datafusion::execution::memory_pool::FairSpillPool;
6use datafusion::execution::runtime_env::RuntimeEnvBuilder;
7use datafusion::prelude::*;
8use std::sync::Arc;
9use tokio::runtime::Runtime;
10
11pub struct DataEngineConfig {
13 pub max_memory_bytes: usize,
15 pub spill_to_disk: bool,
17 pub spill_path: Option<String>,
19}
20
21impl Default for DataEngineConfig {
22 fn default() -> Self {
23 DataEngineConfig {
24 max_memory_bytes: 512 * 1024 * 1024, spill_to_disk: true,
26 spill_path: None,
27 }
28 }
29}
30
31pub struct DataEngine {
33 pub ctx: SessionContext,
34 pub rt: Arc<Runtime>,
35}
36
37impl Default for DataEngine {
38 fn default() -> Self {
39 Self::new()
40 }
41}
42
43impl DataEngine {
44 pub fn new() -> Self {
47 Self::with_config(DataEngineConfig::default())
48 }
49
50 pub fn with_config(config: DataEngineConfig) -> Self {
52 let rt = Arc::new(Runtime::new().expect("Failed to create tokio runtime for DataEngine"));
53
54 let pool = FairSpillPool::new(config.max_memory_bytes);
56
57 let mut rt_builder = RuntimeEnvBuilder::new().with_memory_pool(Arc::new(pool));
58
59 if config.spill_to_disk {
60 let disk_config = if let Some(ref path) = config.spill_path {
61 DiskManagerConfig::new_specified(vec![path.clone().into()])
62 } else {
63 DiskManagerConfig::NewOs
64 };
65 rt_builder = rt_builder.with_disk_manager(disk_config);
66 }
67
68 let runtime_env = rt_builder.build().expect("Failed to build RuntimeEnv");
69
70 let target_partitions = num_cpus::get();
72 let session_config = SessionConfig::new().with_target_partitions(target_partitions);
73
74 let ctx = SessionContext::new_with_config_rt(session_config, Arc::new(runtime_env));
75
76 DataEngine { ctx, rt }
77 }
78
79 pub fn collect(&self, df: DataFrame) -> Result<Vec<RecordBatch>, String> {
81 self.rt
82 .block_on(df.collect())
83 .map_err(|e| format!("DataFusion collect error: {e}"))
84 }
85
86 pub fn format_batches(batches: &[RecordBatch]) -> Result<String, String> {
88 pretty_format_batches(batches)
89 .map(|t| t.to_string())
90 .map_err(|e| format!("Format error: {e}"))
91 }
92
93 pub fn register_batch(&self, name: &str, batch: RecordBatch) -> Result<(), String> {
95 let schema = batch.schema();
96 let provider = datafusion::datasource::MemTable::try_new(schema, vec![vec![batch]])
97 .map_err(|e| format!("MemTable error: {e}"))?;
98 self.ctx
99 .register_table(name, Arc::new(provider))
100 .map_err(|e| format!("Register table error: {e}"))?;
101 Ok(())
102 }
103
104 pub fn register_batches(
107 &self,
108 name: &str,
109 schema: Arc<datafusion::arrow::datatypes::Schema>,
110 batches: Vec<RecordBatch>,
111 ) -> Result<(), String> {
112 if batches.is_empty() {
113 let provider = datafusion::datasource::MemTable::try_new(schema, vec![])
115 .map_err(|e| format!("MemTable error: {e}"))?;
116 self.ctx
117 .register_table(name, Arc::new(provider))
118 .map_err(|e| format!("Register table error: {e}"))?;
119 return Ok(());
120 }
121 let partitions: Vec<Vec<RecordBatch>> = batches.into_iter().map(|b| vec![b]).collect();
123 let provider = datafusion::datasource::MemTable::try_new(schema, partitions)
124 .map_err(|e| format!("MemTable error: {e}"))?;
125 let _ = self.ctx.deregister_table(name);
127 self.ctx
128 .register_table(name, Arc::new(provider))
129 .map_err(|e| format!("Register table error: {e}"))?;
130 Ok(())
131 }
132
133 pub fn sql(&self, query: &str) -> Result<DataFrame, String> {
135 self.rt
136 .block_on(self.ctx.sql(query))
137 .map_err(|e| format!("SQL error: {e}"))
138 }
139
140 pub fn session_ctx(&self) -> &SessionContext {
142 &self.ctx
143 }
144
145 pub fn runtime(&self) -> &Arc<Runtime> {
147 &self.rt
148 }
149}
150
151#[cfg(test)]
152mod tests {
153 use super::*;
154 use datafusion::arrow::array::{Int64Array, StringArray};
155 use datafusion::arrow::datatypes::{DataType, Field, Schema};
156
157 #[test]
158 fn test_engine_basic() {
159 let engine = DataEngine::new();
160 let schema = Arc::new(Schema::new(vec![
161 Field::new("id", DataType::Int64, false),
162 Field::new("name", DataType::Utf8, false),
163 ]));
164 let batch = RecordBatch::try_new(
165 schema,
166 vec![
167 Arc::new(Int64Array::from(vec![1, 2, 3])),
168 Arc::new(StringArray::from(vec!["Alice", "Bob", "Charlie"])),
169 ],
170 )
171 .unwrap();
172
173 engine.register_batch("test_table", batch).unwrap();
174 let df = engine.sql("SELECT * FROM test_table WHERE id > 1").unwrap();
175 let results = engine.collect(df).unwrap();
176 assert_eq!(results[0].num_rows(), 2);
177 }
178
179 #[test]
180 fn test_engine_with_config() {
181 let config = DataEngineConfig {
182 max_memory_bytes: 256 * 1024 * 1024,
183 spill_to_disk: true,
184 spill_path: None,
185 };
186 let engine = DataEngine::with_config(config);
187 let schema = Arc::new(Schema::new(vec![Field::new("x", DataType::Int64, false)]));
188 let batch =
189 RecordBatch::try_new(schema, vec![Arc::new(Int64Array::from(vec![1, 2, 3]))]).unwrap();
190 engine.register_batch("t", batch).unwrap();
191 let df = engine.sql("SELECT * FROM t").unwrap();
192 let results = engine.collect(df).unwrap();
193 assert_eq!(results[0].num_rows(), 3);
194 }
195}