Skip to main content

systemprompt_sync/database/
mod.rs

1//! Database push / pull: serialise the user and context tables to
2//! JSON and round-trip them between a local Postgres and a cloud Postgres
3//! using compile-time-checked `sqlx` upserts.
4
5mod upsert;
6
7use chrono::{DateTime, Utc};
8use serde::{Deserialize, Serialize};
9use sqlx::PgPool;
10use sqlx::prelude::FromRow;
11use systemprompt_identifiers::{ContextId, SessionId, UserId};
12
13use crate::error::SyncResult;
14use crate::{SyncDirection, SyncOperationResult};
15
16use upsert::{upsert_context, upsert_user};
17
18#[derive(Clone, Debug, Serialize, Deserialize)]
19pub struct DatabaseExport {
20    pub users: Vec<UserExport>,
21    pub contexts: Vec<ContextExport>,
22    pub timestamp: DateTime<Utc>,
23}
24
25#[derive(Clone, Debug, Serialize, Deserialize, FromRow)]
26pub struct UserExport {
27    pub id: UserId,
28    pub name: String,
29    pub email: String,
30    pub full_name: Option<String>,
31    pub display_name: Option<String>,
32    pub status: String,
33    pub email_verified: bool,
34    pub roles: Vec<String>,
35    pub is_bot: bool,
36    pub is_scanner: bool,
37    pub avatar_url: Option<String>,
38    pub created_at: DateTime<Utc>,
39    pub updated_at: DateTime<Utc>,
40}
41
42#[derive(Clone, Debug, Serialize, Deserialize, FromRow)]
43pub struct ContextExport {
44    pub context_id: ContextId,
45    pub user_id: UserId,
46    pub session_id: Option<SessionId>,
47    pub name: String,
48    pub created_at: DateTime<Utc>,
49    pub updated_at: DateTime<Utc>,
50}
51
52#[derive(Clone, Copy, Debug, Serialize, Deserialize)]
53pub struct ImportResult {
54    pub created: usize,
55    pub updated: usize,
56    pub skipped: usize,
57}
58
59#[derive(Debug)]
60pub struct DatabaseSyncService {
61    direction: SyncDirection,
62    dry_run: bool,
63    local_database_url: String,
64    cloud_database_url: String,
65}
66
67impl DatabaseSyncService {
68    pub fn new(
69        direction: SyncDirection,
70        dry_run: bool,
71        local_database_url: &str,
72        cloud_database_url: &str,
73    ) -> Self {
74        Self {
75            direction,
76            dry_run,
77            local_database_url: local_database_url.to_string(),
78            cloud_database_url: cloud_database_url.to_string(),
79        }
80    }
81
82    pub async fn sync(&self) -> SyncResult<SyncOperationResult> {
83        match self.direction {
84            SyncDirection::Push => self.push().await,
85            SyncDirection::Pull => self.pull().await,
86        }
87    }
88
89    async fn push(&self) -> SyncResult<SyncOperationResult> {
90        let export = export_from_database(&self.local_database_url).await?;
91        let count = export.users.len() + export.contexts.len();
92
93        if self.dry_run {
94            return Ok(SyncOperationResult::dry_run(
95                "database_push",
96                count,
97                serde_json::json!({
98                    "users": export.users.len(),
99                    "contexts": export.contexts.len(),
100                }),
101            ));
102        }
103
104        import_to_database(&self.cloud_database_url, &export).await?;
105        Ok(SyncOperationResult::success("database_push", count))
106    }
107
108    async fn pull(&self) -> SyncResult<SyncOperationResult> {
109        let export = export_from_database(&self.cloud_database_url).await?;
110        let count = export.users.len() + export.contexts.len();
111
112        if self.dry_run {
113            return Ok(SyncOperationResult::dry_run(
114                "database_pull",
115                count,
116                serde_json::json!({
117                    "users": export.users.len(),
118                    "contexts": export.contexts.len(),
119                }),
120            ));
121        }
122
123        import_to_database(&self.local_database_url, &export).await?;
124        Ok(SyncOperationResult::success("database_pull", count))
125    }
126}
127
128async fn export_from_database(database_url: &str) -> SyncResult<DatabaseExport> {
129    let pool = PgPool::connect(database_url).await?;
130
131    let users = sqlx::query_as!(
132        UserExport,
133        r#"SELECT id, name, email, full_name, display_name, status, email_verified,
134                  roles, is_bot, is_scanner, avatar_url, created_at, updated_at
135           FROM users"#
136    )
137    .fetch_all(&pool)
138    .await?;
139
140    let contexts = sqlx::query_as!(
141        ContextExport,
142        r#"SELECT context_id as "context_id!: ContextId",
143                  user_id as "user_id!: UserId",
144                  session_id as "session_id: SessionId",
145                  name, created_at, updated_at
146           FROM user_contexts"#
147    )
148    .fetch_all(&pool)
149    .await?;
150
151    Ok(DatabaseExport {
152        users,
153        contexts,
154        timestamp: Utc::now(),
155    })
156}
157
158async fn import_to_database(
159    database_url: &str,
160    export: &DatabaseExport,
161) -> SyncResult<ImportResult> {
162    let pool = PgPool::connect(database_url).await?;
163    let mut created = 0;
164    let mut updated = 0;
165    let mut completed = 0usize;
166    let total = export.users.len() + export.contexts.len();
167
168    for user in &export.users {
169        match upsert_user(&pool, user).await {
170            Ok((c, u)) => {
171                created += c;
172                updated += u;
173                completed += 1;
174            },
175            Err(err) => {
176                return Err(crate::error::SyncError::PartialImport {
177                    completed,
178                    total,
179                    message: format!("user upsert failed: {err}"),
180                });
181            },
182        }
183    }
184
185    for context in &export.contexts {
186        match upsert_context(&pool, context).await {
187            Ok((c, u)) => {
188                created += c;
189                updated += u;
190                completed += 1;
191            },
192            Err(err) => {
193                return Err(crate::error::SyncError::PartialImport {
194                    completed,
195                    total,
196                    message: format!("context upsert failed: {err}"),
197                });
198            },
199        }
200    }
201
202    Ok(ImportResult {
203        created,
204        updated,
205        skipped: 0,
206    })
207}