Skip to main content

streamling_state/
lib.rs

1use async_trait::async_trait;
2use in_memory::InMemoryStateOperatorBackendFactory;
3use postgres::PostgresStateOperatorBackendFactory;
4use serde::{Deserialize, Serialize};
5use sqlite::SqliteStateOperatorBackendFactory;
6use std::convert::From;
7use std::error::Error as StdError;
8use std::fmt;
9use std::fmt::Debug;
10use std::sync::Arc;
11use streamling_config::app_config::{StateBackendConfig, StateBackendType};
12
13pub mod in_memory;
14pub mod postgres;
15pub mod sqlite;
16
17#[cfg(feature = "test-utils")]
18pub mod testing;
19
20/// Type for the keys used in the state backend.
21/// For simplicity, it's assumed that all state backends will use the same key format (strings).
22/// If it were to change, a new generic type could be added, similar to the `V` type parameter.
23#[derive(Debug, Clone, Hash, Eq, PartialEq)]
24pub struct StateKey(pub String);
25
26impl From<&str> for StateKey {
27    fn from(s: &str) -> Self {
28        StateKey(s.to_string())
29    }
30}
31
32impl From<String> for StateKey {
33    fn from(s: String) -> Self {
34        StateKey(s)
35    }
36}
37
38#[derive(Debug, Clone, Copy, PartialEq, Eq)]
39pub enum StateBackendErrorKind {
40    Initialization,
41    Connection,
42    Query,
43    Serialization,
44}
45
46#[derive(Debug)]
47pub struct StateBackendError {
48    kind: StateBackendErrorKind,
49    message: String,
50    source: Option<Box<dyn StdError + Send + Sync>>,
51}
52
53impl StateBackendError {
54    pub fn new<M: Into<String>>(kind: StateBackendErrorKind, message: M) -> Self {
55        Self {
56            kind,
57            message: message.into(),
58            source: None,
59        }
60    }
61
62    pub fn with_source<M, E>(kind: StateBackendErrorKind, message: M, source: E) -> Self
63    where
64        M: Into<String>,
65        E: StdError + Send + Sync + 'static,
66    {
67        Self {
68            kind,
69            message: message.into(),
70            source: Some(Box::new(source)),
71        }
72    }
73
74    pub fn kind(&self) -> StateBackendErrorKind {
75        self.kind
76    }
77}
78
79impl fmt::Display for StateBackendError {
80    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
81        write!(f, "{:?}: {}", self.kind, self.message)?;
82        if let Some(source) = &self.source {
83            write!(f, "\n\nCaused by:\n    {}", source)?;
84        }
85        Ok(())
86    }
87}
88
89impl StdError for StateBackendError {
90    fn source(&self) -> Option<&(dyn StdError + 'static)> {
91        self.source
92            .as_ref()
93            .map(|e| e.as_ref() as &(dyn StdError + 'static))
94    }
95}
96
97#[async_trait]
98pub trait StateOperatorBackend<V>: Debug + Sync + Send
99where
100    V: Serialize + for<'de> Deserialize<'de>,
101{
102    async fn get(&self, key: StateKey) -> Result<Option<V>, StateBackendError>;
103    async fn put(&self, key: StateKey, value: V) -> Result<(), StateBackendError>;
104    async fn remove(&self, key: StateKey) -> Result<(), StateBackendError>;
105    async fn clear(&self) -> Result<(), StateBackendError>;
106}
107
108pub enum StateBackendFactories {
109    InMemory(InMemoryStateOperatorBackendFactory),
110    Postgres(PostgresStateOperatorBackendFactory),
111    Sqlite(SqliteStateOperatorBackendFactory),
112}
113
114impl StateBackendFactories {
115    pub fn new(config: StateBackendConfig) -> Result<Self, StateBackendError> {
116        let init_future = async move {
117            match config.backend_type {
118                StateBackendType::InMemory => Ok(StateBackendFactories::InMemory(
119                    InMemoryStateOperatorBackendFactory::new()?,
120                )),
121                StateBackendType::Postgres => {
122                    let postgres_config = config
123                        .postgres
124                        .expect("Postgres JSON backend config is required");
125                    Ok(StateBackendFactories::Postgres(
126                        PostgresStateOperatorBackendFactory::new(
127                            postgres_config.connection_url(),
128                            postgres_config.max_connections,
129                            postgres_config.state_schema_name,
130                            postgres_config.state_table_name,
131                        )
132                        .await?,
133                    ))
134                }
135                StateBackendType::Sqlite => {
136                    let sqlite_config = config
137                        .sqlite
138                        .expect("SQLite JSON backend config is required");
139                    Ok(StateBackendFactories::Sqlite(
140                        SqliteStateOperatorBackendFactory::new(
141                            sqlite_config.database_path,
142                            sqlite_config.max_connections,
143                            sqlite_config.state_table_name,
144                        )
145                        .await?,
146                    ))
147                }
148            }
149        };
150
151        tokio::task::block_in_place(|| tokio::runtime::Handle::current().block_on(init_future))
152    }
153}
154
155pub trait StateOperatorBackendFactory {
156    fn create<V>(&self, namespace: &str) -> Arc<dyn StateOperatorBackend<V>>
157    where
158        V: Serialize + for<'de> Deserialize<'de> + Send + Sync + Unpin + Clone + Debug + 'static;
159}
160
161impl StateOperatorBackendFactory for StateBackendFactories {
162    fn create<V>(&self, namespace: &str) -> Arc<dyn StateOperatorBackend<V>>
163    where
164        V: Serialize + for<'de> Deserialize<'de> + Send + Sync + Unpin + Clone + Debug + 'static,
165    {
166        match self {
167            StateBackendFactories::InMemory(factory) => factory.create(namespace),
168            StateBackendFactories::Postgres(factory) => factory.create(namespace),
169            StateBackendFactories::Sqlite(factory) => factory.create(namespace),
170        }
171    }
172}