1use serde::Serialize;
7use tokio_postgres::Client;
8
9use crate::error::Result;
10
11#[derive(Debug, Clone, Serialize)]
13pub struct PreflightCheck {
14 pub name: String,
16 pub status: CheckStatus,
18 pub detail: String,
20}
21
22#[derive(Debug, Clone, Serialize, PartialEq)]
24pub enum CheckStatus {
25 Pass,
27 Warn,
29 Fail,
31}
32
33impl std::fmt::Display for CheckStatus {
34 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
35 match self {
36 CheckStatus::Pass => write!(f, "PASS"),
37 CheckStatus::Warn => write!(f, "WARN"),
38 CheckStatus::Fail => write!(f, "FAIL"),
39 }
40 }
41}
42
43#[derive(Debug, Serialize)]
45pub struct PreflightReport {
46 pub checks: Vec<PreflightCheck>,
48 pub passed: bool,
50}
51
52#[derive(Debug, Clone)]
54pub struct PreflightConfig {
55 pub enabled: bool,
57 pub max_replication_lag_mb: i64,
59 pub long_query_threshold_secs: i64,
61}
62
63impl Default for PreflightConfig {
64 fn default() -> Self {
65 Self {
66 enabled: true,
67 max_replication_lag_mb: 100,
68 long_query_threshold_secs: 300,
69 }
70 }
71}
72
73pub async fn run_preflight(client: &Client, config: &PreflightConfig) -> Result<PreflightReport> {
75 let mut checks = Vec::new();
76
77 checks.push(check_recovery_mode(client).await);
78 checks.push(check_active_connections(client).await);
79 checks.push(check_long_running_queries(client, config.long_query_threshold_secs).await);
80 checks.push(check_replication_lag(client, config.max_replication_lag_mb).await);
81 checks.push(check_database_size(client).await);
82 checks.push(check_lock_contention(client).await);
83
84 let passed = !checks.iter().any(|c| c.status == CheckStatus::Fail);
85
86 Ok(PreflightReport { checks, passed })
87}
88
89async fn check_recovery_mode(client: &Client) -> PreflightCheck {
90 match client.query_one("SELECT pg_is_in_recovery()", &[]).await {
91 Ok(row) => {
92 let in_recovery: bool = row.get(0);
93 if in_recovery {
94 PreflightCheck {
95 name: "Recovery Mode".to_string(),
96 status: CheckStatus::Fail,
97 detail: "Database is in recovery mode (read-only replica)".to_string(),
98 }
99 } else {
100 PreflightCheck {
101 name: "Recovery Mode".to_string(),
102 status: CheckStatus::Pass,
103 detail: "Not in recovery mode".to_string(),
104 }
105 }
106 }
107 Err(e) => PreflightCheck {
108 name: "Recovery Mode".to_string(),
109 status: CheckStatus::Warn,
110 detail: format!("Could not check: {}", e),
111 },
112 }
113}
114
115async fn check_active_connections(client: &Client) -> PreflightCheck {
116 let query = "SELECT count(*)::int as active,
117 (SELECT setting::int FROM pg_settings WHERE name = 'max_connections') as max_conn
118 FROM pg_stat_activity";
119 match client.query_one(query, &[]).await {
120 Ok(row) => {
121 let active: i32 = row.get(0);
122 let max_conn: i32 = row.get(1);
123 let pct = (active as f64 / max_conn as f64) * 100.0;
124 let status = if pct >= 80.0 {
125 CheckStatus::Warn
126 } else {
127 CheckStatus::Pass
128 };
129 PreflightCheck {
130 name: "Active Connections".to_string(),
131 status,
132 detail: format!("{}/{} ({:.0}%)", active, max_conn, pct),
133 }
134 }
135 Err(e) => PreflightCheck {
136 name: "Active Connections".to_string(),
137 status: CheckStatus::Warn,
138 detail: format!("Could not check: {}", e),
139 },
140 }
141}
142
143async fn check_long_running_queries(client: &Client, threshold_secs: i64) -> PreflightCheck {
144 let query = format!(
145 "SELECT count(*)::int FROM pg_stat_activity
146 WHERE state = 'active' AND now() - query_start > interval '{} seconds'",
147 threshold_secs
148 );
149 match client.query_one(&query, &[]).await {
150 Ok(row) => {
151 let count: i32 = row.get(0);
152 if count > 0 {
153 PreflightCheck {
154 name: "Long-Running Queries".to_string(),
155 status: CheckStatus::Warn,
156 detail: format!(
157 "{} query(ies) running longer than {}s",
158 count, threshold_secs
159 ),
160 }
161 } else {
162 PreflightCheck {
163 name: "Long-Running Queries".to_string(),
164 status: CheckStatus::Pass,
165 detail: format!("No queries running longer than {}s", threshold_secs),
166 }
167 }
168 }
169 Err(e) => PreflightCheck {
170 name: "Long-Running Queries".to_string(),
171 status: CheckStatus::Warn,
172 detail: format!("Could not check: {}", e),
173 },
174 }
175}
176
177async fn check_replication_lag(client: &Client, max_lag_mb: i64) -> PreflightCheck {
178 let query = "SELECT pg_wal_lsn_diff(pg_current_wal_lsn(), replay_lsn)
179 FROM pg_stat_replication
180 ORDER BY replay_lsn ASC LIMIT 1";
181 match client.query_opt(query, &[]).await {
182 Ok(Some(row)) => {
183 let lag_bytes: Option<i64> = row.get(0);
184 let lag_mb = lag_bytes.unwrap_or(0) / (1024 * 1024);
185 let status = if lag_mb > max_lag_mb {
186 CheckStatus::Warn
187 } else {
188 CheckStatus::Pass
189 };
190 PreflightCheck {
191 name: "Replication Lag".to_string(),
192 status,
193 detail: format!("{}MB (threshold: {}MB)", lag_mb, max_lag_mb),
194 }
195 }
196 Ok(None) => PreflightCheck {
197 name: "Replication Lag".to_string(),
198 status: CheckStatus::Pass,
199 detail: "No replicas connected".to_string(),
200 },
201 Err(_) => PreflightCheck {
202 name: "Replication Lag".to_string(),
203 status: CheckStatus::Pass,
204 detail: "Not a primary or no replication configured".to_string(),
205 },
206 }
207}
208
209async fn check_database_size(client: &Client) -> PreflightCheck {
210 match client
211 .query_one("SELECT pg_database_size(current_database())", &[])
212 .await
213 {
214 Ok(row) => {
215 let size_bytes: i64 = row.get(0);
216 let size_mb = size_bytes / (1024 * 1024);
217 let detail = if size_mb > 1024 {
218 format!("{:.1}GB", size_mb as f64 / 1024.0)
219 } else {
220 format!("{}MB", size_mb)
221 };
222 PreflightCheck {
223 name: "Database Size".to_string(),
224 status: CheckStatus::Pass,
225 detail,
226 }
227 }
228 Err(e) => PreflightCheck {
229 name: "Database Size".to_string(),
230 status: CheckStatus::Warn,
231 detail: format!("Could not check: {}", e),
232 },
233 }
234}
235
236async fn check_lock_contention(client: &Client) -> PreflightCheck {
237 match client
238 .query_one("SELECT count(*)::int FROM pg_locks WHERE NOT granted", &[])
239 .await
240 {
241 Ok(row) => {
242 let blocked: i32 = row.get(0);
243 if blocked > 0 {
244 PreflightCheck {
245 name: "Lock Contention".to_string(),
246 status: CheckStatus::Warn,
247 detail: format!("{} blocked lock request(s)", blocked),
248 }
249 } else {
250 PreflightCheck {
251 name: "Lock Contention".to_string(),
252 status: CheckStatus::Pass,
253 detail: "No blocked locks".to_string(),
254 }
255 }
256 }
257 Err(e) => PreflightCheck {
258 name: "Lock Contention".to_string(),
259 status: CheckStatus::Warn,
260 detail: format!("Could not check: {}", e),
261 },
262 }
263}