Skip to main content

systemprompt_sync/database/
mod.rs

1mod upsert;
2
3use chrono::{DateTime, Utc};
4use serde::{Deserialize, Serialize};
5use sqlx::PgPool;
6use sqlx::prelude::FromRow;
7use systemprompt_identifiers::{ContextId, SessionId, SkillId, SourceId, UserId};
8
9use crate::error::SyncResult;
10use crate::{SyncDirection, SyncOperationResult};
11
12use upsert::{upsert_context, upsert_skill, upsert_user};
13
14#[derive(Clone, Debug, Serialize, Deserialize)]
15pub struct DatabaseExport {
16    pub users: Vec<UserExport>,
17    pub skills: Vec<SkillExport>,
18    pub contexts: Vec<ContextExport>,
19    pub timestamp: DateTime<Utc>,
20}
21
22#[derive(Clone, Debug, Serialize, Deserialize, FromRow)]
23pub struct UserExport {
24    pub id: UserId,
25    pub name: String,
26    pub email: String,
27    pub full_name: Option<String>,
28    pub display_name: Option<String>,
29    pub status: String,
30    pub email_verified: bool,
31    pub roles: Vec<String>,
32    pub is_bot: bool,
33    pub is_scanner: bool,
34    pub avatar_url: Option<String>,
35    pub created_at: DateTime<Utc>,
36    pub updated_at: DateTime<Utc>,
37}
38
39#[derive(Clone, Debug, Serialize, Deserialize, FromRow)]
40pub struct SkillExport {
41    pub skill_id: SkillId,
42    pub file_path: String,
43    pub name: String,
44    pub description: String,
45    pub instructions: String,
46    pub enabled: bool,
47    pub tags: Option<Vec<String>>,
48    pub category_id: Option<String>,
49    pub source_id: SourceId,
50    pub created_at: DateTime<Utc>,
51    pub updated_at: DateTime<Utc>,
52}
53
54#[derive(Clone, Debug, Serialize, Deserialize, FromRow)]
55pub struct ContextExport {
56    pub context_id: ContextId,
57    pub user_id: UserId,
58    pub session_id: Option<SessionId>,
59    pub name: String,
60    pub created_at: DateTime<Utc>,
61    pub updated_at: DateTime<Utc>,
62}
63
64#[derive(Clone, Copy, Debug, Serialize, Deserialize)]
65pub struct ImportResult {
66    pub created: usize,
67    pub updated: usize,
68    pub skipped: usize,
69}
70
71#[derive(Debug)]
72pub struct DatabaseSyncService {
73    direction: SyncDirection,
74    dry_run: bool,
75    local_database_url: String,
76    cloud_database_url: String,
77}
78
79impl DatabaseSyncService {
80    pub fn new(
81        direction: SyncDirection,
82        dry_run: bool,
83        local_database_url: &str,
84        cloud_database_url: &str,
85    ) -> Self {
86        Self {
87            direction,
88            dry_run,
89            local_database_url: local_database_url.to_string(),
90            cloud_database_url: cloud_database_url.to_string(),
91        }
92    }
93
94    pub async fn sync(&self) -> SyncResult<SyncOperationResult> {
95        match self.direction {
96            SyncDirection::Push => self.push().await,
97            SyncDirection::Pull => self.pull().await,
98        }
99    }
100
101    async fn push(&self) -> SyncResult<SyncOperationResult> {
102        let export = export_from_database(&self.local_database_url).await?;
103        let count = export.users.len() + export.skills.len() + export.contexts.len();
104
105        if self.dry_run {
106            return Ok(SyncOperationResult::dry_run(
107                "database_push",
108                count,
109                serde_json::json!({
110                    "users": export.users.len(),
111                    "skills": export.skills.len(),
112                    "contexts": export.contexts.len(),
113                }),
114            ));
115        }
116
117        import_to_database(&self.cloud_database_url, &export).await?;
118        Ok(SyncOperationResult::success("database_push", count))
119    }
120
121    async fn pull(&self) -> SyncResult<SyncOperationResult> {
122        let export = export_from_database(&self.cloud_database_url).await?;
123        let count = export.users.len() + export.skills.len() + export.contexts.len();
124
125        if self.dry_run {
126            return Ok(SyncOperationResult::dry_run(
127                "database_pull",
128                count,
129                serde_json::json!({
130                    "users": export.users.len(),
131                    "skills": export.skills.len(),
132                    "contexts": export.contexts.len(),
133                }),
134            ));
135        }
136
137        import_to_database(&self.local_database_url, &export).await?;
138        Ok(SyncOperationResult::success("database_pull", count))
139    }
140}
141
142async fn export_from_database(database_url: &str) -> SyncResult<DatabaseExport> {
143    let pool = PgPool::connect(database_url).await?;
144
145    let users = sqlx::query_as!(
146        UserExport,
147        r#"SELECT id, name, email, full_name, display_name, status, email_verified,
148                  roles, is_bot, is_scanner, avatar_url, created_at, updated_at
149           FROM users"#
150    )
151    .fetch_all(&pool)
152    .await?;
153
154    let skills = sqlx::query_as!(
155        SkillExport,
156        r#"SELECT skill_id as "skill_id!: SkillId",
157                  file_path, name, description, instructions, enabled,
158                  tags, category_id,
159                  source_id as "source_id!: SourceId",
160                  created_at, updated_at
161           FROM agent_skills"#
162    )
163    .fetch_all(&pool)
164    .await?;
165
166    let contexts = sqlx::query_as!(
167        ContextExport,
168        r#"SELECT context_id as "context_id!: ContextId",
169                  user_id as "user_id!: UserId",
170                  session_id as "session_id: SessionId",
171                  name, created_at, updated_at
172           FROM user_contexts"#
173    )
174    .fetch_all(&pool)
175    .await?;
176
177    Ok(DatabaseExport {
178        users,
179        skills,
180        contexts,
181        timestamp: Utc::now(),
182    })
183}
184
185async fn import_to_database(
186    database_url: &str,
187    export: &DatabaseExport,
188) -> SyncResult<ImportResult> {
189    let pool = PgPool::connect(database_url).await?;
190    let mut created = 0;
191    let mut updated = 0;
192
193    for user in &export.users {
194        let (c, u) = upsert_user(&pool, user).await?;
195        created += c;
196        updated += u;
197    }
198
199    for skill in &export.skills {
200        let (c, u) = upsert_skill(&pool, skill).await?;
201        created += c;
202        updated += u;
203    }
204
205    for context in &export.contexts {
206        let (c, u) = upsert_context(&pool, context).await?;
207        created += c;
208        updated += u;
209    }
210
211    Ok(ImportResult {
212        created,
213        updated,
214        skipped: 0,
215    })
216}