1use datafusion::arrow::array::RecordBatch;
2use datafusion::arrow::util::pretty::pretty_format_batches;
3use datafusion::execution::context::SessionContext;
4use datafusion::execution::disk_manager::DiskManagerConfig;
5use datafusion::execution::memory_pool::FairSpillPool;
6use datafusion::execution::runtime_env::RuntimeEnvBuilder;
7use datafusion::prelude::*;
8use std::sync::Arc;
9use tokio::runtime::Runtime;
10
11pub struct DataEngineConfig {
13 pub max_memory_bytes: usize,
15 pub spill_to_disk: bool,
17 pub spill_path: Option<String>,
19}
20
21impl Default for DataEngineConfig {
22 fn default() -> Self {
23 DataEngineConfig {
24 max_memory_bytes: 512 * 1024 * 1024, spill_to_disk: true,
26 spill_path: None,
27 }
28 }
29}
30
31pub struct DataEngine {
33 pub ctx: SessionContext,
34 pub rt: Arc<Runtime>,
35}
36
37impl Default for DataEngine {
38 fn default() -> Self {
39 Self::new()
40 }
41}
42
43impl DataEngine {
44 pub fn new() -> Self {
47 Self::with_config(DataEngineConfig::default())
48 }
49
50 pub fn with_config(config: DataEngineConfig) -> Self {
52 let rt = Arc::new(Runtime::new().expect("Failed to create tokio runtime for DataEngine"));
53
54 let pool = FairSpillPool::new(config.max_memory_bytes);
56
57 let mut rt_builder = RuntimeEnvBuilder::new().with_memory_pool(Arc::new(pool));
58
59 if config.spill_to_disk {
60 let disk_config = if let Some(ref path) = config.spill_path {
61 DiskManagerConfig::new_specified(vec![path.clone().into()])
62 } else {
63 DiskManagerConfig::NewOs
64 };
65 rt_builder = rt_builder.with_disk_manager(disk_config);
66 }
67
68 let runtime_env = rt_builder.build().expect("Failed to build RuntimeEnv");
69
70 let target_partitions = num_cpus::get();
72 let session_config = SessionConfig::new().with_target_partitions(target_partitions);
73
74 let ctx = SessionContext::new_with_config_rt(session_config, Arc::new(runtime_env));
75
76 DataEngine { ctx, rt }
77 }
78
79 pub fn collect(&self, df: DataFrame) -> Result<Vec<RecordBatch>, String> {
81 self.rt
82 .block_on(df.collect())
83 .map_err(|e| format!("DataFusion collect error: {e}"))
84 }
85
86 pub fn format_batches(batches: &[RecordBatch]) -> Result<String, String> {
88 pretty_format_batches(batches)
89 .map(|t| t.to_string())
90 .map_err(|e| format!("Format error: {e}"))
91 }
92
93 pub fn register_batch(&self, name: &str, batch: RecordBatch) -> Result<(), String> {
95 let schema = batch.schema();
96 let provider = datafusion::datasource::MemTable::try_new(schema, vec![vec![batch]])
97 .map_err(|e| format!("MemTable error: {e}"))?;
98 self.ctx
99 .register_table(name, Arc::new(provider))
100 .map_err(|e| format!("Register table error: {e}"))?;
101 Ok(())
102 }
103
104 pub fn sql(&self, query: &str) -> Result<DataFrame, String> {
106 self.rt
107 .block_on(self.ctx.sql(query))
108 .map_err(|e| format!("SQL error: {e}"))
109 }
110
111 pub fn session_ctx(&self) -> &SessionContext {
113 &self.ctx
114 }
115
116 pub fn runtime(&self) -> &Arc<Runtime> {
118 &self.rt
119 }
120}
121
122#[cfg(test)]
123mod tests {
124 use super::*;
125 use datafusion::arrow::array::{Int64Array, StringArray};
126 use datafusion::arrow::datatypes::{DataType, Field, Schema};
127
128 #[test]
129 fn test_engine_basic() {
130 let engine = DataEngine::new();
131 let schema = Arc::new(Schema::new(vec![
132 Field::new("id", DataType::Int64, false),
133 Field::new("name", DataType::Utf8, false),
134 ]));
135 let batch = RecordBatch::try_new(
136 schema,
137 vec![
138 Arc::new(Int64Array::from(vec![1, 2, 3])),
139 Arc::new(StringArray::from(vec!["Alice", "Bob", "Charlie"])),
140 ],
141 )
142 .unwrap();
143
144 engine.register_batch("test_table", batch).unwrap();
145 let df = engine.sql("SELECT * FROM test_table WHERE id > 1").unwrap();
146 let results = engine.collect(df).unwrap();
147 assert_eq!(results[0].num_rows(), 2);
148 }
149
150 #[test]
151 fn test_engine_with_config() {
152 let config = DataEngineConfig {
153 max_memory_bytes: 256 * 1024 * 1024,
154 spill_to_disk: true,
155 spill_path: None,
156 };
157 let engine = DataEngine::with_config(config);
158 let schema = Arc::new(Schema::new(vec![Field::new("x", DataType::Int64, false)]));
159 let batch =
160 RecordBatch::try_new(schema, vec![Arc::new(Int64Array::from(vec![1, 2, 3]))]).unwrap();
161 engine.register_batch("t", batch).unwrap();
162 let df = engine.sql("SELECT * FROM t").unwrap();
163 let results = engine.collect(df).unwrap();
164 assert_eq!(results[0].num_rows(), 3);
165 }
166}