1use crate::{
18 StateBackendError, StateBackendErrorKind, StateKey, StateOperatorBackend,
19 StateOperatorBackendFactory,
20};
21use async_trait::async_trait;
22use serde::{Deserialize, Serialize};
23use sqlx::pool::PoolOptions;
24use sqlx::sqlite::SqliteConnectOptions;
25use sqlx::{Row, SqlitePool};
26use std::fmt::Debug;
27use std::str::FromStr;
28use std::sync::Arc;
29use tracing::info;
30
31const DEFAULT_MAX_CONNECTIONS: u32 = 10;
32const DEFAULT_TABLE_NAME: &str = "state";
33
34pub struct SqliteStateOperatorBackendFactory {
35 pool: Arc<SqlitePool>,
36 state_table_name: String,
37}
38
39impl SqliteStateOperatorBackendFactory {
40 pub async fn new(
41 database_path: String,
42 max_connections: Option<u32>,
43 state_table_name: Option<String>,
44 ) -> Result<Self, StateBackendError> {
45 let state_table_name = state_table_name.unwrap_or_else(|| DEFAULT_TABLE_NAME.to_string());
46
47 let options = SqliteConnectOptions::from_str(format!("sqlite:{}", database_path).as_str())
48 .unwrap()
49 .create_if_missing(true);
50
51 let pool = PoolOptions::<sqlx::Sqlite>::new()
52 .max_connections(max_connections.unwrap_or(DEFAULT_MAX_CONNECTIONS))
53 .connect_with(options)
54 .await
55 .map_err(|e| {
56 StateBackendError::with_source(
57 StateBackendErrorKind::Connection,
58 "failed to create SQLite connection pool",
59 e,
60 )
61 })?;
62
63 let pool = Arc::new(pool);
64
65 Self::initialize(pool.clone(), &state_table_name).await?;
66
67 Ok(Self {
68 pool,
69 state_table_name,
70 })
71 }
72
73 async fn initialize(
74 pool: Arc<SqlitePool>,
75 state_table_name: &str,
76 ) -> Result<(), StateBackendError> {
77 sqlx::query(
78 format!(
79 r#"
80 CREATE TABLE IF NOT EXISTS {} (
81 namespace TEXT,
82 key TEXT,
83 data TEXT NOT NULL,
84 created_at DATETIME DEFAULT CURRENT_TIMESTAMP,
85 PRIMARY KEY(namespace, key)
86 );
87 "#,
88 state_table_name
89 )
90 .as_str(),
91 )
92 .execute(pool.as_ref())
93 .await
94 .map(|_| ())
95 .map_err(|e| {
96 StateBackendError::with_source(
97 StateBackendErrorKind::Initialization,
98 "failed to create state table",
99 e,
100 )
101 })
102 }
103}
104
105impl StateOperatorBackendFactory for SqliteStateOperatorBackendFactory {
106 fn create<V>(&self, namespace: &str) -> Arc<dyn StateOperatorBackend<V>>
107 where
108 V: Serialize + for<'de> Deserialize<'de> + Send + Sync + Unpin + Clone + Debug + 'static,
109 {
110 Arc::new(SqliteStateOperatorBackend::new(
111 self.pool.clone(),
112 self.state_table_name.clone(),
113 namespace,
114 ))
115 }
116}
117
118#[derive(Debug)]
119struct SqliteStateOperatorBackend {
120 pool: Arc<SqlitePool>,
121 state_table_name: String,
122 namespace: String,
123}
124
125impl SqliteStateOperatorBackend {
126 fn new(pool: Arc<SqlitePool>, state_table_name: String, namespace: &str) -> Self {
127 info!(
128 "Creating a new SQLite JSON state backend for namespace: {}",
129 namespace
130 );
131
132 Self {
133 pool,
134 state_table_name,
135 namespace: namespace.to_string(),
136 }
137 }
138}
139
140#[async_trait]
141impl<V> StateOperatorBackend<V> for SqliteStateOperatorBackend
142where
143 V: Serialize + for<'de> Deserialize<'de> + Send + Sync + Unpin + Debug + 'static,
144{
145 async fn get(&self, key: StateKey) -> Result<Option<V>, StateBackendError> {
146 let result = sqlx::query(
147 format!(
148 r#"
149 SELECT data
150 FROM {}
151 WHERE namespace = ? AND key = ?
152 "#,
153 self.state_table_name
154 )
155 .as_str(),
156 )
157 .bind(&self.namespace)
158 .bind(&key.0)
159 .fetch_optional(self.pool.as_ref())
160 .await
161 .map_err(|e| {
162 StateBackendError::with_source(StateBackendErrorKind::Query, "failed to fetch state", e)
163 })?;
164
165 if result.is_none() {
166 return Ok(None);
167 }
168
169 let data = result.unwrap();
170 let json_str: String = data.try_get(0).map_err(|e| {
171 StateBackendError::with_source(
172 StateBackendErrorKind::Query,
173 "failed to read data column",
174 e,
175 )
176 })?;
177
178 serde_json::from_str(&json_str).map(Some).map_err(|e| {
179 StateBackendError::with_source(
180 StateBackendErrorKind::Serialization,
181 "failed to deserialize state",
182 e,
183 )
184 })
185 }
186
187 async fn put(&self, key: StateKey, value: V) -> Result<(), StateBackendError> {
188 let json_str = serde_json::to_string(&value).unwrap();
189 sqlx::query(
190 format!(
191 r#"
192 INSERT INTO {} (namespace, key, data, created_at)
193 VALUES (?, ?, ?, CURRENT_TIMESTAMP)
194 ON CONFLICT(namespace, key) DO UPDATE SET data = excluded.data
195 "#,
196 self.state_table_name
197 )
198 .as_str(),
199 )
200 .bind(&self.namespace)
201 .bind(&key.0)
202 .bind(&json_str)
203 .execute(self.pool.as_ref())
204 .await
205 .map(|_| ())
206 .map_err(|e| {
207 StateBackendError::with_source(
208 StateBackendErrorKind::Query,
209 "failed to update state",
210 e,
211 )
212 })
213 }
214
215 async fn remove(&self, key: StateKey) -> Result<(), StateBackendError> {
216 sqlx::query(
217 format!(
218 r#"
219 DELETE FROM {}
220 WHERE namespace = ? AND key = ?
221 "#,
222 self.state_table_name
223 )
224 .as_str(),
225 )
226 .bind(&self.namespace)
227 .bind(&key.0)
228 .execute(self.pool.as_ref())
229 .await
230 .map(|_| ())
231 .map_err(|e| {
232 StateBackendError::with_source(
233 StateBackendErrorKind::Query,
234 "failed to remove state",
235 e,
236 )
237 })
238 }
239
240 async fn clear(&self) -> Result<(), StateBackendError> {
241 sqlx::query(
242 format!(
243 r#"
244 DELETE FROM {}
245 WHERE namespace = ?
246 "#,
247 self.state_table_name
248 )
249 .as_str(),
250 )
251 .bind(&self.namespace)
252 .execute(self.pool.as_ref())
253 .await
254 .map(|_| ())
255 .map_err(|e| {
256 StateBackendError::with_source(StateBackendErrorKind::Query, "failed to clear state", e)
257 })
258 }
259}