systemprompt_database/lifecycle/
installation.rs1use super::migrations::MigrationService;
2use crate::services::{DatabaseProvider, SqlExecutor};
3use anyhow::Result;
4use std::path::Path;
5use systemprompt_extension::{Extension, ExtensionRegistry, LoaderError, SchemaSource, SeedSource};
6use systemprompt_models::modules::{Module, ModuleSchema};
7use tracing::{debug, info, warn};
8
9#[derive(Debug, Clone, Copy)]
10pub struct ModuleInstaller;
11
12impl ModuleInstaller {
13 pub async fn install(module: &Module, db: &dyn DatabaseProvider) -> Result<()> {
14 install_module_schemas_from_source(module, db).await?;
15 install_module_seeds_from_path(module, db).await?;
16 Ok(())
17 }
18}
19
20pub async fn install_module_schemas_from_source(
21 module: &Module,
22 db: &dyn DatabaseProvider,
23) -> Result<()> {
24 let Some(schemas) = &module.schemas else {
25 return Ok(());
26 };
27
28 for schema in schemas {
29 if schema.table.is_empty() {
30 let sql = read_module_schema_sql(module, schema)?;
31 SqlExecutor::execute_statements_parsed(db, &sql).await?;
32 continue;
33 }
34
35 if !table_exists(db, &schema.table).await? {
36 let sql = read_module_schema_sql(module, schema)?;
37 SqlExecutor::execute_statements_parsed(db, &sql).await?;
38 }
39 }
40
41 Ok(())
42}
43
44fn read_module_schema_sql(module: &Module, schema: &ModuleSchema) -> Result<String> {
45 match &schema.sql {
46 SchemaSource::Inline(sql) => Ok(sql.clone()),
47 SchemaSource::File(relative_path) => {
48 let full_path = module.path.join(relative_path);
49 std::fs::read_to_string(&full_path).map_err(|e| {
50 anyhow::anyhow!(
51 "Failed to read schema file '{}' for module '{}': {e}",
52 full_path.display(),
53 module.name
54 )
55 })
56 },
57 }
58}
59
60pub async fn install_module_seeds_from_path(
61 module: &Module,
62 db: &dyn DatabaseProvider,
63) -> Result<()> {
64 let Some(seeds) = &module.seeds else {
65 return Ok(());
66 };
67
68 for seed in seeds {
69 let sql = match &seed.sql {
70 SeedSource::Inline(sql) => sql.clone(),
71 SeedSource::File(relative_path) => {
72 let seed_path = module.path.join(relative_path);
73 if !seed_path.exists() {
74 anyhow::bail!(
75 "Seed file not found for module '{}': {}",
76 module.name,
77 seed_path.display()
78 );
79 }
80 std::fs::read_to_string(&seed_path)?
81 },
82 };
83 SqlExecutor::execute_statements_parsed(db, &sql).await?;
84 }
85
86 Ok(())
87}
88
89pub async fn install_schema(db: &dyn DatabaseProvider, schema_path: &Path) -> Result<()> {
90 let schema_content = std::fs::read_to_string(schema_path)?;
91 SqlExecutor::execute_statements_parsed(db, &schema_content).await
92}
93
94pub async fn install_seed(db: &dyn DatabaseProvider, seed_path: &Path) -> Result<()> {
95 let seed_content = std::fs::read_to_string(seed_path)?;
96 SqlExecutor::execute_statements_parsed(db, &seed_content).await
97}
98
99async fn table_exists(db: &dyn DatabaseProvider, table_name: &str) -> Result<bool> {
100 let result = db
101 .query_raw_with(
102 &"SELECT EXISTS (SELECT 1 FROM information_schema.tables WHERE table_schema = 'public' AND table_name = $1) as exists",
103 vec![serde_json::Value::String(table_name.to_string())],
104 )
105 .await
106 .map_err(|e| {
107 tracing::error!(error = %e, table = %table_name, "Database error checking table existence");
108 anyhow::anyhow!("Database error checking table '{}': {}", table_name, e)
109 })?;
110
111 let exists = result
112 .rows
113 .first()
114 .and_then(|row| row.get("exists"))
115 .and_then(serde_json::Value::as_bool)
116 .unwrap_or(false);
117
118 Ok(exists)
119}
120
121pub async fn install_extension_schemas(
122 registry: &ExtensionRegistry,
123 db: &dyn DatabaseProvider,
124) -> std::result::Result<(), LoaderError> {
125 install_extension_schemas_with_config(registry, db, &[]).await
126}
127
128pub async fn install_extension_schemas_with_config(
129 registry: &ExtensionRegistry,
130 db: &dyn DatabaseProvider,
131 disabled_extensions: &[String],
132) -> std::result::Result<(), LoaderError> {
133 let schema_extensions = registry.enabled_schema_extensions(disabled_extensions);
134
135 if schema_extensions.is_empty() {
136 log_no_schemas();
137 return Ok(());
138 }
139
140 log_installing_schemas(schema_extensions.len());
141
142 let migration_service = MigrationService::new(db);
143
144 for ext in schema_extensions {
145 install_extension_schema(ext.as_ref(), db).await?;
146
147 if ext.has_migrations() {
148 debug!(
149 extension = %ext.id(),
150 "Running pending migrations"
151 );
152 migration_service
153 .run_pending_migrations(ext.as_ref())
154 .await?;
155 }
156 }
157
158 log_installation_complete();
159 Ok(())
160}
161
162fn log_no_schemas() {
163 info!("No extension schemas to install");
164}
165
166fn log_installing_schemas(count: usize) {
167 info!("Installing schemas for {} extensions", count);
168}
169
170fn log_installation_complete() {
171 info!("Extension schema installation complete");
172}
173
174async fn install_extension_schema(
175 ext: &dyn Extension,
176 db: &dyn DatabaseProvider,
177) -> std::result::Result<(), LoaderError> {
178 let schemas = ext.schemas();
179 let extension_id = ext.metadata().id.to_string();
180
181 if schemas.is_empty() {
182 return Ok(());
183 }
184
185 debug!(
186 "Installing {} schema(s) for extension '{}' (weight: {})",
187 schemas.len(),
188 extension_id,
189 ext.migration_weight()
190 );
191
192 let mut all_sql = Vec::new();
193 let mut schemas_to_validate = Vec::new();
194
195 for schema in &schemas {
196 if !schema.table.is_empty()
197 && check_table_exists_for_extension(db, &schema.table, &extension_id).await?
198 {
199 debug!(" Table '{}' already exists, skipping", schema.table);
200 continue;
201 }
202
203 let sql = read_schema_sql(schema, &extension_id)?;
204 all_sql.push(sql);
205
206 if !schema.required_columns.is_empty() {
207 schemas_to_validate.push(schema);
208 }
209 }
210
211 if all_sql.is_empty() {
212 return Ok(());
213 }
214
215 let combined = all_sql.join("\n");
216 let statements = SqlExecutor::parse_sql_statements(&combined);
217
218 if !statements.is_empty() {
219 let batch = statements.join("\n");
220 if let Err(batch_err) = db.execute_raw(&batch).await {
221 debug!(
222 extension = %extension_id,
223 error = %batch_err,
224 "Batch execution failed, falling back to per-statement execution"
225 );
226 for statement in &statements {
227 db.execute_raw(statement).await.map_err(|e| {
228 LoaderError::SchemaInstallationFailed {
229 extension: extension_id.clone(),
230 message: format!("Failed to execute SQL statement: {e}\n{statement}"),
231 }
232 })?;
233 }
234 }
235 }
236
237 for schema in schemas_to_validate {
238 validate_extension_columns(db, &schema.table, &schema.required_columns, &extension_id)
239 .await?;
240 }
241
242 Ok(())
243}
244
245
246async fn check_table_exists_for_extension(
247 db: &dyn DatabaseProvider,
248 table: &str,
249 extension_id: &str,
250) -> std::result::Result<bool, LoaderError> {
251 table_exists(db, table)
252 .await
253 .map_err(|e| LoaderError::SchemaInstallationFailed {
254 extension: extension_id.to_string(),
255 message: format!("Failed to check table existence: {e}"),
256 })
257}
258
259fn read_schema_sql(
260 schema: &systemprompt_extension::SchemaDefinition,
261 extension_id: &str,
262) -> std::result::Result<String, LoaderError> {
263 match &schema.sql {
264 SchemaSource::Inline(sql) => Ok(sql.clone()),
265 SchemaSource::File(path) => {
266 std::fs::read_to_string(path).map_err(|e| LoaderError::SchemaInstallationFailed {
267 extension: extension_id.to_string(),
268 message: format!("Failed to read schema file '{}': {e}", path.display()),
269 })
270 },
271 }
272}
273
274
275async fn validate_extension_columns(
276 db: &dyn DatabaseProvider,
277 table: &str,
278 required_columns: &[String],
279 extension_id: &str,
280) -> std::result::Result<(), LoaderError> {
281 for column in required_columns {
282 validate_single_column(db, table, column, extension_id).await?;
283 }
284 Ok(())
285}
286
287async fn validate_single_column(
288 db: &dyn DatabaseProvider,
289 table: &str,
290 column: &str,
291 extension_id: &str,
292) -> std::result::Result<(), LoaderError> {
293 let result = db
294 .query_raw_with(
295 &"SELECT 1 FROM information_schema.columns WHERE table_schema = 'public' AND \
296 table_name = $1 AND column_name = $2",
297 vec![
298 serde_json::Value::String(table.to_string()),
299 serde_json::Value::String(column.to_string()),
300 ],
301 )
302 .await
303 .map_err(|e| LoaderError::SchemaInstallationFailed {
304 extension: extension_id.to_string(),
305 message: format!("Failed to validate column '{column}': {e}"),
306 })?;
307
308 if result.rows.is_empty() {
309 warn!(
310 "Extension '{}': Required column '{}' not found in table '{}'",
311 extension_id, column, table
312 );
313 return Err(LoaderError::SchemaInstallationFailed {
314 extension: extension_id.to_string(),
315 message: format!("Required column '{column}' not found in table '{table}'"),
316 });
317 }
318
319 Ok(())
320}