Skip to main content

waypoint_core/
multi.rs

1//! Multi-database orchestration.
2//!
3//! Allows managing migrations across multiple named databases
4//! with dependency ordering between them.
5
6use std::collections::{HashMap, HashSet, VecDeque};
7
8use serde::Serialize;
9use tokio_postgres::Client;
10
11use crate::config::{DatabaseConfig, HooksConfig, MigrationSettings, WaypointConfig};
12use crate::error::{Result, WaypointError};
13
14/// Configuration for a single named database within a multi-db setup.
15#[derive(Debug, Clone)]
16pub struct NamedDatabaseConfig {
17    /// Unique logical name identifying this database.
18    pub name: String,
19    /// Database connection configuration.
20    pub database: DatabaseConfig,
21    /// Migration settings for this database.
22    pub migrations: MigrationSettings,
23    /// Hook configuration for this database.
24    pub hooks: HooksConfig,
25    /// Placeholder key-value pairs for SQL template substitution.
26    pub placeholders: HashMap<String, String>,
27    /// Names of other databases that must be migrated before this one.
28    pub depends_on: Vec<String>,
29}
30
31impl NamedDatabaseConfig {
32    /// Convert to a standalone WaypointConfig for running commands.
33    pub fn to_waypoint_config(&self) -> WaypointConfig {
34        WaypointConfig {
35            database: self.database.clone(),
36            migrations: self.migrations.clone(),
37            hooks: self.hooks.clone(),
38            placeholders: self.placeholders.clone(),
39            ..WaypointConfig::default()
40        }
41    }
42}
43
44/// Multi-database orchestration entry point.
45pub struct MultiWaypoint {
46    /// List of all database configurations to orchestrate.
47    pub databases: Vec<NamedDatabaseConfig>,
48}
49
50/// Result from a multi-db operation on a single database.
51#[derive(Debug, Serialize)]
52pub struct DatabaseResult {
53    /// Logical name of the database.
54    pub name: String,
55    /// Whether the operation succeeded on this database.
56    pub success: bool,
57    /// Human-readable summary of the operation result.
58    pub message: String,
59}
60
61/// Aggregate result from a multi-db operation.
62#[derive(Debug, Serialize)]
63pub struct MultiResult {
64    /// Per-database operation results.
65    pub results: Vec<DatabaseResult>,
66    /// Whether every database operation succeeded.
67    pub all_succeeded: bool,
68}
69
70impl MultiWaypoint {
71    /// Determine execution order based on depends_on relationships (Kahn's algorithm).
72    ///
73    /// Uses borrowed `&str` references internally to avoid cloning database names
74    /// during the topological sort; only clones into owned `String`s for the output.
75    pub fn execution_order(databases: &[NamedDatabaseConfig]) -> Result<Vec<String>> {
76        let all_names: HashSet<&str> = databases.iter().map(|d| d.name.as_str()).collect();
77
78        // Build in-degree map using borrowed names
79        let mut in_degree: HashMap<&str, usize> = HashMap::new();
80        let mut reverse_edges: HashMap<&str, Vec<&str>> = HashMap::new();
81
82        for db in databases {
83            in_degree.entry(db.name.as_str()).or_insert(0);
84            for dep in &db.depends_on {
85                if !all_names.contains(dep.as_str()) {
86                    return Err(WaypointError::DatabaseNotFound {
87                        name: dep.clone(),
88                        available: all_names.iter().copied().collect::<Vec<_>>().join(", "),
89                    });
90                }
91                *in_degree.entry(db.name.as_str()).or_insert(0) += 1;
92                reverse_edges
93                    .entry(dep.as_str())
94                    .or_default()
95                    .push(db.name.as_str());
96            }
97        }
98
99        let mut queue: VecDeque<&str> = VecDeque::new();
100        for (&name, &deg) in &in_degree {
101            if deg == 0 {
102                queue.push_back(name);
103            }
104        }
105
106        let mut sorted = Vec::new();
107        while let Some(name) = queue.pop_front() {
108            sorted.push(name.to_string());
109            if let Some(dependents) = reverse_edges.get(name) {
110                for &dep in dependents {
111                    let deg = in_degree
112                        .get_mut(dep)
113                        .expect("dependency not found in in_degree map");
114                    *deg -= 1;
115                    if *deg == 0 {
116                        queue.push_back(dep);
117                    }
118                }
119            }
120        }
121
122        if sorted.len() != databases.len() {
123            let in_cycle: Vec<&str> = in_degree
124                .iter()
125                .filter(|(_, deg)| **deg > 0)
126                .map(|(&name, _)| name)
127                .collect();
128            return Err(WaypointError::MultiDbDependencyCycle {
129                path: in_cycle.join(" -> "),
130            });
131        }
132
133        Ok(sorted)
134    }
135
136    /// Connect to all databases (or a filtered subset).
137    pub async fn connect(
138        databases: &[NamedDatabaseConfig],
139        filter: Option<&str>,
140    ) -> Result<HashMap<String, Client>> {
141        let mut clients = HashMap::new();
142
143        for db in databases {
144            if let Some(name_filter) = filter {
145                if db.name != name_filter {
146                    continue;
147                }
148            }
149
150            let config = db.to_waypoint_config();
151            let conn_string = config.connection_string()?;
152            let client = crate::db::connect_with_full_config(
153                &conn_string,
154                &config.database.ssl_mode,
155                config.database.connect_retries,
156                config.database.connect_timeout_secs,
157                config.database.statement_timeout_secs,
158                config.database.keepalive_secs,
159            )
160            .await?;
161            clients.insert(db.name.clone(), client);
162        }
163
164        if let Some(name_filter) = filter {
165            if !clients.contains_key(name_filter) {
166                let available = databases
167                    .iter()
168                    .map(|d| d.name.clone())
169                    .collect::<Vec<_>>()
170                    .join(", ");
171                return Err(WaypointError::DatabaseNotFound {
172                    name: name_filter.to_string(),
173                    available,
174                });
175            }
176        }
177
178        Ok(clients)
179    }
180
181    /// Run migrate on all databases in dependency order.
182    pub async fn migrate(
183        databases: &[NamedDatabaseConfig],
184        clients: &HashMap<String, Client>,
185        order: &[String],
186        target_version: Option<&str>,
187        fail_fast: bool,
188    ) -> Result<MultiResult> {
189        let mut results = Vec::new();
190
191        for name in order {
192            let db = databases.iter().find(|d| &d.name == name);
193            let client = clients.get(name);
194
195            match (db, client) {
196                (Some(db), Some(client)) => {
197                    let config = db.to_waypoint_config();
198                    match crate::commands::migrate::execute(client, &config, target_version).await {
199                        Ok(report) => {
200                            results.push(DatabaseResult {
201                                name: name.clone(),
202                                success: true,
203                                message: format!(
204                                    "Applied {} migration(s) ({}ms)",
205                                    report.migrations_applied, report.total_time_ms
206                                ),
207                            });
208                        }
209                        Err(e) => {
210                            results.push(DatabaseResult {
211                                name: name.clone(),
212                                success: false,
213                                message: format!("{}", e),
214                            });
215                            if fail_fast {
216                                break;
217                            }
218                        }
219                    }
220                }
221                _ => {
222                    results.push(DatabaseResult {
223                        name: name.clone(),
224                        success: false,
225                        message: "Database not connected".to_string(),
226                    });
227                    if fail_fast {
228                        break;
229                    }
230                }
231            }
232        }
233
234        let all_succeeded = results.iter().all(|r| r.success);
235        Ok(MultiResult {
236            results,
237            all_succeeded,
238        })
239    }
240
241    /// Run info on all databases in dependency order.
242    pub async fn info(
243        databases: &[NamedDatabaseConfig],
244        clients: &HashMap<String, Client>,
245        order: &[String],
246    ) -> Result<HashMap<String, Vec<crate::commands::info::MigrationInfo>>> {
247        let mut all_info = HashMap::new();
248
249        for name in order {
250            let db = databases.iter().find(|d| &d.name == name);
251            let client = clients.get(name);
252
253            if let (Some(db), Some(client)) = (db, client) {
254                let config = db.to_waypoint_config();
255                let info = crate::commands::info::execute(client, &config).await?;
256                all_info.insert(name.clone(), info);
257            }
258        }
259
260        Ok(all_info)
261    }
262}