Skip to main content

systemprompt_database/lifecycle/
installation.rs

1use super::migrations::MigrationService;
2use crate::services::{DatabaseProvider, SqlExecutor};
3use anyhow::Result;
4use std::path::Path;
5use systemprompt_extension::{Extension, ExtensionRegistry, LoaderError, SchemaSource, SeedSource};
6use systemprompt_models::modules::{Module, ModuleSchema};
7use tracing::{debug, info, warn};
8
9#[derive(Debug, Clone, Copy)]
10pub struct ModuleInstaller;
11
12impl ModuleInstaller {
13    pub async fn install(module: &Module, db: &dyn DatabaseProvider) -> Result<()> {
14        install_module_schemas_from_source(module, db).await?;
15        install_module_seeds_from_path(module, db).await?;
16        Ok(())
17    }
18}
19
20pub async fn install_module_schemas_from_source(
21    module: &Module,
22    db: &dyn DatabaseProvider,
23) -> Result<()> {
24    let Some(schemas) = &module.schemas else {
25        return Ok(());
26    };
27
28    for schema in schemas {
29        if schema.table.is_empty() {
30            let sql = read_module_schema_sql(module, schema)?;
31            SqlExecutor::execute_statements_parsed(db, &sql).await?;
32            continue;
33        }
34
35        if !table_exists(db, &schema.table).await? {
36            let sql = read_module_schema_sql(module, schema)?;
37            SqlExecutor::execute_statements_parsed(db, &sql).await?;
38        }
39    }
40
41    Ok(())
42}
43
44fn read_module_schema_sql(module: &Module, schema: &ModuleSchema) -> Result<String> {
45    match &schema.sql {
46        SchemaSource::Inline(sql) => Ok(sql.clone()),
47        SchemaSource::File(relative_path) => {
48            let full_path = module.path.join(relative_path);
49            std::fs::read_to_string(&full_path).map_err(|e| {
50                anyhow::anyhow!(
51                    "Failed to read schema file '{}' for module '{}': {e}",
52                    full_path.display(),
53                    module.name
54                )
55            })
56        },
57    }
58}
59
60pub async fn install_module_seeds_from_path(
61    module: &Module,
62    db: &dyn DatabaseProvider,
63) -> Result<()> {
64    let Some(seeds) = &module.seeds else {
65        return Ok(());
66    };
67
68    for seed in seeds {
69        let sql = match &seed.sql {
70            SeedSource::Inline(sql) => sql.clone(),
71            SeedSource::File(relative_path) => {
72                let seed_path = module.path.join(relative_path);
73                if !seed_path.exists() {
74                    anyhow::bail!(
75                        "Seed file not found for module '{}': {}",
76                        module.name,
77                        seed_path.display()
78                    );
79                }
80                std::fs::read_to_string(&seed_path)?
81            },
82        };
83        SqlExecutor::execute_statements_parsed(db, &sql).await?;
84    }
85
86    Ok(())
87}
88
89pub async fn install_schema(db: &dyn DatabaseProvider, schema_path: &Path) -> Result<()> {
90    let schema_content = std::fs::read_to_string(schema_path)?;
91    SqlExecutor::execute_statements_parsed(db, &schema_content).await
92}
93
94pub async fn install_seed(db: &dyn DatabaseProvider, seed_path: &Path) -> Result<()> {
95    let seed_content = std::fs::read_to_string(seed_path)?;
96    SqlExecutor::execute_statements_parsed(db, &seed_content).await
97}
98
99async fn table_exists(db: &dyn DatabaseProvider, table_name: &str) -> Result<bool> {
100    let result = db
101        .query_raw_with(
102            &"SELECT EXISTS (SELECT 1 FROM information_schema.tables WHERE table_schema = 'public' AND table_name = $1) as exists",
103            vec![serde_json::Value::String(table_name.to_string())],
104        )
105        .await
106        .map_err(|e| {
107            tracing::error!(error = %e, table = %table_name, "Database error checking table existence");
108            anyhow::anyhow!("Database error checking table '{}': {}", table_name, e)
109        })?;
110
111    let exists = result
112        .rows
113        .first()
114        .and_then(|row| row.get("exists"))
115        .and_then(serde_json::Value::as_bool)
116        .unwrap_or(false);
117
118    Ok(exists)
119}
120
121pub async fn install_extension_schemas(
122    registry: &ExtensionRegistry,
123    db: &dyn DatabaseProvider,
124) -> std::result::Result<(), LoaderError> {
125    install_extension_schemas_with_config(registry, db, &[]).await
126}
127
128pub async fn install_extension_schemas_with_config(
129    registry: &ExtensionRegistry,
130    db: &dyn DatabaseProvider,
131    disabled_extensions: &[String],
132) -> std::result::Result<(), LoaderError> {
133    let schema_extensions = registry.enabled_schema_extensions(disabled_extensions);
134
135    if schema_extensions.is_empty() {
136        log_no_schemas();
137        return Ok(());
138    }
139
140    log_installing_schemas(schema_extensions.len());
141
142    let migration_service = MigrationService::new(db);
143
144    for ext in schema_extensions {
145        install_extension_schema(ext.as_ref(), db).await?;
146
147        if ext.has_migrations() {
148            debug!(
149                extension = %ext.id(),
150                "Running pending migrations"
151            );
152            migration_service
153                .run_pending_migrations(ext.as_ref())
154                .await?;
155        }
156    }
157
158    log_installation_complete();
159    Ok(())
160}
161
162fn log_no_schemas() {
163    info!("No extension schemas to install");
164}
165
166fn log_installing_schemas(count: usize) {
167    info!("Installing schemas for {} extensions", count);
168}
169
170fn log_installation_complete() {
171    info!("Extension schema installation complete");
172}
173
174async fn install_extension_schema(
175    ext: &dyn Extension,
176    db: &dyn DatabaseProvider,
177) -> std::result::Result<(), LoaderError> {
178    let schemas = ext.schemas();
179    let extension_id = ext.metadata().id.to_string();
180
181    if schemas.is_empty() {
182        return Ok(());
183    }
184
185    debug!(
186        "Installing {} schema(s) for extension '{}' (weight: {})",
187        schemas.len(),
188        extension_id,
189        ext.migration_weight()
190    );
191
192    for schema in &schemas {
193        install_single_schema(db, schema, &extension_id).await?;
194    }
195
196    Ok(())
197}
198
199async fn install_single_schema(
200    db: &dyn DatabaseProvider,
201    schema: &systemprompt_extension::SchemaDefinition,
202    extension_id: &str,
203) -> std::result::Result<(), LoaderError> {
204    if check_table_exists_for_extension(db, &schema.table, extension_id).await? {
205        debug!("  Table '{}' already exists, skipping", schema.table);
206        return Ok(());
207    }
208
209    let sql = read_schema_sql(schema, extension_id)?;
210    execute_schema_sql(db, &sql, &schema.table, extension_id).await?;
211
212    if !schema.required_columns.is_empty() {
213        validate_extension_columns(db, &schema.table, &schema.required_columns, extension_id)
214            .await?;
215    }
216
217    Ok(())
218}
219
220async fn check_table_exists_for_extension(
221    db: &dyn DatabaseProvider,
222    table: &str,
223    extension_id: &str,
224) -> std::result::Result<bool, LoaderError> {
225    table_exists(db, table)
226        .await
227        .map_err(|e| LoaderError::SchemaInstallationFailed {
228            extension: extension_id.to_string(),
229            message: format!("Failed to check table existence: {e}"),
230        })
231}
232
233fn read_schema_sql(
234    schema: &systemprompt_extension::SchemaDefinition,
235    extension_id: &str,
236) -> std::result::Result<String, LoaderError> {
237    match &schema.sql {
238        SchemaSource::Inline(sql) => Ok(sql.clone()),
239        SchemaSource::File(path) => {
240            std::fs::read_to_string(path).map_err(|e| LoaderError::SchemaInstallationFailed {
241                extension: extension_id.to_string(),
242                message: format!("Failed to read schema file '{}': {e}", path.display()),
243            })
244        },
245    }
246}
247
248async fn execute_schema_sql(
249    db: &dyn DatabaseProvider,
250    sql: &str,
251    table: &str,
252    extension_id: &str,
253) -> std::result::Result<(), LoaderError> {
254    debug!("  Creating table '{}'", table);
255    SqlExecutor::execute_statements_parsed(db, sql)
256        .await
257        .map_err(|e| LoaderError::SchemaInstallationFailed {
258            extension: extension_id.to_string(),
259            message: format!("Failed to create table '{}': {e}", table),
260        })
261}
262
263async fn validate_extension_columns(
264    db: &dyn DatabaseProvider,
265    table: &str,
266    required_columns: &[String],
267    extension_id: &str,
268) -> std::result::Result<(), LoaderError> {
269    for column in required_columns {
270        validate_single_column(db, table, column, extension_id).await?;
271    }
272    Ok(())
273}
274
275async fn validate_single_column(
276    db: &dyn DatabaseProvider,
277    table: &str,
278    column: &str,
279    extension_id: &str,
280) -> std::result::Result<(), LoaderError> {
281    let result = db
282        .query_raw_with(
283            &"SELECT 1 FROM information_schema.columns WHERE table_schema = 'public' AND \
284              table_name = $1 AND column_name = $2",
285            vec![
286                serde_json::Value::String(table.to_string()),
287                serde_json::Value::String(column.to_string()),
288            ],
289        )
290        .await
291        .map_err(|e| LoaderError::SchemaInstallationFailed {
292            extension: extension_id.to_string(),
293            message: format!("Failed to validate column '{column}': {e}"),
294        })?;
295
296    if result.rows.is_empty() {
297        warn!(
298            "Extension '{}': Required column '{}' not found in table '{}'",
299            extension_id, column, table
300        );
301        return Err(LoaderError::SchemaInstallationFailed {
302            extension: extension_id.to_string(),
303            message: format!("Required column '{column}' not found in table '{table}'"),
304        });
305    }
306
307    Ok(())
308}