1use base64::prelude::*;
2use serde::Serialize;
3use wstd::http::body::IntoBody;
4use wstd::http::{Client, Request};
5
6use crate::wit::trailbase::runtime::host_endpoint::{
7 tx_begin, tx_commit, tx_execute, tx_query, tx_rollback,
8};
9
10pub use crate::wit::trailbase::runtime::host_endpoint::{TxError, Value};
11pub use trailbase_wasm_common::{SqliteRequest, SqliteResponse};
12
13pub struct Transaction {
14 committed: bool,
15}
16
17impl Transaction {
18 pub fn begin() -> Result<Self, TxError> {
19 tx_begin()?;
20 return Ok(Self { committed: false });
21 }
22
23 pub fn query(&mut self, query: &str, params: &[Value]) -> Result<Vec<Vec<Value>>, TxError> {
24 return tx_query(query, params);
25 }
26
27 pub fn execute(&mut self, query: &str, params: &[Value]) -> Result<u64, TxError> {
28 return tx_execute(query, params);
29 }
30
31 pub fn commit(&mut self) -> Result<(), TxError> {
32 if !self.committed {
33 self.committed = true;
34 tx_commit()?;
35 }
36 return Ok(());
37 }
38}
39
40impl Drop for Transaction {
41 fn drop(&mut self) {
42 if !self.committed {
43 if let Err(err) = tx_rollback() {
44 log::warn!("TX rollback failed: {err}");
45 }
46 }
47 }
48}
49
50#[derive(Debug, thiserror::Error)]
51pub enum Error {
52 #[error("Unexpected Type")]
53 UnexpectedType,
54 #[error("Not a Number")]
55 NotANumber,
56 #[error("Decoding")]
57 Decording(#[from] base64::DecodeError),
58 #[error("Other: {0}")]
59 Other(String),
60}
61
62pub async fn query(
63 query: impl std::string::ToString,
64 params: impl Into<Vec<Value>>,
65) -> Result<Vec<Vec<Value>>, Error> {
66 let r = SqliteRequest {
67 query: query.to_string(),
68 params: params.into().into_iter().map(to_json_value).collect(),
69 };
70 let request = Request::builder()
71 .uri("http://__sqlite/query")
72 .method("POST")
73 .body(
74 serde_json::to_vec(&r)
75 .map_err(|_| Error::UnexpectedType)?
76 .into_body(),
77 )
78 .map_err(|err| Error::Other(err.to_string()))?;
79
80 let client = Client::new();
81 let (_parts, mut body) = client
82 .send(request)
83 .await
84 .map_err(|err| Error::Other(err.to_string()))?
85 .into_parts();
86
87 let bytes = body
88 .bytes()
89 .await
90 .map_err(|err| Error::Other(err.to_string()))?;
91
92 return match serde_json::from_slice(&bytes) {
93 Ok(SqliteResponse::Query { rows }) => Ok(
94 rows
95 .into_iter()
96 .map(|row| {
97 row
98 .into_iter()
99 .map(from_json_value)
100 .collect::<Result<Vec<_>, _>>()
101 })
102 .collect::<Result<Vec<_>, _>>()?,
103 ),
104 Ok(_) => Err(Error::UnexpectedType),
105 Err(err) => Err(Error::Other(err.to_string())),
106 };
107}
108
109pub async fn execute(
110 query: impl std::string::ToString,
111 params: impl Into<Vec<Value>>,
112) -> Result<usize, Error> {
113 let r = SqliteRequest {
114 query: query.to_string(),
115 params: params.into().into_iter().map(to_json_value).collect(),
116 };
117 let request = Request::builder()
118 .uri("http://__sqlite/execute")
119 .method("POST")
120 .body(
121 serde_json::to_vec(&r)
122 .map_err(|_| Error::UnexpectedType)?
123 .into_body(),
124 )
125 .map_err(|err| Error::Other(err.to_string()))?;
126
127 let client = Client::new();
128 let (_parts, mut body) = client
129 .send(request)
130 .await
131 .map_err(|err| Error::Other(err.to_string()))?
132 .into_parts();
133
134 let bytes = body
135 .bytes()
136 .await
137 .map_err(|err| Error::Other(err.to_string()))?;
138
139 return match serde_json::from_slice(&bytes) {
140 Ok(SqliteResponse::Execute { rows_affected }) => Ok(rows_affected),
141 Ok(_) => Err(Error::UnexpectedType),
142 Err(err) => Err(Error::Other(err.to_string())),
143 };
144}
145
146fn from_json_value(value: serde_json::Value) -> Result<Value, Error> {
147 return match value {
148 serde_json::Value::Null => Ok(Value::Null),
149 serde_json::Value::String(s) => Ok(Value::Text(s)),
150 serde_json::Value::Object(mut map) => match map.remove("blob") {
151 Some(serde_json::Value::String(str)) => Ok(Value::Blob(BASE64_URL_SAFE.decode(&str)?)),
152 _ => Err(Error::UnexpectedType),
153 },
154 serde_json::Value::Number(n) => {
155 if let Some(n) = n.as_i64() {
156 Ok(Value::Integer(n))
157 } else if let Some(n) = n.as_u64() {
158 Ok(Value::Integer(n as i64))
159 } else if let Some(n) = n.as_f64() {
160 Ok(Value::Real(n))
161 } else {
162 Err(Error::NotANumber)
163 }
164 }
165 _ => Err(Error::UnexpectedType),
166 };
167}
168
169#[derive(Serialize)]
170struct Blob {
171 blob: String,
172}
173
174impl serde::ser::Serialize for Value {
175 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
176 where
177 S: serde::ser::Serializer,
178 {
179 return match self {
180 Value::Null => serializer.serialize_unit(),
181 Value::Text(s) => serializer.serialize_str(s),
182 Value::Integer(i) => serializer.serialize_i64(*i),
183 Value::Real(f) => serializer.serialize_f64(*f),
184 Value::Blob(blob) => serializer.serialize_some(&Blob {
185 blob: BASE64_URL_SAFE.encode(blob),
186 }),
187 };
188 }
189}
190
191pub fn to_json_value(value: Value) -> serde_json::Value {
192 return match value {
193 Value::Null => serde_json::Value::Null,
194 Value::Text(s) => serde_json::Value::String(s),
195 Value::Integer(i) => serde_json::Value::Number(serde_json::Number::from(i)),
196 Value::Real(f) => match serde_json::Number::from_f64(f) {
197 Some(n) => serde_json::Value::Number(n),
198 None => serde_json::Value::Null,
199 },
200 Value::Blob(blob) => serde_json::json!({
201 "blob": BASE64_URL_SAFE.encode(blob)
202 }),
203 };
204}