Skip to main content

tianshu_postgres/
state_store.rs

1// Copyright 2026 Desicool
2//
3// SPDX-License-Identifier: Apache-2.0
4
5use anyhow::Result;
6use async_trait::async_trait;
7use chrono::Utc;
8use deadpool_postgres::Pool;
9use tracing::{debug, info};
10
11use tianshu::store::{SessionStateEntry, StateEntry, StateStore};
12
13pub struct PostgresStateStore {
14    pool: Pool,
15}
16
17impl PostgresStateStore {
18    pub fn new(pool: Pool) -> Self {
19        Self { pool }
20    }
21}
22
23#[async_trait]
24impl StateStore for PostgresStateStore {
25    async fn save(&self, case_key: &str, step: &str, data: &str) -> Result<()> {
26        let client = self.pool.get().await?;
27        debug!("Saving state: case_key={}, step={}", case_key, step);
28
29        let now = Utc::now();
30        client
31            .execute(
32                r#"
33                INSERT INTO wf_state (case_key, step, data, updated_at)
34                VALUES ($1, $2, $3, $4)
35                ON CONFLICT (case_key, step) DO UPDATE SET
36                    data       = EXCLUDED.data,
37                    updated_at = EXCLUDED.updated_at
38                "#,
39                &[&case_key, &step, &data, &now],
40            )
41            .await?;
42
43        info!("Saved state: case_key={}, step={}", case_key, step);
44        Ok(())
45    }
46
47    async fn get(&self, case_key: &str, step: &str) -> Result<Option<StateEntry>> {
48        let client = self.pool.get().await?;
49        debug!("Getting state: case_key={}, step={}", case_key, step);
50
51        let row_opt = client
52            .query_opt(
53                "SELECT case_key, step, data, updated_at FROM wf_state WHERE case_key = $1 AND step = $2",
54                &[&case_key, &step],
55            )
56            .await?;
57
58        Ok(row_opt.map(|row| StateEntry {
59            case_key: row.get(0),
60            step: row.get(1),
61            data: row.get(2),
62            updated_at: row.get(3),
63        }))
64    }
65
66    async fn get_all(&self, case_key: &str) -> Result<Vec<StateEntry>> {
67        let client = self.pool.get().await?;
68        debug!("Getting all state for case_key={}", case_key);
69
70        let rows = client
71            .query(
72                "SELECT case_key, step, data, updated_at FROM wf_state WHERE case_key = $1",
73                &[&case_key],
74            )
75            .await?;
76
77        Ok(rows
78            .iter()
79            .map(|row| StateEntry {
80                case_key: row.get(0),
81                step: row.get(1),
82                data: row.get(2),
83                updated_at: row.get(3),
84            })
85            .collect())
86    }
87
88    async fn delete_by_case(&self, case_key: &str) -> Result<()> {
89        let client = self.pool.get().await?;
90        let count = client
91            .execute("DELETE FROM wf_state WHERE case_key = $1", &[&case_key])
92            .await?;
93        info!("Deleted {} state entries for case_key={}", count, case_key);
94        Ok(())
95    }
96
97    async fn save_session(&self, session_id: &str, step: &str, data: &str) -> Result<()> {
98        let client = self.pool.get().await?;
99        debug!(
100            "Saving session state: session_id={}, step={}",
101            session_id, step
102        );
103
104        let now = Utc::now();
105        client
106            .execute(
107                r#"
108                INSERT INTO wf_session_state (session_id, step, data, updated_at)
109                VALUES ($1, $2, $3, $4)
110                ON CONFLICT (session_id, step) DO UPDATE SET
111                    data       = EXCLUDED.data,
112                    updated_at = EXCLUDED.updated_at
113                "#,
114                &[&session_id, &step, &data, &now],
115            )
116            .await?;
117
118        info!(
119            "Saved session state: session_id={}, step={}",
120            session_id, step
121        );
122        Ok(())
123    }
124
125    async fn get_session(&self, session_id: &str, step: &str) -> Result<Option<SessionStateEntry>> {
126        let client = self.pool.get().await?;
127        let row_opt = client
128            .query_opt(
129                "SELECT session_id, step, data, updated_at FROM wf_session_state WHERE session_id = $1 AND step = $2",
130                &[&session_id, &step],
131            )
132            .await?;
133
134        Ok(row_opt.map(|row| SessionStateEntry {
135            session_id: row.get(0),
136            step: row.get(1),
137            data: row.get(2),
138            updated_at: row.get(3),
139        }))
140    }
141
142    async fn get_all_session(&self, session_id: &str) -> Result<Vec<SessionStateEntry>> {
143        let client = self.pool.get().await?;
144        let rows = client
145            .query(
146                "SELECT session_id, step, data, updated_at FROM wf_session_state WHERE session_id = $1",
147                &[&session_id],
148            )
149            .await?;
150
151        Ok(rows
152            .iter()
153            .map(|row| SessionStateEntry {
154                session_id: row.get(0),
155                step: row.get(1),
156                data: row.get(2),
157                updated_at: row.get(3),
158            })
159            .collect())
160    }
161
162    async fn delete_by_session(&self, session_id: &str) -> Result<()> {
163        let client = self.pool.get().await?;
164        let count = client
165            .execute(
166                "DELETE FROM wf_session_state WHERE session_id = $1",
167                &[&session_id],
168            )
169            .await?;
170        info!(
171            "Deleted {} session state entries for session_id={}",
172            count, session_id
173        );
174        Ok(())
175    }
176
177    async fn setup(&self) -> Result<()> {
178        let client = self.pool.get().await?;
179        client
180            .execute(
181                r#"
182                CREATE TABLE IF NOT EXISTS wf_state (
183                    case_key   TEXT NOT NULL,
184                    step       TEXT NOT NULL,
185                    data       TEXT NOT NULL,
186                    updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
187                    PRIMARY KEY (case_key, step)
188                )
189                "#,
190                &[],
191            )
192            .await?;
193        client
194            .execute(
195                "CREATE INDEX IF NOT EXISTS wf_state_case_key_idx ON wf_state (case_key)",
196                &[],
197            )
198            .await?;
199        client
200            .execute(
201                r#"
202                CREATE TABLE IF NOT EXISTS wf_session_state (
203                    session_id TEXT NOT NULL,
204                    step       TEXT NOT NULL,
205                    data       TEXT NOT NULL,
206                    updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
207                    PRIMARY KEY (session_id, step)
208                )
209                "#,
210                &[],
211            )
212            .await?;
213        client
214            .execute(
215                "CREATE INDEX IF NOT EXISTS wf_session_state_session_id_idx ON wf_session_state (session_id)",
216                &[],
217            )
218            .await?;
219        info!("wf_state and wf_session_state tables ready");
220        Ok(())
221    }
222}