Skip to main content

trailbase_wasm/
db.rs

1use trailbase_sqlvalue::{Blob, DecodeError, SqlValue};
2use wstd::http::body::IntoBody;
3use wstd::http::{Client, Request};
4
5use crate::wit::trailbase::database::sqlite::Transaction as WasiTransaction;
6
7pub use crate::wit::trailbase::database::sqlite::{TxError, Value};
8pub use trailbase_wasm_common::{SqliteRequest, SqliteResponse};
9
10/// Escapes arbitrary strings as a safe SQL string literal, e.g. 'foo'.
11pub fn escape(s: impl AsRef<str>) -> String {
12  let input = s.as_ref();
13  let mut buf = String::with_capacity(input.len() * 2 + 2);
14
15  buf.push('\'');
16  for b in input.chars() {
17    match b {
18      '\'' => buf.push_str("''"),
19      // Not strictly an injection risk, just being defensive here for downstream consumers.
20      '\0' => buf.push_str("\\0"),
21      _ => buf.push(b),
22    }
23  }
24  buf.push('\'');
25
26  return buf;
27}
28
29pub struct Transaction {
30  tx: WasiTransaction,
31  committed: bool,
32}
33
34impl Transaction {
35  pub fn begin() -> Result<Self, TxError> {
36    return Ok(Self {
37      tx: WasiTransaction::new(),
38      committed: false,
39    });
40  }
41
42  pub fn query(&mut self, query: &str, params: &[Value]) -> Result<Vec<Vec<Value>>, TxError> {
43    return self.tx.query(query, params);
44  }
45
46  pub fn execute(&mut self, query: &str, params: &[Value]) -> Result<u64, TxError> {
47    return self.tx.execute(query, params);
48  }
49
50  pub fn commit(&mut self) -> Result<(), TxError> {
51    if !self.committed {
52      self.committed = true;
53      self.tx.commit()?;
54    }
55    return Ok(());
56  }
57}
58
59impl Drop for Transaction {
60  fn drop(&mut self) {
61    if !self.committed
62      && let Err(err) = self.tx.rollback()
63    {
64      log::warn!("TX rollback failed: {err}");
65    }
66  }
67}
68
69#[derive(Debug, thiserror::Error)]
70#[non_exhaustive]
71pub enum Error {
72  #[error("Unexpected Type: {0}")]
73  UnexpectedType(Box<dyn std::error::Error>),
74  #[error("Decoding: {0}")]
75  Decoding(Box<dyn std::error::Error>),
76  #[error("Other: {0}")]
77  Other(Box<dyn std::error::Error>),
78}
79
80impl From<DecodeError> for Error {
81  fn from(err: DecodeError) -> Self {
82    return Self::Decoding(err.into());
83  }
84}
85
86impl From<serde_json::Error> for Error {
87  fn from(err: serde_json::Error) -> Self {
88    return Self::Decoding(err.into());
89  }
90}
91
92pub async fn query(
93  query: impl std::string::ToString,
94  params: impl Into<Vec<Value>>,
95) -> Result<Vec<Vec<Value>>, Error> {
96  let r = SqliteRequest {
97    query: query.to_string(),
98    params: params.into().into_iter().map(to_sql_value).collect(),
99  };
100  let request = Request::builder()
101    .uri("http://__sqlite/query")
102    .method("POST")
103    .body(serde_json::to_vec(&r)?.into_body())
104    .map_err(|err| Error::Other(err.into()))?;
105
106  let client = Client::new();
107  let (_parts, mut body) = client
108    .send(request)
109    .await
110    .map_err(|err| Error::Other(err.into()))?
111    .into_parts();
112
113  let bytes = body.bytes().await.map_err(|err| Error::Other(err.into()))?;
114
115  return match serde_json::from_slice(&bytes) {
116    Ok(SqliteResponse::Query { rows }) => Ok(
117      rows
118        .into_iter()
119        .map(|row| {
120          row
121            .into_iter()
122            .map(from_sql_value)
123            .collect::<Result<Vec<_>, _>>()
124        })
125        .collect::<Result<Vec<_>, _>>()?,
126    ),
127    Ok(SqliteResponse::Error(err)) => Err(Error::Other(err.into())),
128    Ok(resp) => Err(Error::UnexpectedType(
129      format!("Expected QueryResponse, got: {resp:?}").into(),
130    )),
131    Err(err) => Err(Error::Other(err.into())),
132  };
133}
134
135pub async fn execute(
136  query: impl std::string::ToString,
137  params: impl Into<Vec<Value>>,
138) -> Result<usize, Error> {
139  let r = SqliteRequest {
140    query: query.to_string(),
141    params: params.into().into_iter().map(to_sql_value).collect(),
142  };
143  let request = Request::builder()
144    .uri("http://__sqlite/execute")
145    .method("POST")
146    .body(serde_json::to_vec(&r)?.into_body())
147    .map_err(|err| Error::Other(err.into()))?;
148
149  let client = Client::new();
150  let (_parts, mut body) = client
151    .send(request)
152    .await
153    .map_err(|err| Error::Other(err.into()))?
154    .into_parts();
155
156  let bytes = body.bytes().await.map_err(|err| Error::Other(err.into()))?;
157
158  return match serde_json::from_slice(&bytes) {
159    Ok(SqliteResponse::Execute { rows_affected }) => Ok(rows_affected),
160    Ok(SqliteResponse::Error(err)) => Err(Error::Other(err.into())),
161    Ok(resp) => Err(Error::UnexpectedType(
162      format!("Expected ExecuteResponse, got: {resp:?}").into(),
163    )),
164    Err(err) => Err(Error::Other(err.into())),
165  };
166}
167
168fn from_sql_value(value: SqlValue) -> Result<Value, DecodeError> {
169  return match value {
170    SqlValue::Null => Ok(Value::Null),
171    SqlValue::Integer(v) => Ok(Value::Integer(v)),
172    SqlValue::Real(v) => Ok(Value::Real(v)),
173    SqlValue::Text(v) => Ok(Value::Text(v)),
174    SqlValue::Blob(v) => Ok(Value::Blob(v.into_bytes()?)),
175  };
176}
177
178pub fn to_sql_value(value: Value) -> SqlValue {
179  return match value {
180    Value::Null => SqlValue::Null,
181    Value::Text(s) => SqlValue::Text(s),
182    Value::Integer(i) => SqlValue::Integer(i),
183    Value::Real(f) => SqlValue::Real(f),
184    Value::Blob(b) => SqlValue::Blob(Blob::Array(b)),
185  };
186}
187
188#[cfg(test)]
189mod tests {
190  use super::*;
191
192  #[test]
193  fn escape_test() {
194    assert_eq!("'foo'", escape("foo"));
195    assert_eq!("'f''oo'", escape("f'oo"));
196    assert_eq!("'foo\\0more'", escape("foo\0more"));
197  }
198}