1use std::collections::{HashMap, HashSet, VecDeque};
7
8use serde::Serialize;
9use tokio_postgres::Client;
10
11use crate::config::{DatabaseConfig, HooksConfig, MigrationSettings, WaypointConfig};
12use crate::error::{Result, WaypointError};
13
14#[derive(Debug, Clone)]
16pub struct NamedDatabaseConfig {
17 pub name: String,
19 pub database: DatabaseConfig,
21 pub migrations: MigrationSettings,
23 pub hooks: HooksConfig,
25 pub placeholders: HashMap<String, String>,
27 pub depends_on: Vec<String>,
29}
30
31impl NamedDatabaseConfig {
32 pub fn to_waypoint_config(&self) -> WaypointConfig {
34 WaypointConfig {
35 database: self.database.clone(),
36 migrations: self.migrations.clone(),
37 hooks: self.hooks.clone(),
38 placeholders: self.placeholders.clone(),
39 ..WaypointConfig::default()
40 }
41 }
42}
43
44pub struct MultiWaypoint {
46 pub databases: Vec<NamedDatabaseConfig>,
48}
49
50#[derive(Debug, Serialize)]
52pub struct DatabaseResult {
53 pub name: String,
55 pub success: bool,
57 pub message: String,
59}
60
61#[derive(Debug, Serialize)]
63pub struct MultiResult {
64 pub results: Vec<DatabaseResult>,
66 pub all_succeeded: bool,
68}
69
70impl MultiWaypoint {
71 pub fn execution_order(databases: &[NamedDatabaseConfig]) -> Result<Vec<String>> {
76 let all_names: HashSet<&str> = databases.iter().map(|d| d.name.as_str()).collect();
77
78 let mut in_degree: HashMap<&str, usize> = HashMap::new();
80 let mut reverse_edges: HashMap<&str, Vec<&str>> = HashMap::new();
81
82 for db in databases {
83 in_degree.entry(db.name.as_str()).or_insert(0);
84 for dep in &db.depends_on {
85 if !all_names.contains(dep.as_str()) {
86 return Err(WaypointError::DatabaseNotFound {
87 name: dep.clone(),
88 available: all_names.iter().copied().collect::<Vec<_>>().join(", "),
89 });
90 }
91 *in_degree.entry(db.name.as_str()).or_insert(0) += 1;
92 reverse_edges
93 .entry(dep.as_str())
94 .or_default()
95 .push(db.name.as_str());
96 }
97 }
98
99 let mut queue: VecDeque<&str> = VecDeque::new();
100 for (&name, °) in &in_degree {
101 if deg == 0 {
102 queue.push_back(name);
103 }
104 }
105
106 let mut sorted = Vec::new();
107 while let Some(name) = queue.pop_front() {
108 sorted.push(name.to_string());
109 if let Some(dependents) = reverse_edges.get(name) {
110 for &dep in dependents {
111 let deg = in_degree
112 .get_mut(dep)
113 .expect("dependency not found in in_degree map");
114 *deg -= 1;
115 if *deg == 0 {
116 queue.push_back(dep);
117 }
118 }
119 }
120 }
121
122 if sorted.len() != databases.len() {
123 let in_cycle: Vec<&str> = in_degree
124 .iter()
125 .filter(|(_, deg)| **deg > 0)
126 .map(|(&name, _)| name)
127 .collect();
128 return Err(WaypointError::MultiDbDependencyCycle {
129 path: in_cycle.join(" -> "),
130 });
131 }
132
133 Ok(sorted)
134 }
135
136 pub async fn connect(
138 databases: &[NamedDatabaseConfig],
139 filter: Option<&str>,
140 ) -> Result<HashMap<String, Client>> {
141 let mut clients = HashMap::new();
142
143 for db in databases {
144 if let Some(name_filter) = filter {
145 if db.name != name_filter {
146 continue;
147 }
148 }
149
150 let config = db.to_waypoint_config();
151 let conn_string = config.connection_string()?;
152 let client = crate::db::connect_with_full_config(
153 &conn_string,
154 &config.database.ssl_mode,
155 config.database.connect_retries,
156 config.database.connect_timeout_secs,
157 config.database.statement_timeout_secs,
158 config.database.keepalive_secs,
159 )
160 .await?;
161 clients.insert(db.name.clone(), client);
162 }
163
164 if let Some(name_filter) = filter {
165 if !clients.contains_key(name_filter) {
166 let available = databases
167 .iter()
168 .map(|d| d.name.clone())
169 .collect::<Vec<_>>()
170 .join(", ");
171 return Err(WaypointError::DatabaseNotFound {
172 name: name_filter.to_string(),
173 available,
174 });
175 }
176 }
177
178 Ok(clients)
179 }
180
181 pub async fn migrate(
183 databases: &[NamedDatabaseConfig],
184 clients: &HashMap<String, Client>,
185 order: &[String],
186 target_version: Option<&str>,
187 fail_fast: bool,
188 ) -> Result<MultiResult> {
189 let mut results = Vec::new();
190
191 for name in order {
192 let db = databases.iter().find(|d| &d.name == name);
193 let client = clients.get(name);
194
195 match (db, client) {
196 (Some(db), Some(client)) => {
197 let config = db.to_waypoint_config();
198 match crate::commands::migrate::execute(client, &config, target_version).await {
199 Ok(report) => {
200 results.push(DatabaseResult {
201 name: name.clone(),
202 success: true,
203 message: format!(
204 "Applied {} migration(s) ({}ms)",
205 report.migrations_applied, report.total_time_ms
206 ),
207 });
208 }
209 Err(e) => {
210 results.push(DatabaseResult {
211 name: name.clone(),
212 success: false,
213 message: format!("{}", e),
214 });
215 if fail_fast {
216 break;
217 }
218 }
219 }
220 }
221 _ => {
222 results.push(DatabaseResult {
223 name: name.clone(),
224 success: false,
225 message: "Database not connected".to_string(),
226 });
227 if fail_fast {
228 break;
229 }
230 }
231 }
232 }
233
234 let all_succeeded = results.iter().all(|r| r.success);
235 Ok(MultiResult {
236 results,
237 all_succeeded,
238 })
239 }
240
241 pub async fn info(
243 databases: &[NamedDatabaseConfig],
244 clients: &HashMap<String, Client>,
245 order: &[String],
246 ) -> Result<HashMap<String, Vec<crate::commands::info::MigrationInfo>>> {
247 let mut all_info = HashMap::new();
248
249 for name in order {
250 let db = databases.iter().find(|d| &d.name == name);
251 let client = clients.get(name);
252
253 if let (Some(db), Some(client)) = (db, client) {
254 let config = db.to_waypoint_config();
255 let info = crate::commands::info::execute(client, &config).await?;
256 all_info.insert(name.clone(), info);
257 }
258 }
259
260 Ok(all_info)
261 }
262}