1use anyhow::{Context, Result};
6use polars::prelude::*;
7use polars::sql::SQLContext;
8use std::collections::HashMap;
9use std::path::Path;
10
11use crate::parser::DataParser;
12
13pub struct SqlEngine {
15 tables: HashMap<String, LazyFrame>,
17 parser: DataParser,
19}
20
21impl SqlEngine {
22 pub fn new() -> Self {
24 Self { tables: HashMap::new(), parser: DataParser::new() }
25 }
26
27 pub fn register_table(&mut self, table_name: &str, path: &Path) -> Result<()> {
29 let lf = self.parser.read_lazy(path).with_context(|| {
30 format!(
31 "无法加载表 '{}'\n --> 文件: {}\n💡 提示: 请确认文件存在且格式正确 (csv/parquet)",
32 table_name,
33 path.display()
34 )
35 })?;
36
37 self.tables.insert(table_name.to_string(), lf);
38 Ok(())
39 }
40
41 pub fn register_lazyframe(&mut self, table_name: &str, lf: LazyFrame) {
43 self.tables.insert(table_name.to_string(), lf);
44 }
45
46 pub fn execute(&self, sql: &str) -> Result<DataFrame> {
48 let mut ctx = SQLContext::new();
50
51 for (name, lf) in &self.tables {
53 ctx.register(name, lf.clone());
54 }
55
56 let result_lf = ctx.execute(sql).with_context(|| format!(
58 "SQL 查询执行失败\n --> SQL: {}\n💡 提示: 检查 SQL 语法,或运行 'xore agent explain \"{}\"' 获取分析",
59 sql, sql
60 ))?;
61
62 result_lf.collect().with_context(|| {
64 format!(
65 "收集查询结果失败\n --> SQL: {}\n💡 提示: 查询可能返回了过多数据,尝试添加 LIMIT 子句",
66 sql
67 )
68 })
69 }
70}
71
72impl Default for SqlEngine {
73 fn default() -> Self {
74 Self::new()
75 }
76}
77
78#[cfg(test)]
79mod tests {
80 use super::*;
81 use polars::df;
82 use tempfile::NamedTempFile;
83
84 #[test]
85 fn test_simple_select() {
86 let df = df! {
87 "id" => &[1, 2, 3],
88 "name" => &["Alice", "Bob", "Charlie"],
89 "age" => &[25, 30, 35],
90 }
91 .unwrap();
92
93 let mut engine = SqlEngine::new();
94 engine.register_lazyframe("users", df.lazy());
95
96 let result = engine.execute("SELECT * FROM users").unwrap();
97 assert_eq!(result.height(), 3);
98 assert_eq!(result.width(), 3);
99 }
100
101 #[test]
102 fn test_select_with_where() {
103 let df = df! {
104 "id" => &[1, 2, 3],
105 "age" => &[25, 30, 35],
106 }
107 .unwrap();
108
109 let mut engine = SqlEngine::new();
110 engine.register_lazyframe("users", df.lazy());
111
112 let result = engine.execute("SELECT * FROM users WHERE age > 28").unwrap();
113 assert_eq!(result.height(), 2); }
115
116 #[test]
117 fn test_select_columns() {
118 let df = df! {
119 "id" => &[1, 2, 3],
120 "name" => &["Alice", "Bob", "Charlie"],
121 "age" => &[25, 30, 35],
122 }
123 .unwrap();
124
125 let mut engine = SqlEngine::new();
126 engine.register_lazyframe("users", df.lazy());
127
128 let result = engine.execute("SELECT name, age FROM users").unwrap();
129 assert_eq!(result.width(), 2);
130 assert!(result.column("name").is_ok());
131 assert!(result.column("age").is_ok());
132 }
133
134 #[test]
135 fn test_group_by_count() {
136 let df = df! {
137 "category" => &["A", "B", "A", "C", "B"],
138 "value" => &[10, 20, 30, 40, 50],
139 }
140 .unwrap();
141
142 let mut engine = SqlEngine::new();
143 engine.register_lazyframe("data", df.lazy());
144
145 let result = engine
146 .execute("SELECT category, COUNT(*) as count FROM data GROUP BY category")
147 .unwrap();
148 assert_eq!(result.height(), 3); assert!(result.column("count").is_ok());
150 }
151
152 #[test]
153 fn test_order_by() {
154 let df = df! {
155 "id" => &[3, 1, 2],
156 "name" => &["Charlie", "Alice", "Bob"],
157 }
158 .unwrap();
159
160 let mut engine = SqlEngine::new();
161 engine.register_lazyframe("users", df.lazy());
162
163 let result = engine.execute("SELECT * FROM users ORDER BY id").unwrap();
164 let ids = result.column("id").unwrap();
165 assert_eq!(ids.i32().unwrap().get(0), Some(1));
166 }
167
168 #[test]
169 fn test_limit() {
170 let df = df! {
171 "id" => &[1, 2, 3, 4, 5],
172 }
173 .unwrap();
174
175 let mut engine = SqlEngine::new();
176 engine.register_lazyframe("data", df.lazy());
177
178 let result = engine.execute("SELECT * FROM data LIMIT 3").unwrap();
179 assert_eq!(result.height(), 3);
180 }
181
182 #[test]
183 fn test_aggregate_functions() {
184 let df = df! {
185 "category" => &["A", "A", "B", "B"],
186 "value" => &[10, 20, 30, 40],
187 }
188 .unwrap();
189
190 let mut engine = SqlEngine::new();
191 engine.register_lazyframe("data", df.lazy());
192
193 let result =
194 engine.execute("SELECT category, SUM(value) as total, AVG(value) as average FROM data GROUP BY category").unwrap();
195 assert_eq!(result.height(), 2);
196 assert!(result.column("total").is_ok());
197 assert!(result.column("average").is_ok());
198 }
199
200 #[test]
201 fn test_register_from_file() {
202 let temp_file = NamedTempFile::with_suffix(".csv").unwrap();
204 let path = temp_file.path();
205
206 std::fs::write(path, "id,name,age\n1,Alice,25\n2,Bob,30\n").unwrap();
208
209 let mut engine = SqlEngine::new();
210 engine.register_table("users", path).unwrap();
211
212 let result = engine.execute("SELECT * FROM users WHERE age > 26").unwrap();
213 assert_eq!(result.height(), 1);
214 }
215
216 #[test]
217 fn test_join() {
218 let users = df! {
219 "id" => &[1, 2, 3],
220 "name" => &["Alice", "Bob", "Charlie"],
221 }
222 .unwrap();
223
224 let orders = df! {
225 "user_id" => &[1, 1, 2],
226 "amount" => &[100, 200, 150],
227 }
228 .unwrap();
229
230 let mut engine = SqlEngine::new();
231 engine.register_lazyframe("users", users.lazy());
232 engine.register_lazyframe("orders", orders.lazy());
233
234 let result = engine
235 .execute("SELECT users.name, SUM(orders.amount) as total FROM users INNER JOIN orders ON users.id = orders.user_id GROUP BY users.name")
236 .unwrap();
237
238 assert_eq!(result.height(), 2); assert!(result.column("total").is_ok());
240 }
241}