1use teaql_core::{DataType, Value};
2
3#[derive(Debug, Clone, Copy, PartialEq, Eq)]
4pub enum DatabaseKind {
5 PostgreSql,
6 Sqlite,
7 MySql,
8}
9
10#[derive(Debug, Clone, PartialEq)]
11pub struct CompiledQuery {
12 pub sql: String,
13 pub params: Vec<Value>,
14 pub comment: Option<String>,
15}
16
17impl CompiledQuery {
18 pub fn sql_with_comment(&self) -> String {
19 match &self.comment {
20 Some(comment) if !comment.is_empty() => {
21 let escaped = comment.replace("*/", "* /");
22 format!("/* {escaped} */ {}", self.sql)
23 }
24 _ => self.sql.clone(),
25 }
26 }
27
28 pub fn debug_sql(&self, kind: DatabaseKind) -> String {
29 match kind {
30 DatabaseKind::PostgreSql => replace_postgres_placeholders(&self.sql, &self.params),
31 DatabaseKind::Sqlite => replace_positional_placeholders(&self.sql, &self.params, DatabaseKind::Sqlite),
32 DatabaseKind::MySql => replace_positional_placeholders(&self.sql, &self.params, DatabaseKind::MySql),
33 }
34 }
35}
36
37fn replace_postgres_placeholders(sql: &str, params: &[Value]) -> String {
38 let mut output = String::with_capacity(sql.len());
39 let mut chars = sql.chars().peekable();
40 let mut in_string = false;
41 while let Some(ch) = chars.next() {
42 if ch == '\'' {
43 output.push(ch);
44 if in_string && matches!(chars.peek(), Some('\'')) {
45 output.push(chars.next().expect("peeked quote must exist"));
46 } else {
47 in_string = !in_string;
48 }
49 continue;
50 }
51 if !in_string && ch == '$' && chars.peek().is_some_and(|next| next.is_ascii_digit()) {
52 let mut index = String::new();
53 while let Some(next) = chars.peek().copied().filter(char::is_ascii_digit) {
54 index.push(next);
55 chars.next();
56 }
57 if let Ok(index) = index.parse::<usize>() {
58 if let Some(value) = index.checked_sub(1).and_then(|idx| params.get(idx)) {
59 output.push_str(&sql_literal(value, DatabaseKind::PostgreSql));
60 continue;
61 }
62 }
63 output.push('$');
64 output.push_str(&index);
65 continue;
66 }
67 output.push(ch);
68 }
69 output
70}
71
72fn replace_positional_placeholders(sql: &str, params: &[Value], kind: DatabaseKind) -> String {
73 let mut output = String::with_capacity(sql.len());
74 let mut params = params.iter();
75 let mut in_string = false;
76 let mut chars = sql.chars().peekable();
77 while let Some(ch) = chars.next() {
78 if ch == '\'' {
79 output.push(ch);
80 if in_string && matches!(chars.peek(), Some('\'')) {
81 output.push(chars.next().expect("peeked quote must exist"));
82 } else {
83 in_string = !in_string;
84 }
85 continue;
86 }
87 if !in_string && ch == '?' {
88 if let Some(value) = params.next() {
89 output.push_str(&sql_literal(value, kind));
90 } else {
91 output.push(ch);
92 }
93 continue;
94 }
95 output.push(ch);
96 }
97 output
98}
99
100fn sql_literal(value: &Value, kind: DatabaseKind) -> String {
101 match value {
102 Value::Null => "NULL".to_owned(),
103 Value::Bool(value) => if *value { "TRUE" } else { "FALSE" }.to_owned(),
104 Value::I64(value) => value.to_string(),
105 Value::U64(value) => value.to_string(),
106 Value::F64(value) => value.to_string(),
107 Value::Decimal(value) => value.to_string(),
108 Value::Text(value) => quoted_sql_string(value),
109 Value::Json(value) => quoted_sql_string(&value.to_string()),
110 Value::Date(value) => quoted_sql_string(&value.to_string()),
111 Value::Timestamp(value) => quoted_sql_string(&value.to_rfc3339()),
112 Value::Object(value) => {
113 quoted_sql_string(&Value::Object(value.clone()).to_json_value().to_string())
114 }
115 Value::List(values) => {
116 let values = values
117 .iter()
118 .map(|v| sql_literal(v, kind))
119 .collect::<Vec<_>>()
120 .join(", ");
121 match kind {
122 DatabaseKind::PostgreSql => format!("ARRAY[{values}]"),
123 _ => format!("({values})"),
124 }
125 }
126 }
127}
128
129fn quoted_sql_string(value: &str) -> String {
130 format!("'{}'", value.replace('\'', "''"))
131}
132
133#[derive(Debug, Clone, PartialEq, Eq)]
134pub enum SqlCompileError {
135 UnknownEntity(String),
136 UnknownField(String),
137 EmptyInList,
138 MissingIdProperty(String),
139 MissingVersionProperty(String),
140 EmptyMutation(String),
141 InvalidRecoverVersion(i64),
142 UnsupportedSchemaType(DataType),
143 InvalidFunctionArguments(String),
144 InvalidSubQueryOperator(String),
145}
146
147impl std::fmt::Display for SqlCompileError {
148 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
149 match self {
150 Self::UnknownEntity(entity) => write!(f, "unknown entity: {entity}"),
151 Self::UnknownField(field) => write!(f, "unknown field: {field}"),
152 Self::EmptyInList => write!(f, "IN requires at least one value"),
153 Self::MissingIdProperty(entity) => write!(f, "entity {entity} has no id property"),
154 Self::MissingVersionProperty(entity) => {
155 write!(f, "entity {entity} has no version property")
156 }
157 Self::EmptyMutation(kind) => write!(f, "{kind} requires at least one writable field"),
158 Self::InvalidRecoverVersion(version) => {
159 write!(f, "recover requires a negative version, got {version}")
160 }
161 Self::UnsupportedSchemaType(data_type) => {
162 write!(f, "unsupported schema type: {data_type:?}")
163 }
164 Self::InvalidFunctionArguments(message) => write!(f, "{message}"),
165 Self::InvalidSubQueryOperator(operator) => {
166 write!(f, "subquery does not support operator: {operator}")
167 }
168 }
169 }
170}
171
172impl std::error::Error for SqlCompileError {}