Skip to main content

xore_process/
sql.rs

1//! SQL 查询引擎
2//!
3//! 基于 Polars 内置 SQL 引擎执行查询。
4
5use anyhow::{Context, Result};
6use polars::prelude::*;
7use polars::sql::SQLContext;
8use std::collections::HashMap;
9use std::path::Path;
10
11use crate::parser::DataParser;
12
13/// SQL 引擎
14pub struct SqlEngine {
15    /// 已注册的表(表名 -> LazyFrame)
16    tables: HashMap<String, LazyFrame>,
17    /// 数据解析器
18    parser: DataParser,
19}
20
21impl SqlEngine {
22    /// 创建新的 SQL 引擎
23    pub fn new() -> Self {
24        Self { tables: HashMap::new(), parser: DataParser::new() }
25    }
26
27    /// 注册表(从文件加载)
28    pub fn register_table(&mut self, table_name: &str, path: &Path) -> Result<()> {
29        let lf = self.parser.read_lazy(path).with_context(|| {
30            format!(
31                "无法加载表 '{}'\n  --> 文件: {}\n💡 提示: 请确认文件存在且格式正确 (csv/parquet)",
32                table_name,
33                path.display()
34            )
35        })?;
36
37        self.tables.insert(table_name.to_string(), lf);
38        Ok(())
39    }
40
41    /// 注册表(从 LazyFrame)
42    pub fn register_lazyframe(&mut self, table_name: &str, lf: LazyFrame) {
43        self.tables.insert(table_name.to_string(), lf);
44    }
45
46    /// 执行 SQL 查询
47    pub fn execute(&self, sql: &str) -> Result<DataFrame> {
48        // 创建 SQL 上下文
49        let mut ctx = SQLContext::new();
50
51        // 注册所有表
52        for (name, lf) in &self.tables {
53            ctx.register(name, lf.clone());
54        }
55
56        // 执行查询
57        let result_lf = ctx.execute(sql).with_context(|| format!(
58            "SQL 查询执行失败\n  --> SQL: {}\n💡 提示: 检查 SQL 语法,或运行 'xore agent explain \"{}\"' 获取分析",
59            sql, sql
60        ))?;
61
62        // 收集结果
63        result_lf.collect().with_context(|| {
64            format!(
65            "收集查询结果失败\n  --> SQL: {}\n💡 提示: 查询可能返回了过多数据,尝试添加 LIMIT 子句",
66            sql
67        )
68        })
69    }
70}
71
72impl Default for SqlEngine {
73    fn default() -> Self {
74        Self::new()
75    }
76}
77
78#[cfg(test)]
79mod tests {
80    use super::*;
81    use polars::df;
82    use tempfile::NamedTempFile;
83
84    #[test]
85    fn test_simple_select() {
86        let df = df! {
87            "id" => &[1, 2, 3],
88            "name" => &["Alice", "Bob", "Charlie"],
89            "age" => &[25, 30, 35],
90        }
91        .unwrap();
92
93        let mut engine = SqlEngine::new();
94        engine.register_lazyframe("users", df.lazy());
95
96        let result = engine.execute("SELECT * FROM users").unwrap();
97        assert_eq!(result.height(), 3);
98        assert_eq!(result.width(), 3);
99    }
100
101    #[test]
102    fn test_select_with_where() {
103        let df = df! {
104            "id" => &[1, 2, 3],
105            "age" => &[25, 30, 35],
106        }
107        .unwrap();
108
109        let mut engine = SqlEngine::new();
110        engine.register_lazyframe("users", df.lazy());
111
112        let result = engine.execute("SELECT * FROM users WHERE age > 28").unwrap();
113        assert_eq!(result.height(), 2); // 30 和 35
114    }
115
116    #[test]
117    fn test_select_columns() {
118        let df = df! {
119            "id" => &[1, 2, 3],
120            "name" => &["Alice", "Bob", "Charlie"],
121            "age" => &[25, 30, 35],
122        }
123        .unwrap();
124
125        let mut engine = SqlEngine::new();
126        engine.register_lazyframe("users", df.lazy());
127
128        let result = engine.execute("SELECT name, age FROM users").unwrap();
129        assert_eq!(result.width(), 2);
130        assert!(result.column("name").is_ok());
131        assert!(result.column("age").is_ok());
132    }
133
134    #[test]
135    fn test_group_by_count() {
136        let df = df! {
137            "category" => &["A", "B", "A", "C", "B"],
138            "value" => &[10, 20, 30, 40, 50],
139        }
140        .unwrap();
141
142        let mut engine = SqlEngine::new();
143        engine.register_lazyframe("data", df.lazy());
144
145        let result = engine
146            .execute("SELECT category, COUNT(*) as count FROM data GROUP BY category")
147            .unwrap();
148        assert_eq!(result.height(), 3); // A, B, C
149        assert!(result.column("count").is_ok());
150    }
151
152    #[test]
153    fn test_order_by() {
154        let df = df! {
155            "id" => &[3, 1, 2],
156            "name" => &["Charlie", "Alice", "Bob"],
157        }
158        .unwrap();
159
160        let mut engine = SqlEngine::new();
161        engine.register_lazyframe("users", df.lazy());
162
163        let result = engine.execute("SELECT * FROM users ORDER BY id").unwrap();
164        let ids = result.column("id").unwrap();
165        assert_eq!(ids.i32().unwrap().get(0), Some(1));
166    }
167
168    #[test]
169    fn test_limit() {
170        let df = df! {
171            "id" => &[1, 2, 3, 4, 5],
172        }
173        .unwrap();
174
175        let mut engine = SqlEngine::new();
176        engine.register_lazyframe("data", df.lazy());
177
178        let result = engine.execute("SELECT * FROM data LIMIT 3").unwrap();
179        assert_eq!(result.height(), 3);
180    }
181
182    #[test]
183    fn test_aggregate_functions() {
184        let df = df! {
185            "category" => &["A", "A", "B", "B"],
186            "value" => &[10, 20, 30, 40],
187        }
188        .unwrap();
189
190        let mut engine = SqlEngine::new();
191        engine.register_lazyframe("data", df.lazy());
192
193        let result =
194            engine.execute("SELECT category, SUM(value) as total, AVG(value) as average FROM data GROUP BY category").unwrap();
195        assert_eq!(result.height(), 2);
196        assert!(result.column("total").is_ok());
197        assert!(result.column("average").is_ok());
198    }
199
200    #[test]
201    fn test_register_from_file() {
202        // 创建带 .csv 扩展名的临时文件
203        let temp_file = NamedTempFile::with_suffix(".csv").unwrap();
204        let path = temp_file.path();
205
206        // 写入 CSV 数据
207        std::fs::write(path, "id,name,age\n1,Alice,25\n2,Bob,30\n").unwrap();
208
209        let mut engine = SqlEngine::new();
210        engine.register_table("users", path).unwrap();
211
212        let result = engine.execute("SELECT * FROM users WHERE age > 26").unwrap();
213        assert_eq!(result.height(), 1);
214    }
215
216    #[test]
217    fn test_join() {
218        let users = df! {
219            "id" => &[1, 2, 3],
220            "name" => &["Alice", "Bob", "Charlie"],
221        }
222        .unwrap();
223
224        let orders = df! {
225            "user_id" => &[1, 1, 2],
226            "amount" => &[100, 200, 150],
227        }
228        .unwrap();
229
230        let mut engine = SqlEngine::new();
231        engine.register_lazyframe("users", users.lazy());
232        engine.register_lazyframe("orders", orders.lazy());
233
234        let result = engine
235            .execute("SELECT users.name, SUM(orders.amount) as total FROM users INNER JOIN orders ON users.id = orders.user_id GROUP BY users.name")
236            .unwrap();
237
238        assert_eq!(result.height(), 2); // Alice 和 Bob
239        assert!(result.column("total").is_ok());
240    }
241}