wesichain_checkpoint_postgres/
lib.rs1use std::convert::TryFrom;
2
3use wesichain_checkpoint_sql::error::CheckpointSqlError;
4use wesichain_checkpoint_sql::migrations::run_migrations;
5use wesichain_checkpoint_sql::ops::{
6 load_latest_checkpoint, save_checkpoint_with_projections_and_queue,
7};
8use wesichain_core::checkpoint::{Checkpoint, Checkpointer};
9use wesichain_core::state::{GraphState, StateSchema};
10use wesichain_core::WesichainError;
11
12#[derive(Debug, Clone)]
13pub struct PostgresCheckpointer {
14 pool: sqlx::PgPool,
15 enable_projections: bool,
16}
17
18#[derive(Debug, Clone)]
19pub struct PostgresCheckpointerBuilder {
20 database_url: String,
21 max_connections: u32,
22 min_connections: u32,
23 enable_projections: bool,
24}
25
26impl PostgresCheckpointer {
27 pub fn builder(database_url: impl Into<String>) -> PostgresCheckpointerBuilder {
28 PostgresCheckpointerBuilder {
29 database_url: database_url.into(),
30 max_connections: 5,
31 min_connections: 0,
32 enable_projections: false,
33 }
34 }
35
36 pub fn projections_enabled(&self) -> bool {
37 self.enable_projections
38 }
39}
40
41impl PostgresCheckpointerBuilder {
42 pub fn max_connections(mut self, max_connections: u32) -> Self {
43 self.max_connections = max_connections;
44 self
45 }
46
47 pub fn min_connections(mut self, min_connections: u32) -> Self {
48 self.min_connections = min_connections;
49 self
50 }
51
52 pub fn enable_projections(mut self, enable_projections: bool) -> Self {
53 self.enable_projections = enable_projections;
54 self
55 }
56
57 pub async fn build(self) -> Result<PostgresCheckpointer, CheckpointSqlError> {
58 let pool = sqlx::postgres::PgPoolOptions::new()
59 .max_connections(self.max_connections)
60 .min_connections(self.min_connections)
61 .connect(&self.database_url)
62 .await
63 .map_err(CheckpointSqlError::Connection)?;
64
65 run_migrations(&pool).await?;
66
67 Ok(PostgresCheckpointer {
68 pool,
69 enable_projections: self.enable_projections,
70 })
71 }
72}
73
74fn graph_checkpoint_error(message: impl Into<String>) -> WesichainError {
75 WesichainError::CheckpointFailed(message.into())
76}
77
78fn map_sql_error(error: CheckpointSqlError) -> WesichainError {
79 graph_checkpoint_error(error.to_string())
80}
81
82impl<S: StateSchema> Checkpointer<S> for PostgresCheckpointer {
83 fn save<'life0, 'life1, 'async_trait>(
84 &'life0 self,
85 checkpoint: &'life1 Checkpoint<S>,
86 ) -> core::pin::Pin<
87 Box<dyn core::future::Future<Output = Result<(), WesichainError>> + Send + 'async_trait>,
88 >
89 where
90 'life0: 'async_trait,
91 'life1: 'async_trait,
92 Self: 'async_trait,
93 {
94 Box::pin(async move {
95 let step = i64::try_from(checkpoint.step)
96 .map_err(|_| graph_checkpoint_error("checkpoint step does not fit into i64"))?;
97
98 save_checkpoint_with_projections_and_queue(
99 &self.pool,
100 &checkpoint.thread_id,
101 &checkpoint.node,
102 step,
103 &checkpoint.created_at,
104 &checkpoint.state,
105 &checkpoint.queue,
106 self.enable_projections,
107 )
108 .await
109 .map_err(map_sql_error)?;
110
111 Ok(())
112 })
113 }
114
115 fn load<'life0, 'life1, 'async_trait>(
116 &'life0 self,
117 thread_id: &'life1 str,
118 ) -> core::pin::Pin<
119 Box<
120 dyn core::future::Future<Output = Result<Option<Checkpoint<S>>, WesichainError>>
121 + Send
122 + 'async_trait,
123 >,
124 >
125 where
126 'life0: 'async_trait,
127 'life1: 'async_trait,
128 Self: 'async_trait,
129 {
130 Box::pin(async move {
131 let stored = load_latest_checkpoint(&self.pool, thread_id)
132 .await
133 .map_err(map_sql_error)?;
134
135 let Some(stored) = stored else {
136 return Ok(None);
137 };
138
139 let step_i64 = stored
140 .step
141 .ok_or_else(|| graph_checkpoint_error("checkpoint step is missing"))?;
142 let step = u64::try_from(step_i64)
143 .map_err(|_| graph_checkpoint_error("checkpoint step is negative"))?;
144
145 let node = stored
146 .node
147 .ok_or_else(|| graph_checkpoint_error("checkpoint node is missing"))?;
148
149 let state: GraphState<S> =
150 serde_json::from_value(stored.state_json).map_err(|error| {
151 graph_checkpoint_error(format!(
152 "failed to deserialize checkpoint state: {error}"
153 ))
154 })?;
155
156 let queue: Vec<(String, u64)> =
157 serde_json::from_value(stored.queue_json).map_err(|error| {
158 graph_checkpoint_error(format!(
159 "failed to deserialize checkpoint queue: {error}"
160 ))
161 })?;
162
163 Ok(Some(Checkpoint {
164 thread_id: stored.thread_id,
165 state,
166 step,
167 node,
168 queue,
169 created_at: stored.created_at,
170 }))
171 })
172 }
173}