1use anyhow::Result;
6use async_trait::async_trait;
7use chrono::Utc;
8use deadpool_postgres::Pool;
9use tracing::{debug, info};
10
11use tianshu::store::{SessionStateEntry, StateEntry, StateStore};
12
13pub struct PostgresStateStore {
14 pool: Pool,
15}
16
17impl PostgresStateStore {
18 pub fn new(pool: Pool) -> Self {
19 Self { pool }
20 }
21}
22
23#[async_trait]
24impl StateStore for PostgresStateStore {
25 async fn save(&self, case_key: &str, step: &str, data: &str) -> Result<()> {
26 let client = self.pool.get().await?;
27 debug!("Saving state: case_key={}, step={}", case_key, step);
28
29 let now = Utc::now();
30 client
31 .execute(
32 r#"
33 INSERT INTO wf_state (case_key, step, data, updated_at)
34 VALUES ($1, $2, $3, $4)
35 ON CONFLICT (case_key, step) DO UPDATE SET
36 data = EXCLUDED.data,
37 updated_at = EXCLUDED.updated_at
38 "#,
39 &[&case_key, &step, &data, &now],
40 )
41 .await?;
42
43 info!("Saved state: case_key={}, step={}", case_key, step);
44 Ok(())
45 }
46
47 async fn get(&self, case_key: &str, step: &str) -> Result<Option<StateEntry>> {
48 let client = self.pool.get().await?;
49 debug!("Getting state: case_key={}, step={}", case_key, step);
50
51 let row_opt = client
52 .query_opt(
53 "SELECT case_key, step, data, updated_at FROM wf_state WHERE case_key = $1 AND step = $2",
54 &[&case_key, &step],
55 )
56 .await?;
57
58 Ok(row_opt.map(|row| StateEntry {
59 case_key: row.get(0),
60 step: row.get(1),
61 data: row.get(2),
62 updated_at: row.get(3),
63 }))
64 }
65
66 async fn get_all(&self, case_key: &str) -> Result<Vec<StateEntry>> {
67 let client = self.pool.get().await?;
68 debug!("Getting all state for case_key={}", case_key);
69
70 let rows = client
71 .query(
72 "SELECT case_key, step, data, updated_at FROM wf_state WHERE case_key = $1",
73 &[&case_key],
74 )
75 .await?;
76
77 Ok(rows
78 .iter()
79 .map(|row| StateEntry {
80 case_key: row.get(0),
81 step: row.get(1),
82 data: row.get(2),
83 updated_at: row.get(3),
84 })
85 .collect())
86 }
87
88 async fn delete_by_case(&self, case_key: &str) -> Result<()> {
89 let client = self.pool.get().await?;
90 let count = client
91 .execute("DELETE FROM wf_state WHERE case_key = $1", &[&case_key])
92 .await?;
93 info!("Deleted {} state entries for case_key={}", count, case_key);
94 Ok(())
95 }
96
97 async fn save_session(&self, session_id: &str, step: &str, data: &str) -> Result<()> {
98 let client = self.pool.get().await?;
99 debug!(
100 "Saving session state: session_id={}, step={}",
101 session_id, step
102 );
103
104 let now = Utc::now();
105 client
106 .execute(
107 r#"
108 INSERT INTO wf_session_state (session_id, step, data, updated_at)
109 VALUES ($1, $2, $3, $4)
110 ON CONFLICT (session_id, step) DO UPDATE SET
111 data = EXCLUDED.data,
112 updated_at = EXCLUDED.updated_at
113 "#,
114 &[&session_id, &step, &data, &now],
115 )
116 .await?;
117
118 info!(
119 "Saved session state: session_id={}, step={}",
120 session_id, step
121 );
122 Ok(())
123 }
124
125 async fn get_session(&self, session_id: &str, step: &str) -> Result<Option<SessionStateEntry>> {
126 let client = self.pool.get().await?;
127 let row_opt = client
128 .query_opt(
129 "SELECT session_id, step, data, updated_at FROM wf_session_state WHERE session_id = $1 AND step = $2",
130 &[&session_id, &step],
131 )
132 .await?;
133
134 Ok(row_opt.map(|row| SessionStateEntry {
135 session_id: row.get(0),
136 step: row.get(1),
137 data: row.get(2),
138 updated_at: row.get(3),
139 }))
140 }
141
142 async fn get_all_session(&self, session_id: &str) -> Result<Vec<SessionStateEntry>> {
143 let client = self.pool.get().await?;
144 let rows = client
145 .query(
146 "SELECT session_id, step, data, updated_at FROM wf_session_state WHERE session_id = $1",
147 &[&session_id],
148 )
149 .await?;
150
151 Ok(rows
152 .iter()
153 .map(|row| SessionStateEntry {
154 session_id: row.get(0),
155 step: row.get(1),
156 data: row.get(2),
157 updated_at: row.get(3),
158 })
159 .collect())
160 }
161
162 async fn delete_by_session(&self, session_id: &str) -> Result<()> {
163 let client = self.pool.get().await?;
164 let count = client
165 .execute(
166 "DELETE FROM wf_session_state WHERE session_id = $1",
167 &[&session_id],
168 )
169 .await?;
170 info!(
171 "Deleted {} session state entries for session_id={}",
172 count, session_id
173 );
174 Ok(())
175 }
176
177 async fn setup(&self) -> Result<()> {
178 let client = self.pool.get().await?;
179 client
180 .execute(
181 r#"
182 CREATE TABLE IF NOT EXISTS wf_state (
183 case_key TEXT NOT NULL,
184 step TEXT NOT NULL,
185 data TEXT NOT NULL,
186 updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
187 PRIMARY KEY (case_key, step)
188 )
189 "#,
190 &[],
191 )
192 .await?;
193 client
194 .execute(
195 "CREATE INDEX IF NOT EXISTS wf_state_case_key_idx ON wf_state (case_key)",
196 &[],
197 )
198 .await?;
199 client
200 .execute(
201 r#"
202 CREATE TABLE IF NOT EXISTS wf_session_state (
203 session_id TEXT NOT NULL,
204 step TEXT NOT NULL,
205 data TEXT NOT NULL,
206 updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
207 PRIMARY KEY (session_id, step)
208 )
209 "#,
210 &[],
211 )
212 .await?;
213 client
214 .execute(
215 "CREATE INDEX IF NOT EXISTS wf_session_state_session_id_idx ON wf_session_state (session_id)",
216 &[],
217 )
218 .await?;
219 info!("wf_state and wf_session_state tables ready");
220 Ok(())
221 }
222}