1use trailbase_sqlvalue::{Blob, DecodeError, SqlValue};
2use wstd::http::body::IntoBody;
3use wstd::http::{Client, Request};
4
5use crate::wit::trailbase::database::sqlite::Transaction as WasiTransaction;
6
7pub use crate::wit::trailbase::database::sqlite::{TxError, Value};
8pub use trailbase_wasm_common::{SqliteRequest, SqliteResponse};
9
10pub fn escape(s: impl AsRef<str>) -> String {
12 let input = s.as_ref();
13 let mut buf = String::with_capacity(input.len() * 2 + 2);
14
15 buf.push('\'');
16 for b in input.chars() {
17 match b {
18 '\'' => buf.push_str("''"),
19 '\0' => buf.push_str("\\0"),
21 _ => buf.push(b),
22 }
23 }
24 buf.push('\'');
25
26 return buf;
27}
28
29pub struct Transaction {
30 tx: WasiTransaction,
31 committed: bool,
32}
33
34impl Transaction {
35 pub fn begin() -> Result<Self, TxError> {
36 return Ok(Self {
37 tx: WasiTransaction::new(),
38 committed: false,
39 });
40 }
41
42 pub fn query(&mut self, query: &str, params: &[Value]) -> Result<Vec<Vec<Value>>, TxError> {
43 return self.tx.query(query, params);
44 }
45
46 pub fn execute(&mut self, query: &str, params: &[Value]) -> Result<u64, TxError> {
47 return self.tx.execute(query, params);
48 }
49
50 pub fn commit(&mut self) -> Result<(), TxError> {
51 if !self.committed {
52 self.committed = true;
53 self.tx.commit()?;
54 }
55 return Ok(());
56 }
57}
58
59impl Drop for Transaction {
60 fn drop(&mut self) {
61 if !self.committed
62 && let Err(err) = self.tx.rollback()
63 {
64 log::warn!("TX rollback failed: {err}");
65 }
66 }
67}
68
69#[derive(Debug, thiserror::Error)]
70#[non_exhaustive]
71pub enum Error {
72 #[error("Unexpected Type: {0}")]
73 UnexpectedType(Box<dyn std::error::Error>),
74 #[error("Decoding: {0}")]
75 Decoding(Box<dyn std::error::Error>),
76 #[error("Other: {0}")]
77 Other(Box<dyn std::error::Error>),
78}
79
80impl From<DecodeError> for Error {
81 fn from(err: DecodeError) -> Self {
82 return Self::Decoding(err.into());
83 }
84}
85
86impl From<serde_json::Error> for Error {
87 fn from(err: serde_json::Error) -> Self {
88 return Self::Decoding(err.into());
89 }
90}
91
92pub async fn query(
93 query: impl std::string::ToString,
94 params: impl Into<Vec<Value>>,
95) -> Result<Vec<Vec<Value>>, Error> {
96 let r = SqliteRequest {
97 query: query.to_string(),
98 params: params.into().into_iter().map(to_sql_value).collect(),
99 };
100 let request = Request::builder()
101 .uri("http://__sqlite/query")
102 .method("POST")
103 .body(serde_json::to_vec(&r)?.into_body())
104 .map_err(|err| Error::Other(err.into()))?;
105
106 let client = Client::new();
107 let (_parts, mut body) = client
108 .send(request)
109 .await
110 .map_err(|err| Error::Other(err.into()))?
111 .into_parts();
112
113 let bytes = body.bytes().await.map_err(|err| Error::Other(err.into()))?;
114
115 return match serde_json::from_slice(&bytes) {
116 Ok(SqliteResponse::Query { rows }) => Ok(
117 rows
118 .into_iter()
119 .map(|row| {
120 row
121 .into_iter()
122 .map(from_sql_value)
123 .collect::<Result<Vec<_>, _>>()
124 })
125 .collect::<Result<Vec<_>, _>>()?,
126 ),
127 Ok(SqliteResponse::Error(err)) => Err(Error::Other(err.into())),
128 Ok(resp) => Err(Error::UnexpectedType(
129 format!("Expected QueryResponse, got: {resp:?}").into(),
130 )),
131 Err(err) => Err(Error::Other(err.into())),
132 };
133}
134
135pub async fn execute(
136 query: impl std::string::ToString,
137 params: impl Into<Vec<Value>>,
138) -> Result<usize, Error> {
139 let r = SqliteRequest {
140 query: query.to_string(),
141 params: params.into().into_iter().map(to_sql_value).collect(),
142 };
143 let request = Request::builder()
144 .uri("http://__sqlite/execute")
145 .method("POST")
146 .body(serde_json::to_vec(&r)?.into_body())
147 .map_err(|err| Error::Other(err.into()))?;
148
149 let client = Client::new();
150 let (_parts, mut body) = client
151 .send(request)
152 .await
153 .map_err(|err| Error::Other(err.into()))?
154 .into_parts();
155
156 let bytes = body.bytes().await.map_err(|err| Error::Other(err.into()))?;
157
158 return match serde_json::from_slice(&bytes) {
159 Ok(SqliteResponse::Execute { rows_affected }) => Ok(rows_affected),
160 Ok(SqliteResponse::Error(err)) => Err(Error::Other(err.into())),
161 Ok(resp) => Err(Error::UnexpectedType(
162 format!("Expected ExecuteResponse, got: {resp:?}").into(),
163 )),
164 Err(err) => Err(Error::Other(err.into())),
165 };
166}
167
168fn from_sql_value(value: SqlValue) -> Result<Value, DecodeError> {
169 return match value {
170 SqlValue::Null => Ok(Value::Null),
171 SqlValue::Integer(v) => Ok(Value::Integer(v)),
172 SqlValue::Real(v) => Ok(Value::Real(v)),
173 SqlValue::Text(v) => Ok(Value::Text(v)),
174 SqlValue::Blob(v) => Ok(Value::Blob(v.into_bytes()?)),
175 };
176}
177
178pub fn to_sql_value(value: Value) -> SqlValue {
179 return match value {
180 Value::Null => SqlValue::Null,
181 Value::Text(s) => SqlValue::Text(s),
182 Value::Integer(i) => SqlValue::Integer(i),
183 Value::Real(f) => SqlValue::Real(f),
184 Value::Blob(b) => SqlValue::Blob(Blob::Array(b)),
185 };
186}
187
188#[cfg(test)]
189mod tests {
190 use super::*;
191
192 #[test]
193 fn escape_test() {
194 assert_eq!("'foo'", escape("foo"));
195 assert_eq!("'f''oo'", escape("f'oo"));
196 assert_eq!("'foo\\0more'", escape("foo\0more"));
197 }
198}