Skip to main content

systemprompt_sync/database/
mod.rs

1mod upsert;
2
3use chrono::{DateTime, Utc};
4use serde::{Deserialize, Serialize};
5use sqlx::prelude::FromRow;
6use sqlx::PgPool;
7
8use crate::error::SyncResult;
9use crate::{SyncDirection, SyncOperationResult};
10
11use upsert::{upsert_context, upsert_skill, upsert_user};
12
13#[derive(Clone, Debug, Serialize, Deserialize)]
14pub struct DatabaseExport {
15    pub users: Vec<UserExport>,
16    pub skills: Vec<SkillExport>,
17    pub contexts: Vec<ContextExport>,
18    pub timestamp: DateTime<Utc>,
19}
20
21#[derive(Clone, Debug, Serialize, Deserialize, FromRow)]
22pub struct UserExport {
23    pub id: String,
24    pub name: String,
25    pub email: String,
26    pub full_name: Option<String>,
27    pub display_name: Option<String>,
28    pub status: String,
29    pub email_verified: bool,
30    pub roles: Vec<String>,
31    pub is_bot: bool,
32    pub is_scanner: bool,
33    pub avatar_url: Option<String>,
34    pub created_at: DateTime<Utc>,
35    pub updated_at: DateTime<Utc>,
36}
37
38#[derive(Clone, Debug, Serialize, Deserialize, FromRow)]
39pub struct SkillExport {
40    pub skill_id: String,
41    pub file_path: String,
42    pub name: String,
43    pub description: String,
44    pub instructions: String,
45    pub enabled: bool,
46    pub tags: Option<Vec<String>>,
47    pub category_id: Option<String>,
48    pub source_id: String,
49    pub created_at: DateTime<Utc>,
50    pub updated_at: DateTime<Utc>,
51}
52
53#[derive(Clone, Debug, Serialize, Deserialize, FromRow)]
54pub struct ContextExport {
55    pub context_id: String,
56    pub user_id: String,
57    pub session_id: Option<String>,
58    pub name: String,
59    pub created_at: DateTime<Utc>,
60    pub updated_at: DateTime<Utc>,
61}
62
63#[derive(Clone, Copy, Debug, Serialize, Deserialize)]
64pub struct ImportResult {
65    pub created: usize,
66    pub updated: usize,
67    pub skipped: usize,
68}
69
70#[derive(Debug)]
71pub struct DatabaseSyncService {
72    direction: SyncDirection,
73    dry_run: bool,
74    local_database_url: String,
75    cloud_database_url: String,
76}
77
78impl DatabaseSyncService {
79    pub fn new(
80        direction: SyncDirection,
81        dry_run: bool,
82        local_database_url: &str,
83        cloud_database_url: &str,
84    ) -> Self {
85        Self {
86            direction,
87            dry_run,
88            local_database_url: local_database_url.to_string(),
89            cloud_database_url: cloud_database_url.to_string(),
90        }
91    }
92
93    pub async fn sync(&self) -> SyncResult<SyncOperationResult> {
94        match self.direction {
95            SyncDirection::Push => self.push().await,
96            SyncDirection::Pull => self.pull().await,
97        }
98    }
99
100    async fn push(&self) -> SyncResult<SyncOperationResult> {
101        let export = export_from_database(&self.local_database_url).await?;
102        let count = export.users.len() + export.skills.len() + export.contexts.len();
103
104        if self.dry_run {
105            return Ok(SyncOperationResult::dry_run(
106                "database_push",
107                count,
108                serde_json::json!({
109                    "users": export.users.len(),
110                    "skills": export.skills.len(),
111                    "contexts": export.contexts.len(),
112                }),
113            ));
114        }
115
116        import_to_database(&self.cloud_database_url, &export).await?;
117        Ok(SyncOperationResult::success("database_push", count))
118    }
119
120    async fn pull(&self) -> SyncResult<SyncOperationResult> {
121        let export = export_from_database(&self.cloud_database_url).await?;
122        let count = export.users.len() + export.skills.len() + export.contexts.len();
123
124        if self.dry_run {
125            return Ok(SyncOperationResult::dry_run(
126                "database_pull",
127                count,
128                serde_json::json!({
129                    "users": export.users.len(),
130                    "skills": export.skills.len(),
131                    "contexts": export.contexts.len(),
132                }),
133            ));
134        }
135
136        import_to_database(&self.local_database_url, &export).await?;
137        Ok(SyncOperationResult::success("database_pull", count))
138    }
139}
140
141async fn export_from_database(database_url: &str) -> SyncResult<DatabaseExport> {
142    let pool = PgPool::connect(database_url).await?;
143
144    let users = sqlx::query_as!(
145        UserExport,
146        r#"SELECT id, name, email, full_name, display_name, status, email_verified,
147                  roles, is_bot, is_scanner, avatar_url, created_at, updated_at
148           FROM users"#
149    )
150    .fetch_all(&pool)
151    .await?;
152
153    let skills = sqlx::query_as!(
154        SkillExport,
155        r#"SELECT skill_id, file_path, name, description, instructions, enabled,
156                  tags, category_id, source_id, created_at, updated_at
157           FROM agent_skills"#
158    )
159    .fetch_all(&pool)
160    .await?;
161
162    let contexts = sqlx::query_as!(
163        ContextExport,
164        r#"SELECT context_id, user_id, session_id, name, created_at, updated_at
165           FROM user_contexts"#
166    )
167    .fetch_all(&pool)
168    .await?;
169
170    Ok(DatabaseExport {
171        users,
172        skills,
173        contexts,
174        timestamp: Utc::now(),
175    })
176}
177
178async fn import_to_database(
179    database_url: &str,
180    export: &DatabaseExport,
181) -> SyncResult<ImportResult> {
182    let pool = PgPool::connect(database_url).await?;
183    let mut created = 0;
184    let mut updated = 0;
185
186    for user in &export.users {
187        let (c, u) = upsert_user(&pool, user).await?;
188        created += c;
189        updated += u;
190    }
191
192    for skill in &export.skills {
193        let (c, u) = upsert_skill(&pool, skill).await?;
194        created += c;
195        updated += u;
196    }
197
198    for context in &export.contexts {
199        let (c, u) = upsert_context(&pool, context).await?;
200        created += c;
201        updated += u;
202    }
203
204    Ok(ImportResult {
205        created,
206        updated,
207        skipped: 0,
208    })
209}