Skip to main content

we_trust_sqlite/
executor.rs

1use crate::reader::NativeReader;
2use crate::writer::NativeWriter;
3use std::cmp::Ordering;
4use std::collections::{BTreeMap, HashMap};
5use std::sync::Arc;
6use yykv_operators::{
7    AggExpr, Expr, OperatorNode, OpsGraph,
8    sql::{SqlParser, SqlStatement},
9};
10use yykv_types::{DsError, DsValue};
11
12type Result<T> = std::result::Result<T, DsError>;
13
14/// Wrapper to make YYValue hashable and comparable
15#[derive(Debug, Clone)]
16struct HashableYYValue(DsValue);
17
18impl PartialEq for HashableYYValue {
19    fn eq(&self, other: &Self) -> bool {
20        self.0.partial_cmp(&other.0) == Some(Ordering::Equal)
21    }
22}
23
24impl Eq for HashableYYValue {}
25
26impl std::hash::Hash for HashableYYValue {
27    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
28        match &self.0 {
29            DsValue::Null => 0.hash(state),
30            DsValue::Bool(b) => b.hash(state),
31            DsValue::Int(i) => i.hash(state),
32            DsValue::Float(f) => f.to_bits().hash(state),
33            DsValue::Text(s) => s.hash(state),
34            DsValue::Bytes(b) | DsValue::Binary(b) => b.hash(state),
35            DsValue::Uuid(u) => u.hash(state),
36            _ => 0.hash(state), // Simplified for other types
37        }
38    }
39}
40
41/// SQLite 专用执行器
42pub struct SqliteExecutor {
43    reader: Arc<NativeReader>,
44    writer: Arc<NativeWriter>,
45}
46
47impl SqliteExecutor {
48    pub fn new(reader: Arc<NativeReader>, writer: Arc<NativeWriter>) -> Self {
49        Self { reader, writer }
50    }
51
52    /// 执行 SQL 查询
53    pub async fn execute_query(&self, sql: &str) -> Result<Vec<DsValue>> {
54        self.writer.ensure_initialized().await?;
55        let stmt = SqlParser::parse_sql(sql).map_err(|e| DsError::query(e.to_string()))?;
56        match stmt {
57            SqlStatement::Query(graph) => self.execute_graph(graph).await,
58            SqlStatement::Insert { table, key, value } => {
59                let schemas = self.reader.get_schemas().await?;
60                let schema = schemas
61                    .iter()
62                    .find(|s| s.name == table)
63                    .ok_or_else(|| DsError::query(format!("Table not found: {}", table)))?;
64
65                let row_id = match key {
66                    DsValue::Int(i) => i,
67                    _ => {
68                        return Err(DsError::query(
69                            "Only integer primary keys (rowid) are supported for now",
70                        ));
71                    }
72                };
73
74                let values = match value {
75                    DsValue::List(l) => l,
76                    _ => vec![value],
77                };
78
79                self.writer
80                    .insert_into_leaf_page(schema.rootpage, row_id, &values, vec![])
81                    .await?;
82                self.reader.clear_cache().await;
83                Ok(vec![DsValue::Text(format!(
84                    "INSERT INTO {} SUCCESS",
85                    table
86                ))])
87            }
88            SqlStatement::CreateTable { name, schema } => {
89                // 1. 分配新页面
90                let new_page_id = self.writer.allocate_page().await?;
91
92                // 2. 初始化为叶子表页面
93                self.writer.init_leaf_page(new_page_id).await?;
94
95                // 3. 写入 sqlite_master (Page 1)
96                // SQLite master schema: (type, name, tbl_name, rootpage, sql)
97                let sql_text = if let Some(s) = schema {
98                    let mut cols = Vec::new();
99                    for field in s.fields.values() {
100                        cols.push(format!("{} {:?}", field.name, field.field_type));
101                    }
102                    format!("CREATE TABLE {} ({})", name, cols.join(", "))
103                } else {
104                    format!("CREATE TABLE {} (dummy)", name)
105                };
106
107                let master_values = vec![
108                    DsValue::Text("table".to_string()),
109                    DsValue::Text(name.clone()),
110                    DsValue::Text(name.clone()),
111                    DsValue::Int(new_page_id as i64),
112                    DsValue::Text(sql_text),
113                ];
114
115                // 简单的 row_id 生成逻辑 (实际应查找最大 row_id)
116                let schemas = self.reader.get_schemas().await?;
117                let next_row_id = schemas.len() as i64 + 1;
118
119                self.writer
120                    .insert_into_leaf_page(1, next_row_id, &master_values, vec![])
121                    .await?;
122
123                self.reader.clear_cache().await;
124
125                Ok(vec![DsValue::Text(format!(
126                    "CREATE TABLE {} SUCCESS AT PAGE {}",
127                    name, new_page_id
128                ))])
129            }
130            SqlStatement::Delete { table, selection } => {
131                let schemas = self.reader.get_schemas().await?;
132                let schema = schemas
133                    .iter()
134                    .find(|s| s.name == table)
135                    .ok_or_else(|| DsError::query(format!("Table not found: {}", table)))?;
136
137                let rows = self.reader.scan_table(schema.rootpage).await?;
138                let mut deleted_count = 0;
139
140                for row in rows {
141                    let mut fields = BTreeMap::new();
142                    let columns = schema.get_columns();
143                    for (i, val) in row.values.iter().enumerate() {
144                        let col_name = columns
145                            .get(i)
146                            .cloned()
147                            .unwrap_or_else(|| format!("col_{}", i));
148                        fields.insert(col_name, val.clone());
149                    }
150                    let dict_val = DsValue::Dict(fields);
151
152                    let should_delete = if let Some(expr) = &selection {
153                        match expr.evaluate(&dict_val).map_err(|e| DsError::query(e))? {
154                            DsValue::Bool(b) => b,
155                            _ => false,
156                        }
157                    } else {
158                        true
159                    };
160
161                    if should_delete {
162                        self.writer
163                            .delete_from_leaf_page(schema.rootpage, row.row_id)
164                            .await?;
165                        deleted_count += 1;
166                    }
167                }
168                self.reader.clear_cache().await;
169                Ok(vec![DsValue::Text(format!(
170                    "DELETE {} ROWS FROM {} SUCCESS",
171                    deleted_count, table
172                ))])
173            }
174            SqlStatement::Update {
175                table,
176                assignments,
177                selection,
178            } => {
179                let schemas = self.reader.get_schemas().await?;
180                let schema = schemas
181                    .iter()
182                    .find(|s| s.name == table)
183                    .ok_or_else(|| DsError::query(format!("Table not found: {}", table)))?;
184
185                let rows = self.reader.scan_table(schema.rootpage).await?;
186                let mut updated_count = 0;
187
188                for row in rows {
189                    let mut fields = BTreeMap::new();
190                    let columns = schema.get_columns();
191                    for (i, val) in row.values.iter().enumerate() {
192                        let col_name = columns
193                            .get(i)
194                            .cloned()
195                            .unwrap_or_else(|| format!("col_{}", i));
196                        fields.insert(col_name, val.clone());
197                    }
198                    let dict_val = DsValue::Dict(fields.clone());
199
200                    let should_update = if let Some(expr) = &selection {
201                        match expr.evaluate(&dict_val).map_err(|e| DsError::query(e))? {
202                            DsValue::Bool(b) => b,
203                            _ => false,
204                        }
205                    } else {
206                        true
207                    };
208
209                    if should_update {
210                        let mut new_values = row.values.clone();
211                        for (col_name, expr) in &assignments {
212                            if let Some(idx) = columns.iter().position(|c| c == col_name) {
213                                new_values[idx] =
214                                    expr.evaluate(&dict_val).map_err(|e| DsError::query(e))?;
215                            }
216                        }
217
218                        // re-insert with same row_id (will overwrite if implementation supports it)
219                        self.writer
220                            .insert_into_leaf_page(schema.rootpage, row.row_id, &new_values, vec![])
221                            .await?;
222                        updated_count += 1;
223                    }
224                }
225                self.reader.clear_cache().await;
226                Ok(vec![DsValue::Text(format!(
227                    "UPDATE {} ROWS IN {} SUCCESS",
228                    updated_count, table
229                ))])
230            }
231            _ => Err(DsError::query("Unsupported SQL statement")),
232        }
233    }
234
235    /// 执行插入操作 (Placeholder)
236    pub async fn execute_insert(
237        &self,
238        _table: &str,
239        _columns: Vec<String>,
240        _values: Vec<Vec<DsValue>>,
241    ) -> Result<()> {
242        // ... (implementation pending adaptation to new SqlStatement)
243        Ok(())
244    }
245
246    /// 执行算子图
247    pub async fn execute_graph(&self, graph: OpsGraph) -> Result<Vec<DsValue>> {
248        let mut results = Vec::new();
249
250        for node in &graph.nodes {
251            match node {
252                OperatorNode::Source(op) => {
253                    let rows = self.reader.scan_table_by_name(&op.table).await?;
254                    for row in rows {
255                        let mut fields = BTreeMap::new();
256                        let schemas = self.reader.get_schemas().await?;
257                        let schema = schemas.iter().find(|s| s.name == op.table);
258                        let columns = if let Some(s) = schema {
259                            s.get_columns()
260                        } else {
261                            op.columns.clone()
262                        };
263
264                        for (i, val) in row.values.into_iter().enumerate() {
265                            let col_name = columns
266                                .get(i)
267                                .cloned()
268                                .unwrap_or_else(|| format!("col_{}", i));
269                            fields.insert(col_name, val);
270                        }
271                        results.push(DsValue::Dict(fields));
272                    }
273                }
274                OperatorNode::Filter(op) => {
275                    let mut filtered_results = Vec::new();
276                    for val in results {
277                        match op
278                            .expression
279                            .evaluate(&val)
280                            .map_err(|e| DsError::query(e))?
281                        {
282                            DsValue::Bool(b) if b => filtered_results.push(val),
283                            _ => {}
284                        }
285                    }
286                    results = filtered_results;
287                }
288                OperatorNode::Project(op) => {
289                    let mut projected_results = Vec::new();
290                    for val in results {
291                        let mut new_fields = BTreeMap::new();
292                        for (alias, expr) in &op.projections {
293                            new_fields.insert(
294                                alias.clone(),
295                                expr.evaluate(&val).map_err(|e| DsError::query(e))?,
296                            );
297                        }
298                        projected_results.push(DsValue::Dict(new_fields));
299                    }
300                    results = projected_results;
301                }
302                OperatorNode::Aggregate(op) => {
303                    let mut groups: HashMap<Vec<HashableYYValue>, Vec<DsValue>> = HashMap::new();
304
305                    for val in results {
306                        let mut group_key = Vec::new();
307                        for expr in &op.group_by {
308                            group_key.push(HashableYYValue(
309                                expr.evaluate(&val).map_err(|e| DsError::query(e))?,
310                            ));
311                        }
312                        groups.entry(group_key).or_default().push(val);
313                    }
314
315                    let mut agg_results = Vec::new();
316                    for (key, group_rows) in groups {
317                        let mut agg_row = BTreeMap::new();
318                        // 放入 group by 的列
319                        for (i, expr) in op.group_by.iter().enumerate() {
320                            if let Expr::Column(name) = expr {
321                                agg_row.insert(name.clone(), key[i].0.clone());
322                            }
323                        }
324
325                        // 执行聚合
326                        for agg_expr in &op.aggs {
327                            match agg_expr {
328                                AggExpr::Count(_) => {
329                                    agg_row.insert(
330                                        "count".to_string(),
331                                        DsValue::Int(group_rows.len() as i64),
332                                    );
333                                }
334                                AggExpr::Sum(expr) => {
335                                    let sum: i64 = group_rows
336                                        .iter()
337                                        .map(|row| {
338                                            if let Ok(DsValue::Int(i)) = expr.evaluate(row) {
339                                                i
340                                            } else {
341                                                0
342                                            }
343                                        })
344                                        .sum();
345                                    agg_row.insert("sum".to_string(), DsValue::Int(sum));
346                                }
347                                _ => {} // 其他聚合函数暂未实现
348                            }
349                        }
350                        agg_results.push(DsValue::Dict(agg_row));
351                    }
352                    results = agg_results;
353                }
354                OperatorNode::VectorSearch(op) => {
355                    // 简单的暴力余弦相似度搜索
356                    let mut scored_results: Vec<(f32, DsValue)> = results
357                        .into_iter()
358                        .map(|val| {
359                            let score = if let Ok(DsValue::Vector(v)) =
360                                Expr::Column(op.column.clone()).evaluate(&val)
361                            {
362                                self.cosine_similarity(&op.vector, &v.0)
363                            } else {
364                                0.0
365                            };
366                            (score, val)
367                        })
368                        .collect();
369
370                    scored_results.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap());
371                    results = scored_results
372                        .into_iter()
373                        .take(op.limit)
374                        .map(|(_, val)| val)
375                        .collect();
376                }
377                _ => {
378                    return Err(DsError::query(format!("Unsupported operator: {:?}", node)));
379                }
380            }
381        }
382
383        Ok(results)
384    }
385
386    fn cosine_similarity(&self, v1: &[f32], v2: &[f32]) -> f32 {
387        if v1.len() != v2.len() || v1.is_empty() {
388            return 0.0;
389        }
390        let dot_product: f32 = v1.iter().zip(v2.iter()).map(|(a, b)| a * b).sum();
391        let norm_v1: f32 = v1.iter().map(|a| a * a).sum::<f32>().sqrt();
392        let norm_v2: f32 = v2.iter().map(|a| a * a).sum::<f32>().sqrt();
393        if norm_v1 == 0.0 || norm_v2 == 0.0 {
394            return 0.0;
395        }
396        dot_product / (norm_v1 * norm_v2)
397    }
398}