Skip to main content

systemprompt_sync/database/
mod.rs

1//! Database push / pull: serialise the user, skill, and context tables to
2//! JSON and round-trip them between a local Postgres and a cloud Postgres
3//! using compile-time-checked `sqlx` upserts.
4
5mod upsert;
6
7use chrono::{DateTime, Utc};
8use serde::{Deserialize, Serialize};
9use sqlx::PgPool;
10use sqlx::prelude::FromRow;
11use systemprompt_identifiers::{ContextId, SessionId, SkillId, SourceId, UserId};
12
13use crate::error::SyncResult;
14use crate::{SyncDirection, SyncOperationResult};
15
16use upsert::{upsert_context, upsert_skill, upsert_user};
17
18#[derive(Clone, Debug, Serialize, Deserialize)]
19pub struct DatabaseExport {
20    pub users: Vec<UserExport>,
21    pub skills: Vec<SkillExport>,
22    pub contexts: Vec<ContextExport>,
23    pub timestamp: DateTime<Utc>,
24}
25
26#[derive(Clone, Debug, Serialize, Deserialize, FromRow)]
27pub struct UserExport {
28    pub id: UserId,
29    pub name: String,
30    pub email: String,
31    pub full_name: Option<String>,
32    pub display_name: Option<String>,
33    pub status: String,
34    pub email_verified: bool,
35    pub roles: Vec<String>,
36    pub is_bot: bool,
37    pub is_scanner: bool,
38    pub avatar_url: Option<String>,
39    pub created_at: DateTime<Utc>,
40    pub updated_at: DateTime<Utc>,
41}
42
43#[derive(Clone, Debug, Serialize, Deserialize, FromRow)]
44pub struct SkillExport {
45    pub skill_id: SkillId,
46    pub file_path: String,
47    pub name: String,
48    pub description: String,
49    pub instructions: String,
50    pub enabled: bool,
51    pub tags: Option<Vec<String>>,
52    pub category_id: Option<String>,
53    pub source_id: SourceId,
54    pub created_at: DateTime<Utc>,
55    pub updated_at: DateTime<Utc>,
56}
57
58#[derive(Clone, Debug, Serialize, Deserialize, FromRow)]
59pub struct ContextExport {
60    pub context_id: ContextId,
61    pub user_id: UserId,
62    pub session_id: Option<SessionId>,
63    pub name: String,
64    pub created_at: DateTime<Utc>,
65    pub updated_at: DateTime<Utc>,
66}
67
68#[derive(Clone, Copy, Debug, Serialize, Deserialize)]
69pub struct ImportResult {
70    pub created: usize,
71    pub updated: usize,
72    pub skipped: usize,
73}
74
75#[derive(Debug)]
76pub struct DatabaseSyncService {
77    direction: SyncDirection,
78    dry_run: bool,
79    local_database_url: String,
80    cloud_database_url: String,
81}
82
83impl DatabaseSyncService {
84    pub fn new(
85        direction: SyncDirection,
86        dry_run: bool,
87        local_database_url: &str,
88        cloud_database_url: &str,
89    ) -> Self {
90        Self {
91            direction,
92            dry_run,
93            local_database_url: local_database_url.to_string(),
94            cloud_database_url: cloud_database_url.to_string(),
95        }
96    }
97
98    pub async fn sync(&self) -> SyncResult<SyncOperationResult> {
99        match self.direction {
100            SyncDirection::Push => self.push().await,
101            SyncDirection::Pull => self.pull().await,
102        }
103    }
104
105    async fn push(&self) -> SyncResult<SyncOperationResult> {
106        let export = export_from_database(&self.local_database_url).await?;
107        let count = export.users.len() + export.skills.len() + export.contexts.len();
108
109        if self.dry_run {
110            return Ok(SyncOperationResult::dry_run(
111                "database_push",
112                count,
113                serde_json::json!({
114                    "users": export.users.len(),
115                    "skills": export.skills.len(),
116                    "contexts": export.contexts.len(),
117                }),
118            ));
119        }
120
121        import_to_database(&self.cloud_database_url, &export).await?;
122        Ok(SyncOperationResult::success("database_push", count))
123    }
124
125    async fn pull(&self) -> SyncResult<SyncOperationResult> {
126        let export = export_from_database(&self.cloud_database_url).await?;
127        let count = export.users.len() + export.skills.len() + export.contexts.len();
128
129        if self.dry_run {
130            return Ok(SyncOperationResult::dry_run(
131                "database_pull",
132                count,
133                serde_json::json!({
134                    "users": export.users.len(),
135                    "skills": export.skills.len(),
136                    "contexts": export.contexts.len(),
137                }),
138            ));
139        }
140
141        import_to_database(&self.local_database_url, &export).await?;
142        Ok(SyncOperationResult::success("database_pull", count))
143    }
144}
145
146async fn export_from_database(database_url: &str) -> SyncResult<DatabaseExport> {
147    let pool = PgPool::connect(database_url).await?;
148
149    let users = sqlx::query_as!(
150        UserExport,
151        r#"SELECT id, name, email, full_name, display_name, status, email_verified,
152                  roles, is_bot, is_scanner, avatar_url, created_at, updated_at
153           FROM users"#
154    )
155    .fetch_all(&pool)
156    .await?;
157
158    let skills = sqlx::query_as!(
159        SkillExport,
160        r#"SELECT skill_id as "skill_id!: SkillId",
161                  file_path, name, description, instructions, enabled,
162                  tags, category_id,
163                  source_id as "source_id!: SourceId",
164                  created_at, updated_at
165           FROM agent_skills"#
166    )
167    .fetch_all(&pool)
168    .await?;
169
170    let contexts = sqlx::query_as!(
171        ContextExport,
172        r#"SELECT context_id as "context_id!: ContextId",
173                  user_id as "user_id!: UserId",
174                  session_id as "session_id: SessionId",
175                  name, created_at, updated_at
176           FROM user_contexts"#
177    )
178    .fetch_all(&pool)
179    .await?;
180
181    Ok(DatabaseExport {
182        users,
183        skills,
184        contexts,
185        timestamp: Utc::now(),
186    })
187}
188
189async fn import_to_database(
190    database_url: &str,
191    export: &DatabaseExport,
192) -> SyncResult<ImportResult> {
193    let pool = PgPool::connect(database_url).await?;
194    let mut created = 0;
195    let mut updated = 0;
196    let mut completed = 0usize;
197    let total = export.users.len() + export.skills.len() + export.contexts.len();
198
199    for user in &export.users {
200        match upsert_user(&pool, user).await {
201            Ok((c, u)) => {
202                created += c;
203                updated += u;
204                completed += 1;
205            },
206            Err(err) => {
207                return Err(crate::error::SyncError::PartialImport {
208                    completed,
209                    total,
210                    message: format!("user upsert failed: {err}"),
211                });
212            },
213        }
214    }
215
216    for skill in &export.skills {
217        match upsert_skill(&pool, skill).await {
218            Ok((c, u)) => {
219                created += c;
220                updated += u;
221                completed += 1;
222            },
223            Err(err) => {
224                return Err(crate::error::SyncError::PartialImport {
225                    completed,
226                    total,
227                    message: format!("skill upsert failed: {err}"),
228                });
229            },
230        }
231    }
232
233    for context in &export.contexts {
234        match upsert_context(&pool, context).await {
235            Ok((c, u)) => {
236                created += c;
237                updated += u;
238                completed += 1;
239            },
240            Err(err) => {
241                return Err(crate::error::SyncError::PartialImport {
242                    completed,
243                    total,
244                    message: format!("context upsert failed: {err}"),
245                });
246            },
247        }
248    }
249
250    Ok(ImportResult {
251        created,
252        updated,
253        skipped: 0,
254    })
255}