1use crate::error::{Result, SqlStreamError};
8use datafusion::arrow::util::pretty::print_batches;
9use datafusion::prelude::*;
10use std::path::Path;
11use tracing::{debug, info, instrument};
12
13pub struct QueryEngine {
19 ctx: SessionContext,
20}
21
22impl QueryEngine {
23 #[instrument]
29 pub fn new() -> Result<Self> {
30 info!("Initializing query engine");
31 let ctx = SessionContext::new();
32 Ok(Self { ctx })
33 }
34
35 #[instrument(skip(self))]
53 pub async fn register_file(&mut self, file_path: &str, table_name: &str) -> Result<()> {
54 let path = Path::new(file_path);
55
56 if !path.exists() {
58 return Err(SqlStreamError::FileNotFound(path.to_path_buf()));
59 }
60
61 info!("Registering file: {} as table: {}", file_path, table_name);
62
63 let extension = path
65 .extension()
66 .and_then(|ext| ext.to_str())
67 .ok_or_else(|| SqlStreamError::UnsupportedFormat(path.to_string_lossy().to_string()))?;
68
69 match extension.to_lowercase().as_str() {
70 "csv" => {
71 debug!("Detected CSV format");
72 self.ctx
73 .register_csv(table_name, file_path, CsvReadOptions::new())
74 .await
75 .map_err(|e| {
76 SqlStreamError::TableRegistration(table_name.to_string(), e.to_string())
77 })?;
78 }
79 "json" => {
80 debug!("Detected JSON format");
81 self.ctx
82 .register_json(table_name, file_path, NdJsonReadOptions::default())
83 .await
84 .map_err(|e| {
85 SqlStreamError::TableRegistration(table_name.to_string(), e.to_string())
86 })?;
87 }
88 _ => {
89 return Err(SqlStreamError::UnsupportedFormat(extension.to_string()));
90 }
91 }
92
93 info!("Successfully registered table: {}", table_name);
94 Ok(())
95 }
96
97 #[instrument(skip(self))]
107 pub async fn execute_query(&self, sql: &str) -> Result<DataFrame> {
108 info!("Executing SQL query");
109 debug!("Query: {}", sql);
110
111 let df = self
112 .ctx
113 .sql(sql)
114 .await
115 .map_err(|e| SqlStreamError::QueryExecution(e.to_string()))?;
116
117 Ok(df)
118 }
119
120 #[instrument(skip(self, dataframe))]
133 pub async fn print_results(&self, dataframe: DataFrame) -> Result<()> {
134 info!("Collecting and printing results");
135
136 let batches = dataframe.collect().await?;
138
139 print_batches(&batches).map_err(|e| {
141 SqlStreamError::QueryExecution(format!("Failed to print results: {}", e))
142 })?;
143
144 let total_rows: usize = batches.iter().map(|b| b.num_rows()).sum();
145 info!("Query returned {} rows", total_rows);
146
147 Ok(())
148 }
149}
150
151impl Default for QueryEngine {
152 fn default() -> Self {
153 Self::new().expect("Failed to create default QueryEngine")
154 }
155}
156
157#[cfg(test)]
158mod tests {
159 use super::*;
160
161 #[tokio::test]
162 async fn test_engine_creation() {
163 let engine = QueryEngine::new();
164 assert!(engine.is_ok());
165 }
166
167 #[tokio::test]
168 async fn test_file_not_found() {
169 let mut engine = QueryEngine::new().unwrap();
170 let result = engine.register_file("nonexistent.csv", "test").await;
171 assert!(matches!(result, Err(SqlStreamError::FileNotFound(_))));
172 }
173}