Skip to main content

systemprompt_database/lifecycle/
installation.rs

1use super::migrations::MigrationService;
2use crate::services::{DatabaseProvider, SqlExecutor};
3use anyhow::Result;
4use std::path::Path;
5use systemprompt_extension::{Extension, ExtensionRegistry, LoaderError, SchemaSource, SeedSource};
6use systemprompt_models::modules::{Module, ModuleSchema};
7use tracing::{debug, info, warn};
8
9#[derive(Debug, Clone, Copy)]
10pub struct ModuleInstaller;
11
12impl ModuleInstaller {
13    pub async fn install(module: &Module, db: &dyn DatabaseProvider) -> Result<()> {
14        install_module_schemas_from_source(module, db).await?;
15        install_module_seeds_from_path(module, db).await?;
16        Ok(())
17    }
18}
19
20pub async fn install_module_schemas_from_source(
21    module: &Module,
22    db: &dyn DatabaseProvider,
23) -> Result<()> {
24    let Some(schemas) = &module.schemas else {
25        return Ok(());
26    };
27
28    for schema in schemas {
29        if schema.table.is_empty() {
30            let sql = read_module_schema_sql(module, schema)?;
31            SqlExecutor::execute_statements_parsed(db, &sql).await?;
32            continue;
33        }
34
35        if !table_exists(db, &schema.table).await? {
36            let sql = read_module_schema_sql(module, schema)?;
37            SqlExecutor::execute_statements_parsed(db, &sql).await?;
38        }
39    }
40
41    Ok(())
42}
43
44fn read_module_schema_sql(module: &Module, schema: &ModuleSchema) -> Result<String> {
45    match &schema.sql {
46        SchemaSource::Inline(sql) => Ok(sql.clone()),
47        SchemaSource::File(relative_path) => {
48            let full_path = module.path.join(relative_path);
49            std::fs::read_to_string(&full_path).map_err(|e| {
50                anyhow::anyhow!(
51                    "Failed to read schema file '{}' for module '{}': {e}",
52                    full_path.display(),
53                    module.name
54                )
55            })
56        },
57    }
58}
59
60pub async fn install_module_seeds_from_path(
61    module: &Module,
62    db: &dyn DatabaseProvider,
63) -> Result<()> {
64    let Some(seeds) = &module.seeds else {
65        return Ok(());
66    };
67
68    for seed in seeds {
69        let sql = match &seed.sql {
70            SeedSource::Inline(sql) => sql.clone(),
71            SeedSource::File(relative_path) => {
72                let seed_path = module.path.join(relative_path);
73                if !seed_path.exists() {
74                    anyhow::bail!(
75                        "Seed file not found for module '{}': {}",
76                        module.name,
77                        seed_path.display()
78                    );
79                }
80                std::fs::read_to_string(&seed_path)?
81            },
82        };
83        SqlExecutor::execute_statements_parsed(db, &sql).await?;
84    }
85
86    Ok(())
87}
88
89pub async fn install_schema(db: &dyn DatabaseProvider, schema_path: &Path) -> Result<()> {
90    let schema_content = std::fs::read_to_string(schema_path)?;
91    SqlExecutor::execute_statements_parsed(db, &schema_content).await
92}
93
94pub async fn install_seed(db: &dyn DatabaseProvider, seed_path: &Path) -> Result<()> {
95    let seed_content = std::fs::read_to_string(seed_path)?;
96    SqlExecutor::execute_statements_parsed(db, &seed_content).await
97}
98
99async fn table_exists(db: &dyn DatabaseProvider, table_name: &str) -> Result<bool> {
100    let result = db
101        .query_raw_with(
102            &"SELECT EXISTS (SELECT 1 FROM information_schema.tables WHERE table_schema = 'public' AND table_name = $1) as exists",
103            vec![serde_json::Value::String(table_name.to_string())],
104        )
105        .await
106        .map_err(|e| {
107            tracing::error!(error = %e, table = %table_name, "Database error checking table existence");
108            anyhow::anyhow!("Database error checking table '{}': {}", table_name, e)
109        })?;
110
111    let exists = result
112        .rows
113        .first()
114        .and_then(|row| row.get("exists"))
115        .and_then(serde_json::Value::as_bool)
116        .unwrap_or(false);
117
118    Ok(exists)
119}
120
121pub async fn install_extension_schemas(
122    registry: &ExtensionRegistry,
123    db: &dyn DatabaseProvider,
124) -> std::result::Result<(), LoaderError> {
125    install_extension_schemas_with_config(registry, db, &[]).await
126}
127
128pub async fn install_extension_schemas_with_config(
129    registry: &ExtensionRegistry,
130    db: &dyn DatabaseProvider,
131    disabled_extensions: &[String],
132) -> std::result::Result<(), LoaderError> {
133    let schema_extensions = registry.enabled_schema_extensions(disabled_extensions);
134
135    if schema_extensions.is_empty() {
136        log_no_schemas();
137        return Ok(());
138    }
139
140    log_installing_schemas(schema_extensions.len());
141
142    let migration_service = MigrationService::new(db);
143
144    for ext in schema_extensions {
145        install_extension_schema(ext.as_ref(), db).await?;
146
147        if ext.has_migrations() {
148            debug!(
149                extension = %ext.id(),
150                "Running pending migrations"
151            );
152            migration_service
153                .run_pending_migrations(ext.as_ref())
154                .await?;
155        }
156    }
157
158    log_installation_complete();
159    Ok(())
160}
161
162fn log_no_schemas() {
163    info!("No extension schemas to install");
164}
165
166fn log_installing_schemas(count: usize) {
167    info!("Installing schemas for {} extensions", count);
168}
169
170fn log_installation_complete() {
171    info!("Extension schema installation complete");
172}
173
174async fn install_extension_schema(
175    ext: &dyn Extension,
176    db: &dyn DatabaseProvider,
177) -> std::result::Result<(), LoaderError> {
178    let schemas = ext.schemas();
179    let extension_id = ext.metadata().id.to_string();
180
181    if schemas.is_empty() {
182        return Ok(());
183    }
184
185    debug!(
186        "Installing {} schema(s) for extension '{}' (weight: {})",
187        schemas.len(),
188        extension_id,
189        ext.migration_weight()
190    );
191
192    let mut all_sql = Vec::new();
193    let mut schemas_to_validate = Vec::new();
194
195    for schema in &schemas {
196        if !schema.table.is_empty()
197            && check_table_exists_for_extension(db, &schema.table, &extension_id).await?
198        {
199            debug!("  Table '{}' already exists, skipping", schema.table);
200            continue;
201        }
202
203        let sql = read_schema_sql(schema, &extension_id)?;
204        all_sql.push(sql);
205
206        if !schema.required_columns.is_empty() {
207            schemas_to_validate.push(schema);
208        }
209    }
210
211    if all_sql.is_empty() {
212        return Ok(());
213    }
214
215    let combined = all_sql.join("\n");
216    let statements = SqlExecutor::parse_sql_statements(&combined);
217
218    if !statements.is_empty() {
219        let batch = statements.join("\n");
220        if let Err(batch_err) = db.execute_raw(&batch).await {
221            debug!(
222                extension = %extension_id,
223                error = %batch_err,
224                "Batch execution failed, falling back to per-statement execution"
225            );
226            for statement in &statements {
227                db.execute_raw(statement).await.map_err(|e| {
228                    LoaderError::SchemaInstallationFailed {
229                        extension: extension_id.clone(),
230                        message: format!("Failed to execute SQL statement: {e}\n{statement}"),
231                    }
232                })?;
233            }
234        }
235    }
236
237    for schema in schemas_to_validate {
238        validate_extension_columns(db, &schema.table, &schema.required_columns, &extension_id)
239            .await?;
240    }
241
242    Ok(())
243}
244
245
246async fn check_table_exists_for_extension(
247    db: &dyn DatabaseProvider,
248    table: &str,
249    extension_id: &str,
250) -> std::result::Result<bool, LoaderError> {
251    table_exists(db, table)
252        .await
253        .map_err(|e| LoaderError::SchemaInstallationFailed {
254            extension: extension_id.to_string(),
255            message: format!("Failed to check table existence: {e}"),
256        })
257}
258
259fn read_schema_sql(
260    schema: &systemprompt_extension::SchemaDefinition,
261    extension_id: &str,
262) -> std::result::Result<String, LoaderError> {
263    match &schema.sql {
264        SchemaSource::Inline(sql) => Ok(sql.clone()),
265        SchemaSource::File(path) => {
266            std::fs::read_to_string(path).map_err(|e| LoaderError::SchemaInstallationFailed {
267                extension: extension_id.to_string(),
268                message: format!("Failed to read schema file '{}': {e}", path.display()),
269            })
270        },
271    }
272}
273
274
275async fn validate_extension_columns(
276    db: &dyn DatabaseProvider,
277    table: &str,
278    required_columns: &[String],
279    extension_id: &str,
280) -> std::result::Result<(), LoaderError> {
281    for column in required_columns {
282        validate_single_column(db, table, column, extension_id).await?;
283    }
284    Ok(())
285}
286
287async fn validate_single_column(
288    db: &dyn DatabaseProvider,
289    table: &str,
290    column: &str,
291    extension_id: &str,
292) -> std::result::Result<(), LoaderError> {
293    let result = db
294        .query_raw_with(
295            &"SELECT 1 FROM information_schema.columns WHERE table_schema = 'public' AND \
296              table_name = $1 AND column_name = $2",
297            vec![
298                serde_json::Value::String(table.to_string()),
299                serde_json::Value::String(column.to_string()),
300            ],
301        )
302        .await
303        .map_err(|e| LoaderError::SchemaInstallationFailed {
304            extension: extension_id.to_string(),
305            message: format!("Failed to validate column '{column}': {e}"),
306        })?;
307
308    if result.rows.is_empty() {
309        warn!(
310            "Extension '{}': Required column '{}' not found in table '{}'",
311            extension_id, column, table
312        );
313        return Err(LoaderError::SchemaInstallationFailed {
314            extension: extension_id.to_string(),
315            message: format!("Required column '{column}' not found in table '{table}'"),
316        });
317    }
318
319    Ok(())
320}