systemprompt_database/lifecycle/
installation.rs1use super::migrations::MigrationService;
2use crate::services::{DatabaseProvider, SqlExecutor};
3use anyhow::Result;
4use std::path::Path;
5use systemprompt_extension::{Extension, ExtensionRegistry, LoaderError, SchemaSource, SeedSource};
6use systemprompt_models::modules::{Module, ModuleSchema};
7use tracing::{debug, info, warn};
8
9#[derive(Debug, Clone, Copy)]
10pub struct ModuleInstaller;
11
12impl ModuleInstaller {
13 pub async fn install(module: &Module, db: &dyn DatabaseProvider) -> Result<()> {
14 install_module_schemas_from_source(module, db).await?;
15 install_module_seeds_from_path(module, db).await?;
16 Ok(())
17 }
18}
19
20pub async fn install_module_schemas_from_source(
21 module: &Module,
22 db: &dyn DatabaseProvider,
23) -> Result<()> {
24 let Some(schemas) = &module.schemas else {
25 return Ok(());
26 };
27
28 for schema in schemas {
29 if schema.table.is_empty() {
30 let sql = read_module_schema_sql(module, schema)?;
31 SqlExecutor::execute_statements_parsed(db, &sql).await?;
32 continue;
33 }
34
35 if !table_exists(db, &schema.table).await? {
36 let sql = read_module_schema_sql(module, schema)?;
37 SqlExecutor::execute_statements_parsed(db, &sql).await?;
38 }
39 }
40
41 Ok(())
42}
43
44fn read_module_schema_sql(module: &Module, schema: &ModuleSchema) -> Result<String> {
45 match &schema.sql {
46 SchemaSource::Inline(sql) => Ok(sql.clone()),
47 SchemaSource::File(relative_path) => {
48 let full_path = module.path.join(relative_path);
49 std::fs::read_to_string(&full_path).map_err(|e| {
50 anyhow::anyhow!(
51 "Failed to read schema file '{}' for module '{}': {e}",
52 full_path.display(),
53 module.name
54 )
55 })
56 },
57 }
58}
59
60pub async fn install_module_seeds_from_path(
61 module: &Module,
62 db: &dyn DatabaseProvider,
63) -> Result<()> {
64 let Some(seeds) = &module.seeds else {
65 return Ok(());
66 };
67
68 for seed in seeds {
69 let sql = match &seed.sql {
70 SeedSource::Inline(sql) => sql.clone(),
71 SeedSource::File(relative_path) => {
72 let seed_path = module.path.join(relative_path);
73 if !seed_path.exists() {
74 anyhow::bail!(
75 "Seed file not found for module '{}': {}",
76 module.name,
77 seed_path.display()
78 );
79 }
80 std::fs::read_to_string(&seed_path)?
81 },
82 };
83 SqlExecutor::execute_statements_parsed(db, &sql).await?;
84 }
85
86 Ok(())
87}
88
89pub async fn install_schema(db: &dyn DatabaseProvider, schema_path: &Path) -> Result<()> {
90 let schema_content = std::fs::read_to_string(schema_path)?;
91 SqlExecutor::execute_statements_parsed(db, &schema_content).await
92}
93
94pub async fn install_seed(db: &dyn DatabaseProvider, seed_path: &Path) -> Result<()> {
95 let seed_content = std::fs::read_to_string(seed_path)?;
96 SqlExecutor::execute_statements_parsed(db, &seed_content).await
97}
98
99async fn table_exists(db: &dyn DatabaseProvider, table_name: &str) -> Result<bool> {
100 let result = db
101 .query_raw_with(
102 &"SELECT EXISTS (SELECT 1 FROM information_schema.tables WHERE table_schema = 'public' AND table_name = $1) as exists",
103 vec![serde_json::Value::String(table_name.to_string())],
104 )
105 .await
106 .map_err(|e| {
107 tracing::error!(error = %e, table = %table_name, "Database error checking table existence");
108 anyhow::anyhow!("Database error checking table '{}': {}", table_name, e)
109 })?;
110
111 let exists = result
112 .rows
113 .first()
114 .and_then(|row| row.get("exists"))
115 .and_then(serde_json::Value::as_bool)
116 .unwrap_or(false);
117
118 Ok(exists)
119}
120
121pub async fn install_extension_schemas(
122 registry: &ExtensionRegistry,
123 db: &dyn DatabaseProvider,
124) -> std::result::Result<(), LoaderError> {
125 install_extension_schemas_with_config(registry, db, &[]).await
126}
127
128pub async fn install_extension_schemas_with_config(
129 registry: &ExtensionRegistry,
130 db: &dyn DatabaseProvider,
131 disabled_extensions: &[String],
132) -> std::result::Result<(), LoaderError> {
133 let schema_extensions = registry.enabled_schema_extensions(disabled_extensions);
134
135 if schema_extensions.is_empty() {
136 log_no_schemas();
137 return Ok(());
138 }
139
140 log_installing_schemas(schema_extensions.len());
141
142 let migration_service = MigrationService::new(db);
143
144 for ext in schema_extensions {
145 install_extension_schema(ext.as_ref(), db).await?;
146
147 if ext.has_migrations() {
148 debug!(
149 extension = %ext.id(),
150 "Running pending migrations"
151 );
152 migration_service
153 .run_pending_migrations(ext.as_ref())
154 .await?;
155 }
156 }
157
158 log_installation_complete();
159 Ok(())
160}
161
162fn log_no_schemas() {
163 info!("No extension schemas to install");
164}
165
166fn log_installing_schemas(count: usize) {
167 info!("Installing schemas for {} extensions", count);
168}
169
170fn log_installation_complete() {
171 info!("Extension schema installation complete");
172}
173
174async fn install_extension_schema(
175 ext: &dyn Extension,
176 db: &dyn DatabaseProvider,
177) -> std::result::Result<(), LoaderError> {
178 let schemas = ext.schemas();
179 let extension_id = ext.metadata().id.to_string();
180
181 if schemas.is_empty() {
182 return Ok(());
183 }
184
185 debug!(
186 "Installing {} schema(s) for extension '{}' (weight: {})",
187 schemas.len(),
188 extension_id,
189 ext.migration_weight()
190 );
191
192 for schema in &schemas {
193 install_single_schema(db, schema, &extension_id).await?;
194 }
195
196 Ok(())
197}
198
199async fn install_single_schema(
200 db: &dyn DatabaseProvider,
201 schema: &systemprompt_extension::SchemaDefinition,
202 extension_id: &str,
203) -> std::result::Result<(), LoaderError> {
204 if check_table_exists_for_extension(db, &schema.table, extension_id).await? {
205 debug!(" Table '{}' already exists, skipping", schema.table);
206 return Ok(());
207 }
208
209 let sql = read_schema_sql(schema, extension_id)?;
210 execute_schema_sql(db, &sql, &schema.table, extension_id).await?;
211
212 if !schema.required_columns.is_empty() {
213 validate_extension_columns(db, &schema.table, &schema.required_columns, extension_id)
214 .await?;
215 }
216
217 Ok(())
218}
219
220async fn check_table_exists_for_extension(
221 db: &dyn DatabaseProvider,
222 table: &str,
223 extension_id: &str,
224) -> std::result::Result<bool, LoaderError> {
225 table_exists(db, table)
226 .await
227 .map_err(|e| LoaderError::SchemaInstallationFailed {
228 extension: extension_id.to_string(),
229 message: format!("Failed to check table existence: {e}"),
230 })
231}
232
233fn read_schema_sql(
234 schema: &systemprompt_extension::SchemaDefinition,
235 extension_id: &str,
236) -> std::result::Result<String, LoaderError> {
237 match &schema.sql {
238 SchemaSource::Inline(sql) => Ok(sql.clone()),
239 SchemaSource::File(path) => {
240 std::fs::read_to_string(path).map_err(|e| LoaderError::SchemaInstallationFailed {
241 extension: extension_id.to_string(),
242 message: format!("Failed to read schema file '{}': {e}", path.display()),
243 })
244 },
245 }
246}
247
248async fn execute_schema_sql(
249 db: &dyn DatabaseProvider,
250 sql: &str,
251 table: &str,
252 extension_id: &str,
253) -> std::result::Result<(), LoaderError> {
254 debug!(" Creating table '{}'", table);
255 SqlExecutor::execute_statements_parsed(db, sql)
256 .await
257 .map_err(|e| LoaderError::SchemaInstallationFailed {
258 extension: extension_id.to_string(),
259 message: format!("Failed to create table '{}': {e}", table),
260 })
261}
262
263async fn validate_extension_columns(
264 db: &dyn DatabaseProvider,
265 table: &str,
266 required_columns: &[String],
267 extension_id: &str,
268) -> std::result::Result<(), LoaderError> {
269 for column in required_columns {
270 validate_single_column(db, table, column, extension_id).await?;
271 }
272 Ok(())
273}
274
275async fn validate_single_column(
276 db: &dyn DatabaseProvider,
277 table: &str,
278 column: &str,
279 extension_id: &str,
280) -> std::result::Result<(), LoaderError> {
281 let result = db
282 .query_raw_with(
283 &"SELECT 1 FROM information_schema.columns WHERE table_schema = 'public' AND \
284 table_name = $1 AND column_name = $2",
285 vec![
286 serde_json::Value::String(table.to_string()),
287 serde_json::Value::String(column.to_string()),
288 ],
289 )
290 .await
291 .map_err(|e| LoaderError::SchemaInstallationFailed {
292 extension: extension_id.to_string(),
293 message: format!("Failed to validate column '{column}': {e}"),
294 })?;
295
296 if result.rows.is_empty() {
297 warn!(
298 "Extension '{}': Required column '{}' not found in table '{}'",
299 extension_id, column, table
300 );
301 return Err(LoaderError::SchemaInstallationFailed {
302 extension: extension_id.to_string(),
303 message: format!("Required column '{column}' not found in table '{table}'"),
304 });
305 }
306
307 Ok(())
308}