teo_sql_connector/execution/
mod.rs

1use std::borrow::Cow;
2use array_tool::vec::Uniq;
3use std::backtrace::Backtrace;
4use std::collections::HashMap;
5use async_recursion::async_recursion;
6use indexmap::IndexMap;
7use key_path::KeyPath;
8use quaint_forked::prelude::{Queryable, ResultRow};
9use quaint_forked::ast::{Query as QuaintQuery};
10use teo_parser::r#type::Type;
11use crate::query::Query;
12use crate::schema::dialect::SQLDialect;
13use crate::schema::value::decode::RowDecoder;
14use crate::schema::value::encode::{SQLEscape, ToSQLString, ToWrapped};
15use teo_runtime::action::Action;
16use teo_runtime::connection::transaction;
17use teo_runtime::model::field::column_named::ColumnNamed;
18use teo_runtime::model::field::is_optional::IsOptional;
19use teo_runtime::model::field::typed::Typed;
20use teo_runtime::model::object::input::Input;
21use teo_runtime::model::Model;
22use teo_runtime::model::Object;
23use teo_runtime::namespace::Namespace;
24use teo_runtime::error_ext;
25use teo_runtime::request;
26use teo_runtime::request::Request;
27use teo_runtime::traits::named::Named;
28use teo_runtime::value::Value;
29use teo_runtime::teon;
30
31pub(crate) struct Execution { }
32
33impl Execution {
34
35    pub(crate) fn row_to_value(namespace: &Namespace, model: &Model, row: &ResultRow, columns: &Vec<String>, dialect: SQLDialect) -> Value {
36        Value::Dictionary(columns.iter().filter_map(|column_name| {
37            if let Some(field) = model.field_with_column_name(column_name) {
38                if field.auto_increment() && dialect == SQLDialect::PostgreSQL {
39                    Some((field.name().to_owned(), RowDecoder::decode_serial(field.is_optional(), row, column_name)))
40                } else {
41                    Some((field.name().to_owned(), RowDecoder::decode(field.r#type(), field.is_optional(), row, column_name, dialect)))
42                }
43            } else if let Some(property) = model.property_with_column_name(column_name) {
44                Some((property.column_name().to_owned(), RowDecoder::decode(property.r#type(), property.is_optional(), row, column_name, dialect)))
45            } else if column_name.contains(".") {
46                let names: Vec<&str> = column_name.split(".").collect();
47                let relation_name = names[0];
48                let field_name = names[1];
49                if relation_name == "c" { // cursor fetch, should remove
50                    None
51                } else {
52                    let relation = model.relation(relation_name).unwrap();
53                    let opposite_model = namespace.model_at_path(&relation.model_path()).unwrap();
54                    let field = opposite_model.field(field_name).unwrap();
55                    Some((column_name.to_owned(), RowDecoder::decode(field.r#type(), field.is_optional(), row, column_name, dialect)))
56                }
57            } else {
58                panic!("Unhandled key {}.", column_name);
59            }
60        }).collect())
61    }
62
63    fn row_to_aggregate_value(model: &Model, row: &ResultRow, columns: &Vec<String>, dialect: SQLDialect) -> Value {
64        let mut retval: IndexMap<String, Value> = IndexMap::new();
65        for column in columns {
66            let result_key = column.as_str();
67            if result_key.contains(".") {
68                let splitted = result_key.split(".").collect::<Vec<&str>>();
69                let group = *splitted.get(0).unwrap();
70                let field_name = *splitted.get(1).unwrap();
71                if !retval.contains_key(group) {
72                    retval.insert(group.to_string(), Value::Dictionary(IndexMap::new()));
73                }
74                if group == "_count" { // force i64
75                    let count: i64 = row.get(result_key).unwrap().as_i64().unwrap();
76                    retval.get_mut(group).unwrap().as_dictionary_mut().unwrap().insert(field_name.to_string(), teon!(count));
77                } else if group == "_avg" || group == "_sum" { // force f64
78                    let v = RowDecoder::decode(&Type::Float, true, &row, result_key, dialect);
79                    retval.get_mut(group).unwrap().as_dictionary_mut().unwrap().insert(field_name.to_string(), v);
80                } else { // field type
81                    let field = model.field(field_name).unwrap();
82                    let v = RowDecoder::decode(field.r#type(), true, &row, result_key, dialect);
83                    retval.get_mut(group).unwrap().as_dictionary_mut().unwrap().insert(field_name.to_string(), v);
84                }
85            } else if let Some(field) = model.field_with_column_name(result_key) {
86                retval.insert(field.name().to_owned(), RowDecoder::decode(field.r#type(), field.is_optional(), row, result_key, dialect));
87            } else if let Some(property) = model.property(result_key) {
88                retval.insert(property.name().to_owned(), RowDecoder::decode(property.r#type(), property.is_optional(), row, result_key, dialect));
89            }
90        }
91        Value::Dictionary(retval)
92    }
93
94    pub(crate) async fn query_objects<'a>(namespace: &Namespace, conn: &'a dyn Queryable, model: &Model, finder: &'a Value, dialect: SQLDialect, action: Action, transaction_ctx: transaction::Ctx, request: Option<Request>, path: KeyPath) -> teo_result::Result<Vec<Object>> {
95        let values = Self::query(namespace, conn, model, finder, dialect, path).await?;
96        let select = finder.as_dictionary().unwrap().get("select");
97        let include = finder.as_dictionary().unwrap().get("include");
98        let mut results = vec![];
99        for value in values {
100            let object = transaction_ctx.new_object(model, action, request.clone())?;
101            object.set_from_database_result_value(&value, select, include);
102            results.push(object);
103        }
104        Ok(results)
105    }
106
107    #[async_recursion]
108    async fn query_internal(namespace: &Namespace, conn: &dyn Queryable, model: &Model, value: &Value, dialect: SQLDialect, additional_where: Option<String>, additional_left_join: Option<String>, join_table_results: Option<Vec<String>>, force_negative_take: bool, additional_distinct: Option<Vec<String>>, path: KeyPath) -> teo_result::Result<Vec<Value>> {
109        let _select = value.get("select");
110        let include = value.get("include");
111        let original_distinct = value.get("distinct").map(|v| if v.as_array().unwrap().is_empty() { None } else { Some(v.as_array().unwrap()) }).flatten();
112        let distinct = Self::merge_distinct(original_distinct, additional_distinct);
113        let skip = value.get("skip");
114        let take = value.get("take");
115        let should_in_memory_take_skip = distinct.is_some() && (skip.is_some() || take.is_some());
116        let value_for_build = if should_in_memory_take_skip {
117            Self::without_paging_and_skip_take(value)
118        } else {
119            Cow::Borrowed(value)
120        };
121        let stmt = Query::build(namespace, model, value_for_build.as_ref(), dialect, additional_where, additional_left_join, join_table_results, force_negative_take)?;
122        // println!("see sql query stmt: {}", &stmt);
123        let reverse = Input::has_negative_take(value);
124        let rows = match conn.query(QuaintQuery::from(stmt)).await {
125            Ok(rows) => rows,
126            Err(err) => {
127                return Err(error_ext::unknown_database_find_error(path.clone(), format!("{:?}", err)));
128            }
129        };
130        if rows.is_empty() {
131            return Ok(vec![])
132        }
133        let columns = rows.columns().clone();
134        let mut results = rows.into_iter().map(|row| Self::row_to_value(namespace, model, &row, &columns, dialect)).collect::<Vec<Value>>();
135        if reverse {
136            results.reverse();
137        }
138        if let Some(distinct) = distinct {
139            let distinct_keys = distinct.iter().map(|s| s.as_str()).collect::<Vec<&str>>();
140            results = results.unique_via(|a, b| {
141                Self::sub_hashmap(a, &distinct_keys) == Self::sub_hashmap(b, &distinct_keys)
142            });
143        }
144        if should_in_memory_take_skip {
145            let skip = skip.map(|s| s.as_int64().unwrap()).unwrap_or(0) as usize;
146            let take = take.map(|s| s.as_int64().unwrap().abs() as u64).unwrap_or(0) as usize;
147            results = results.into_iter().enumerate().filter(|(i, _r)| {
148                *i >= skip && *i < (skip + take)
149            }).map(|(_i, r)| r.clone()).collect();
150            if reverse {
151                results.reverse();
152            }
153        }
154        if let Some(include) = include.map(|i| i.as_dictionary().unwrap()) {
155            for (key, value) in include {
156                let skip = value.as_dictionary().map(|m| m.get("skip")).flatten().map(|v| v.as_int64().unwrap());
157                let take = value.as_dictionary().map(|m| m.get("take")).flatten().map(|v| v.as_int64().unwrap());
158                let take_abs = take.map(|t| t.abs() as u64);
159                let negative_take = take.map(|v| v.is_negative()).unwrap_or(false);
160                let inner_distinct = value.as_dictionary().map(|m| m.get("distinct")).flatten().map(|v| if v.as_array().unwrap().is_empty() { None } else { Some(v.as_array().unwrap()) }).flatten();
161                let relation = model.relation(key).unwrap();
162                let (opposite_model, _) = namespace.opposite_relation(relation);
163                if !relation.has_join_table() {
164                    let fields = relation.fields();
165                    let opposite_fields = relation.references();
166                    let names = if opposite_fields.len() == 1 {
167                        opposite_model.field(opposite_fields.get(0).unwrap()).unwrap().column_name().escape(dialect)
168                    } else {
169                        opposite_fields.iter().map(|f| opposite_model.field(f).unwrap().column_name().escape(dialect)).collect::<Vec<String>>().join(",").to_wrapped()
170                    };
171                    let values = if opposite_fields.len() == 1 {
172                        // in a (?,?,?,?,?) format
173                        let field_name = fields.get(0).unwrap();
174                        results.iter().map(|v| {
175                            ToSQLString::to_string(&v.as_dictionary().unwrap().get(field_name).unwrap(), dialect)
176                        }).collect::<Vec<String>>().join(",").to_wrapped()
177                    } else {
178                        // in a (VALUES (?,?),(?,?)) format
179                        format!("(VALUES {})", results.iter().map(|o| {
180                            fields.iter().map(|f| ToSQLString::to_string(&o.as_dictionary().unwrap().get(f).unwrap(), dialect)).collect::<Vec<String>>().join(",").to_wrapped()
181                        }).collect::<Vec<String>>().join(","))
182                    };
183                    let where_addition = Query::where_item(&names, "IN", &values);
184                    let nested_query = if value.is_dictionary() {
185                        Self::without_paging_and_skip_take_distinct(value)
186                    } else {
187                        Cow::Owned(teon!({}))
188                    };
189                    let included_values = Self::query_internal(namespace, conn, opposite_model, &nested_query, dialect, Some(where_addition), None, None, negative_take, None, path.clone()).await?;
190                    // println!("see included: {:?}", included_values);
191                    for result in results.iter_mut() {
192                        let mut skipped = 0;
193                        let mut taken = 0;
194                        if relation.is_vec() {
195                            result.as_dictionary_mut().unwrap().insert(relation.name().to_owned(), Value::Array(vec![]));
196                        }
197                        for included_value in included_values.iter() {
198                            let mut matched = true;
199                            for (field, reference) in relation.iter() {
200                                if included_value.get(reference).is_none() && result.get(field).is_none() {
201                                    matched = false;
202                                    break;
203                                }
204                                if included_value.get(reference) != result.get(field) {
205                                    matched = false;
206                                    break;
207                                }
208                            }
209                            if matched {
210                                if (skip.is_none() || skip.unwrap() <= skipped) && (take.is_none() || taken < take_abs.unwrap()) {
211                                    if result.get(relation.name()).is_none() {
212                                        result.as_dictionary_mut().unwrap().insert(relation.name().to_owned(), Value::Array(vec![]));
213                                    }
214                                    if negative_take {
215                                        result.as_dictionary_mut().unwrap().get_mut(relation.name()).unwrap().as_array_mut().unwrap().insert(0, included_value.clone());
216                                    } else {
217                                        result.as_dictionary_mut().unwrap().get_mut(relation.name()).unwrap().as_array_mut().unwrap().push(included_value.clone());
218                                    }
219                                    taken += 1;
220                                    if take.is_some() && (taken >= take_abs.unwrap()) {
221                                        break;
222                                    }
223                                } else {
224                                    skipped += 1;
225                                }
226                            }
227                        }
228                    }
229                } else {
230                    let (opposite_model, opposite_relation) = namespace.opposite_relation(relation);
231                    let (through_model, through_opposite_relation) = namespace.through_opposite_relation(relation);
232                    let mut join_parts: Vec<String> = vec![];
233                    for (field, reference) in through_opposite_relation.iter() {
234                        let field_column_name = through_model.field(field).unwrap().column_name();
235                        let reference_column_name = opposite_model.field(reference).unwrap().column_name();
236                        join_parts.push(format!("t.{} = j.{}", reference_column_name.escape(dialect), field_column_name.escape(dialect)));
237                    }
238                    let joins = join_parts.join(" AND ");
239                    let left_join = format!("{} AS j ON {}", &through_model.table_name().escape(dialect), joins);
240                    let (through_table, through_relation) = namespace.through_relation(relation);
241                    let names = if through_relation.len() == 1 { // todo: column name
242                        format!("j.{}", through_table.field(through_relation.fields().get(0).unwrap()).unwrap().column_name().escape(dialect))
243                    } else {
244                        through_relation.fields().iter().map(|f| format!("j.{}", through_table.field(f).unwrap().column_name().escape(dialect))).collect::<Vec<String>>().join(",").to_wrapped()
245                    };
246                    let values = if through_relation.len() == 1 { // (?,?,?,?,?) format
247                        let references = through_relation.references();
248                        let field_name = references.get(0).unwrap();
249                        results.iter().map(|v| {
250                            ToSQLString::to_string(&v.as_dictionary().unwrap().get(field_name).unwrap(), dialect)
251                        }).collect::<Vec<String>>().join(",").to_wrapped()
252                    } else { // (VALUES (?,?),(?,?)) format
253                        let pairs = results.iter().map(|o| {
254                            through_relation.references().iter().map(|f| ToSQLString::to_string(&o.as_dictionary().unwrap().get(f).unwrap(), dialect)).collect::<Vec<String>>().join(",").to_wrapped()
255                        }).collect::<Vec<String>>().join(",");
256                        format!("(VALUES {})", pairs)
257                    };
258                    let where_addition = Query::where_item(&names, "IN", &values);
259                    let nested_query = if value.is_dictionary() {
260                        Self::without_paging_and_skip_take(value)
261                    } else {
262                        Cow::Owned(teon!({}))
263                    };
264                    let join_table_results = through_relation.iter().map(|(f, r)| {
265                        let through_column_name = through_model.field(f).unwrap().column_name().to_string();
266                        if dialect == SQLDialect::PostgreSQL {
267                            format!("j.{} AS \"{}.{}\"", through_column_name.as_str().escape(dialect), opposite_relation.unwrap().name(), r)
268                        } else {
269                            format!("j.{} AS `{}.{}`", through_column_name, opposite_relation.unwrap().name(), r)
270                        }
271                    }).collect();
272                    let additional_inner_distinct = if inner_distinct.is_some() {
273                        Some(through_relation.iter().map(|(_f, r)| {
274                            format!("{}.{}", opposite_relation.unwrap().name(), r)
275                        }).collect())
276                    } else {
277                        None
278                    };
279                    let included_values = Self::query_internal(namespace, conn, opposite_model, &nested_query, dialect, Some(where_addition), Some(left_join), Some(join_table_results), negative_take, additional_inner_distinct, path.clone()).await?;
280                    // println!("see included {:?}", included_values);
281                    for result in results.iter_mut() {
282                        result.as_dictionary_mut().unwrap().insert(relation.name().to_owned(), Value::Array(vec![]));
283                        let mut skipped = 0;
284                        let mut taken = 0;
285                        for included_value in included_values.iter() {
286                            let mut matched = true;
287                            for (_field, reference) in through_relation.iter() {
288                                let key = format!("{}.{}", opposite_relation.unwrap().name(), reference);
289                                if result.get(reference).is_none() && included_value.get(&key).is_none() {
290                                    matched = false;
291                                    break;
292                                }
293                                if result.get(reference) != included_value.get(&key) {
294                                    matched = false;
295                                    break;
296                                }
297                            }
298                            if matched {
299                                if (skip.is_none() || skip.unwrap() <= skipped) && (take.is_none() || taken < take_abs.unwrap()) {
300                                    if negative_take {
301                                        result.as_dictionary_mut().unwrap().get_mut(relation.name()).unwrap().as_array_mut().unwrap().insert(0, included_value.clone());
302                                    } else {
303                                        result.as_dictionary_mut().unwrap().get_mut(relation.name()).unwrap().as_array_mut().unwrap().push(included_value.clone());
304                                    }
305                                    taken += 1;
306                                    if take.is_some() && (taken >= take_abs.unwrap()) {
307                                        break;
308                                    }
309                                } else {
310                                    skipped += 1;
311                                }
312                            }
313                        }
314                    }
315                }
316            }
317        }
318        Ok(results)
319    }
320
321    pub(crate) async fn query(namespace: &Namespace, conn: &dyn Queryable, model: &Model, finder: &Value, dialect: SQLDialect, path: KeyPath) -> teo_result::Result<Vec<Value>> {
322       Self::query_internal(namespace, conn, model, finder, dialect, None, None, None, false, None, path).await
323    }
324
325    pub(crate) async fn query_aggregate(namespace: &Namespace, conn: &dyn Queryable, model: &Model, finder: &Value, dialect: SQLDialect, path: KeyPath) -> teo_result::Result<Value> {
326        let stmt = Query::build_for_aggregate(namespace, model, finder, dialect)?;
327        match conn.query(QuaintQuery::from(&*stmt)).await {
328            Ok(result_set) => {
329                let columns = result_set.columns().clone();
330                let result = result_set.into_iter().next().unwrap();
331                Ok(Self::row_to_aggregate_value(model, &result, &columns, dialect))
332            },
333            Err(err) => {
334                return Err(error_ext::unknown_database_find_error(path, format!("{:?}", err)));
335            }
336        }
337    }
338
339    pub(crate) async fn query_group_by(namespace: &Namespace, conn: &dyn Queryable, model: &Model, finder: &Value, dialect: SQLDialect, path: KeyPath) -> teo_result::Result<Vec<Value>> {
340        let stmt = Query::build_for_group_by(namespace, model, finder, dialect)?;
341        let rows = match conn.query(QuaintQuery::from(stmt)).await {
342            Ok(rows) => rows,
343            Err(err) => {
344                return Err(error_ext::unknown_database_find_error(path.clone(), format!("{:?}", err)));
345            }
346        };
347        let columns = rows.columns().clone();
348        Ok(rows.into_iter().map(|r| {
349            Self::row_to_aggregate_value(model, &r, &columns, dialect)
350        }).collect::<Vec<Value>>())
351    }
352
353    pub(crate) async fn query_count(namespace: &Namespace, conn: &dyn Queryable, model: &Model, finder: &Value, dialect: SQLDialect, path: KeyPath) -> teo_result::Result<Value> {
354        if finder.get("select").is_some() {
355            Self::query_count_fields(namespace, conn, model, finder, dialect, path).await
356        } else {
357            let result = Self::query_count_objects(namespace, conn, model, finder, dialect, path).await?;
358            Ok(Value::Int64(result as i64))
359        }
360    }
361
362    pub(crate) async fn query_count_objects(namespace: &Namespace, conn: &dyn Queryable, model: &Model, finder: &Value, dialect: SQLDialect, path: KeyPath) -> teo_result::Result<usize> {
363        let stmt = Query::build_for_count(namespace, model, finder, dialect, None, None, None, false)?;
364        match conn.query(QuaintQuery::from(stmt)).await {
365            Ok(result) => {
366                let result = result.into_iter().next().unwrap();
367                let count: i64 = result.into_iter().next().unwrap().as_i64().unwrap();
368                Ok(count as usize)
369            },
370            Err(err) => {
371                return Err(error_ext::unknown_database_find_error(path.clone(), format!("{:?}", err)));
372            }
373        }
374    }
375
376    pub(crate) async fn query_count_fields(namespace: &Namespace, conn: &dyn Queryable, model: &Model, finder: &Value, dialect: SQLDialect, path: KeyPath) -> teo_result::Result<Value> {
377        let new_finder = Value::Dictionary(finder.as_dictionary().unwrap().iter().map(|(k, v)| {
378            if k.as_str() == "select" {
379                ("_count".to_owned(), v.clone())
380            } else {
381                (k.to_owned(), v.clone())
382            }
383        }).collect());
384        let aggregate_value = Self::query_aggregate(namespace, conn, model, &new_finder, dialect, path).await?;
385        Ok(aggregate_value.get("_count").unwrap().clone())
386    }
387
388    fn without_paging_and_skip_take(value: &Value) -> Cow<Value> {
389        let map = value.as_dictionary().unwrap();
390        if map.contains_key("take") || map.contains_key("skip") || map.contains_key("pageSize") || map.contains_key("pageNumber") {
391            let mut map = map.clone();
392            map.remove("take");
393            map.remove("skip");
394            map.remove("pageSize");
395            map.remove("pageNumber");
396            Cow::Owned(Value::Dictionary(map))
397        } else {
398            Cow::Borrowed(value)
399        }
400    }
401
402    fn without_paging_and_skip_take_distinct(value: &Value) -> Cow<Value> {
403        let map = value.as_dictionary().unwrap();
404        if map.contains_key("take") || map.contains_key("skip") || map.contains_key("pageSize") || map.contains_key("pageNumber") {
405            let mut map = map.clone();
406            map.remove("take");
407            map.remove("skip");
408            map.remove("pageSize");
409            map.remove("pageNumber");
410            map.remove("distinct");
411            Cow::Owned(Value::Dictionary(map))
412        } else {
413            Cow::Borrowed(value)
414        }
415    }
416
417    fn sub_hashmap(value: &Value, keys: &Vec<&str>) -> HashMap<String, Value> {
418        let map = value.as_dictionary().unwrap();
419        let mut retval = HashMap::new();
420        for key in keys {
421            retval.insert(key.to_string(), map.get(*key).cloned().unwrap_or(Value::Null));
422        }
423        retval
424    }
425
426    fn merge_distinct(value1: Option<&Vec<Value>>, value2: Option<Vec<String>>) -> Option<Vec<String>> {
427        let mut result: Vec<String> = vec![];
428        if let Some(value1) = value1 {
429            for v in value1 {
430                result.push(v.as_str().unwrap().to_string());
431            }
432        }
433        if let Some(value2) = value2 {
434            for v in value2 {
435                result.push(v.to_string())
436            }
437        }
438        if result.is_empty() {
439            None
440        } else {
441            Some(result)
442        }
443    }
444}