Skip to main content

teaql_sql/
types.rs

1use teaql_core::{DataType, Value};
2
3#[derive(Debug, Clone, Copy, PartialEq, Eq)]
4pub enum DatabaseKind {
5    PostgreSql,
6    Sqlite,
7    MySql,
8}
9
10#[derive(Debug, Clone, PartialEq)]
11pub struct CompiledQuery {
12    pub sql: String,
13    pub params: Vec<Value>,
14    pub comment: Option<String>,
15}
16
17impl CompiledQuery {
18    pub fn sql_with_comment(&self) -> String {
19        match &self.comment {
20            Some(comment) if !comment.is_empty() => {
21                let escaped = comment.replace("*/", "* /");
22                format!("/* {escaped} */ {}", self.sql)
23            }
24            _ => self.sql.clone(),
25        }
26    }
27
28    pub fn debug_sql(&self, kind: DatabaseKind) -> String {
29        match kind {
30            DatabaseKind::PostgreSql => replace_postgres_placeholders(&self.sql, &self.params),
31            DatabaseKind::Sqlite => replace_positional_placeholders(&self.sql, &self.params, DatabaseKind::Sqlite),
32            DatabaseKind::MySql => replace_positional_placeholders(&self.sql, &self.params, DatabaseKind::MySql),
33        }
34    }
35}
36
37fn replace_postgres_placeholders(sql: &str, params: &[Value]) -> String {
38    let mut output = String::with_capacity(sql.len());
39    let mut chars = sql.chars().peekable();
40    let mut in_string = false;
41    while let Some(ch) = chars.next() {
42        if ch == '\'' {
43            output.push(ch);
44            if in_string && matches!(chars.peek(), Some('\'')) {
45                output.push(chars.next().expect("peeked quote must exist"));
46            } else {
47                in_string = !in_string;
48            }
49            continue;
50        }
51        if !in_string && ch == '$' && chars.peek().is_some_and(|next| next.is_ascii_digit()) {
52            let mut index = String::new();
53            while let Some(next) = chars.peek().copied().filter(char::is_ascii_digit) {
54                index.push(next);
55                chars.next();
56            }
57            if let Ok(index) = index.parse::<usize>() {
58                if let Some(value) = index.checked_sub(1).and_then(|idx| params.get(idx)) {
59                    output.push_str(&sql_literal(value, DatabaseKind::PostgreSql));
60                    continue;
61                }
62            }
63            output.push('$');
64            output.push_str(&index);
65            continue;
66        }
67        output.push(ch);
68    }
69    output
70}
71
72fn replace_positional_placeholders(sql: &str, params: &[Value], kind: DatabaseKind) -> String {
73    let mut output = String::with_capacity(sql.len());
74    let mut params = params.iter();
75    let mut in_string = false;
76    let mut chars = sql.chars().peekable();
77    while let Some(ch) = chars.next() {
78        if ch == '\'' {
79            output.push(ch);
80            if in_string && matches!(chars.peek(), Some('\'')) {
81                output.push(chars.next().expect("peeked quote must exist"));
82            } else {
83                in_string = !in_string;
84            }
85            continue;
86        }
87        if !in_string && ch == '?' {
88            if let Some(value) = params.next() {
89                output.push_str(&sql_literal(value, kind));
90            } else {
91                output.push(ch);
92            }
93            continue;
94        }
95        output.push(ch);
96    }
97    output
98}
99
100fn sql_literal(value: &Value, kind: DatabaseKind) -> String {
101    match value {
102        Value::Null => "NULL".to_owned(),
103        Value::Bool(value) => if *value { "TRUE" } else { "FALSE" }.to_owned(),
104        Value::I64(value) => value.to_string(),
105        Value::U64(value) => value.to_string(),
106        Value::F64(value) => value.to_string(),
107        Value::Decimal(value) => value.to_string(),
108        Value::Text(value) => quoted_sql_string(value),
109        Value::Json(value) => quoted_sql_string(&value.to_string()),
110        Value::Date(value) => quoted_sql_string(&value.to_string()),
111        Value::Timestamp(value) => quoted_sql_string(&value.to_rfc3339()),
112        Value::Object(value) => {
113            quoted_sql_string(&Value::Object(value.clone()).to_json_value().to_string())
114        }
115        Value::List(values) => {
116            let values = values
117                .iter()
118                .map(|v| sql_literal(v, kind))
119                .collect::<Vec<_>>()
120                .join(", ");
121            match kind {
122                DatabaseKind::PostgreSql => format!("ARRAY[{values}]"),
123                _ => format!("({values})"),
124            }
125        }
126    }
127}
128
129fn quoted_sql_string(value: &str) -> String {
130    format!("'{}'", value.replace('\'', "''"))
131}
132
133#[derive(Debug, Clone, PartialEq, Eq)]
134pub enum SqlCompileError {
135    UnknownEntity(String),
136    UnknownField(String),
137    EmptyInList,
138    MissingIdProperty(String),
139    MissingVersionProperty(String),
140    EmptyMutation(String),
141    InvalidRecoverVersion(i64),
142    UnsupportedSchemaType(DataType),
143    InvalidFunctionArguments(String),
144    InvalidSubQueryOperator(String),
145}
146
147impl std::fmt::Display for SqlCompileError {
148    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
149        match self {
150            Self::UnknownEntity(entity) => write!(f, "unknown entity: {entity}"),
151            Self::UnknownField(field) => write!(f, "unknown field: {field}"),
152            Self::EmptyInList => write!(f, "IN requires at least one value"),
153            Self::MissingIdProperty(entity) => write!(f, "entity {entity} has no id property"),
154            Self::MissingVersionProperty(entity) => {
155                write!(f, "entity {entity} has no version property")
156            }
157            Self::EmptyMutation(kind) => write!(f, "{kind} requires at least one writable field"),
158            Self::InvalidRecoverVersion(version) => {
159                write!(f, "recover requires a negative version, got {version}")
160            }
161            Self::UnsupportedSchemaType(data_type) => {
162                write!(f, "unsupported schema type: {data_type:?}")
163            }
164            Self::InvalidFunctionArguments(message) => write!(f, "{message}"),
165            Self::InvalidSubQueryOperator(operator) => {
166                write!(f, "subquery does not support operator: {operator}")
167            }
168        }
169    }
170}
171
172impl std::error::Error for SqlCompileError {}