Skip to main content

waypoint_core/
schema.rs

1//! PostgreSQL schema introspection, diff, and DDL generation.
2//!
3//! Used by diff, drift, and snapshot commands.
4
5use std::collections::{HashMap, HashSet};
6
7use serde::Serialize;
8use tokio_postgres::Client;
9
10use crate::db::quote_ident;
11use crate::error::Result;
12
13/// Complete snapshot of a PostgreSQL schema.
14#[derive(Debug, Clone, Serialize, PartialEq)]
15pub struct SchemaSnapshot {
16    /// All base tables in the schema.
17    pub tables: Vec<TableDef>,
18    /// All views (regular and materialized) in the schema.
19    pub views: Vec<ViewDef>,
20    /// All indexes in the schema.
21    pub indexes: Vec<IndexDef>,
22    /// All sequences in the schema.
23    pub sequences: Vec<SequenceDef>,
24    /// All functions and procedures in the schema.
25    pub functions: Vec<FunctionDef>,
26    /// All enum types in the schema.
27    pub enums: Vec<EnumDef>,
28    /// All table constraints in the schema.
29    pub constraints: Vec<ConstraintDef>,
30    /// All triggers in the schema.
31    pub triggers: Vec<TriggerDef>,
32    /// Names of installed extensions (excluding plpgsql).
33    pub extensions: Vec<String>,
34}
35
36/// Definition of a database table.
37#[derive(Debug, Clone, Serialize, PartialEq)]
38pub struct TableDef {
39    /// Schema the table belongs to.
40    pub schema: String,
41    /// Name of the table.
42    pub name: String,
43    /// Columns belonging to this table.
44    pub columns: Vec<ColumnDef>,
45}
46
47/// Definition of a table column.
48#[derive(Debug, Clone, Serialize, PartialEq)]
49pub struct ColumnDef {
50    /// Name of the column.
51    pub name: String,
52    /// SQL data type of the column.
53    pub data_type: String,
54    /// Whether the column allows NULL values.
55    pub is_nullable: bool,
56    /// Default value expression, if any.
57    pub default: Option<String>,
58    /// Position of the column within its table (1-based).
59    pub ordinal_position: i32,
60}
61
62/// Definition of a database view.
63#[derive(Debug, Clone, Serialize, PartialEq)]
64pub struct ViewDef {
65    /// Schema the view belongs to.
66    pub schema: String,
67    /// Name of the view.
68    pub name: String,
69    /// SQL definition body of the view.
70    pub definition: String,
71    /// Whether this is a materialized view.
72    pub is_materialized: bool,
73}
74
75/// Definition of a database index.
76#[derive(Debug, Clone, Serialize, PartialEq)]
77pub struct IndexDef {
78    /// Schema the index belongs to.
79    pub schema: String,
80    /// Name of the index.
81    pub name: String,
82    /// Name of the table the index is built on.
83    pub table_name: String,
84    /// Full CREATE INDEX DDL statement.
85    pub definition: String,
86    /// Whether this is a unique index.
87    pub is_unique: bool,
88}
89
90/// Definition of a database sequence.
91#[derive(Debug, Clone, Serialize, PartialEq)]
92pub struct SequenceDef {
93    /// Schema the sequence belongs to.
94    pub schema: String,
95    /// Name of the sequence.
96    pub name: String,
97    /// Data type of the sequence (e.g. bigint).
98    pub data_type: String,
99}
100
101/// Definition of a database function or procedure.
102#[derive(Debug, Clone, Serialize, PartialEq)]
103pub struct FunctionDef {
104    /// Schema the function belongs to.
105    pub schema: String,
106    /// Name of the function.
107    pub name: String,
108    /// Function argument signature.
109    pub arguments: String,
110    /// Return type of the function.
111    pub return_type: String,
112    /// Implementation language (e.g. plpgsql, sql).
113    pub language: String,
114    /// Full function definition body.
115    pub definition: String,
116}
117
118/// Definition of a PostgreSQL enum type.
119#[derive(Debug, Clone, Serialize, PartialEq)]
120pub struct EnumDef {
121    /// Schema the enum belongs to.
122    pub schema: String,
123    /// Name of the enum type.
124    pub name: String,
125    /// Ordered list of enum label values.
126    pub values: Vec<String>,
127}
128
129/// Definition of a table constraint.
130#[derive(Debug, Clone, Serialize, PartialEq)]
131pub struct ConstraintDef {
132    /// Schema the constraint belongs to.
133    pub schema: String,
134    /// Name of the table the constraint is on.
135    pub table_name: String,
136    /// Name of the constraint.
137    pub name: String,
138    /// Type of constraint (e.g. PRIMARY KEY, UNIQUE, FOREIGN KEY, CHECK).
139    pub constraint_type: String,
140    /// Full constraint definition expression.
141    pub definition: String,
142}
143
144/// Definition of a database trigger.
145#[derive(Debug, Clone, Serialize, PartialEq)]
146pub struct TriggerDef {
147    /// Schema the trigger belongs to.
148    pub schema: String,
149    /// Name of the table the trigger is attached to.
150    pub table_name: String,
151    /// Name of the trigger.
152    pub name: String,
153    /// Action statement executed by the trigger.
154    pub definition: String,
155}
156
157/// Differences between two schema snapshots.
158#[derive(Debug, Clone, Serialize)]
159pub enum SchemaDiff {
160    /// A table was added in the target schema.
161    TableAdded(TableDef),
162    /// A table was dropped from the target schema.
163    TableDropped(String),
164    /// A column was added to an existing table.
165    ColumnAdded { table: String, column: ColumnDef },
166    /// A column was dropped from an existing table.
167    ColumnDropped { table: String, column: String },
168    /// A column definition was altered in an existing table.
169    ColumnAltered {
170        table: String,
171        column: String,
172        from: ColumnDef,
173        to: ColumnDef,
174    },
175    /// An index was added in the target schema.
176    IndexAdded(IndexDef),
177    /// An index was dropped from the target schema.
178    IndexDropped(String),
179    /// A view was added in the target schema.
180    ViewAdded(ViewDef),
181    /// A view was dropped from the target schema.
182    ViewDropped(String),
183    /// A view definition was altered.
184    ViewAltered {
185        name: String,
186        from: String,
187        to: String,
188    },
189    /// A sequence was added in the target schema.
190    SequenceAdded(SequenceDef),
191    /// A sequence was dropped from the target schema.
192    SequenceDropped(String),
193    /// A function was added in the target schema.
194    FunctionAdded(FunctionDef),
195    /// A function was dropped from the target schema.
196    FunctionDropped(String),
197    /// A function definition was altered.
198    FunctionAltered { name: String },
199    /// An enum type was added in the target schema.
200    EnumAdded(EnumDef),
201    /// An enum type was dropped from the target schema.
202    EnumDropped(String),
203    /// A constraint was added in the target schema.
204    ConstraintAdded(ConstraintDef),
205    /// A constraint was dropped from the target schema.
206    ConstraintDropped { table: String, name: String },
207    /// A trigger was added in the target schema.
208    TriggerAdded(TriggerDef),
209    /// A trigger was dropped from the target schema.
210    TriggerDropped { table: String, name: String },
211    /// A PostgreSQL extension was added.
212    ExtensionAdded(String),
213    /// A PostgreSQL extension was dropped.
214    ExtensionDropped(String),
215}
216
217impl std::fmt::Display for SchemaDiff {
218    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
219        match self {
220            SchemaDiff::TableAdded(t) => write!(f, "+ TABLE {}", t.name),
221            SchemaDiff::TableDropped(n) => write!(f, "- TABLE {}", n),
222            SchemaDiff::ColumnAdded { table, column } => {
223                write!(
224                    f,
225                    "+ COLUMN {}.{} ({})",
226                    table, column.name, column.data_type
227                )
228            }
229            SchemaDiff::ColumnDropped { table, column } => {
230                write!(f, "- COLUMN {}.{}", table, column)
231            }
232            SchemaDiff::ColumnAltered { table, column, .. } => {
233                write!(f, "~ COLUMN {}.{}", table, column)
234            }
235            SchemaDiff::IndexAdded(idx) => write!(f, "+ INDEX {}", idx.name),
236            SchemaDiff::IndexDropped(n) => write!(f, "- INDEX {}", n),
237            SchemaDiff::ViewAdded(v) => write!(f, "+ VIEW {}", v.name),
238            SchemaDiff::ViewDropped(n) => write!(f, "- VIEW {}", n),
239            SchemaDiff::ViewAltered { name, .. } => write!(f, "~ VIEW {}", name),
240            SchemaDiff::SequenceAdded(s) => write!(f, "+ SEQUENCE {}", s.name),
241            SchemaDiff::SequenceDropped(n) => write!(f, "- SEQUENCE {}", n),
242            SchemaDiff::FunctionAdded(func) => write!(f, "+ FUNCTION {}", func.name),
243            SchemaDiff::FunctionDropped(n) => write!(f, "- FUNCTION {}", n),
244            SchemaDiff::FunctionAltered { name } => write!(f, "~ FUNCTION {}", name),
245            SchemaDiff::EnumAdded(e) => write!(f, "+ TYPE {} (enum)", e.name),
246            SchemaDiff::EnumDropped(n) => write!(f, "- TYPE {} (enum)", n),
247            SchemaDiff::ConstraintAdded(c) => {
248                write!(f, "+ CONSTRAINT {} ON {}", c.name, c.table_name)
249            }
250            SchemaDiff::ConstraintDropped { table, name } => {
251                write!(f, "- CONSTRAINT {} ON {}", name, table)
252            }
253            SchemaDiff::TriggerAdded(t) => write!(f, "+ TRIGGER {} ON {}", t.name, t.table_name),
254            SchemaDiff::TriggerDropped { table, name } => {
255                write!(f, "- TRIGGER {} ON {}", name, table)
256            }
257            SchemaDiff::ExtensionAdded(n) => write!(f, "+ EXTENSION {}", n),
258            SchemaDiff::ExtensionDropped(n) => write!(f, "- EXTENSION {}", n),
259        }
260    }
261}
262
263/// Introspect the current state of a PostgreSQL schema.
264pub async fn introspect(client: &Client, schema: &str) -> Result<SchemaSnapshot> {
265    let (tables, views, indexes, sequences, functions, enums, constraints, triggers, extensions) =
266        tokio::try_join!(
267            introspect_tables(client, schema),
268            introspect_views(client, schema),
269            introspect_indexes(client, schema),
270            introspect_sequences(client, schema),
271            introspect_functions(client, schema),
272            introspect_enums(client, schema),
273            introspect_constraints(client, schema),
274            introspect_triggers(client, schema),
275            introspect_extensions(client),
276        )?;
277
278    Ok(SchemaSnapshot {
279        tables,
280        views,
281        indexes,
282        sequences,
283        functions,
284        enums,
285        constraints,
286        triggers,
287        extensions,
288    })
289}
290
291async fn introspect_tables(client: &Client, schema: &str) -> Result<Vec<TableDef>> {
292    let rows = client
293        .query(
294            "SELECT t.table_name, c.column_name, c.data_type, c.is_nullable, c.column_default, c.ordinal_position
295             FROM information_schema.tables t
296             LEFT JOIN information_schema.columns c
297               ON t.table_schema = c.table_schema AND t.table_name = c.table_name
298             WHERE t.table_schema = $1 AND t.table_type = 'BASE TABLE'
299             ORDER BY t.table_name, c.ordinal_position",
300            &[&schema],
301        )
302        .await?;
303
304    let mut tables: Vec<TableDef> = Vec::new();
305    let mut current_table: Option<String> = None;
306    let mut columns: Vec<ColumnDef> = Vec::new();
307
308    for row in &rows {
309        let table_name: String = row.get(0);
310        let col_name: Option<String> = row.get(1);
311
312        if current_table.as_ref() != Some(&table_name) {
313            if let Some(prev_name) = current_table.take() {
314                tables.push(TableDef {
315                    schema: schema.to_string(),
316                    name: prev_name,
317                    columns: std::mem::take(&mut columns),
318                });
319            }
320            current_table = Some(table_name.clone());
321        }
322
323        if let Some(name) = col_name {
324            columns.push(ColumnDef {
325                name,
326                data_type: row.get(2),
327                is_nullable: row.get::<_, String>(3) == "YES",
328                default: row.get(4),
329                ordinal_position: row.get(5),
330            });
331        }
332    }
333
334    // Don't forget the last table
335    if let Some(name) = current_table {
336        tables.push(TableDef {
337            schema: schema.to_string(),
338            name,
339            columns,
340        });
341    }
342
343    Ok(tables)
344}
345
346async fn introspect_views(client: &Client, schema: &str) -> Result<Vec<ViewDef>> {
347    // Regular views
348    let rows = client
349        .query(
350            "SELECT table_name, view_definition
351             FROM information_schema.views
352             WHERE table_schema = $1
353             ORDER BY table_name",
354            &[&schema],
355        )
356        .await?;
357
358    let mut views: Vec<ViewDef> = rows
359        .iter()
360        .map(|r| ViewDef {
361            schema: schema.to_string(),
362            name: r.get(0),
363            definition: r.get::<_, Option<String>>(1).unwrap_or_default(),
364            is_materialized: false,
365        })
366        .collect();
367
368    // Materialized views
369    let mat_rows = client
370        .query(
371            "SELECT c.relname, pg_get_viewdef(c.oid)
372             FROM pg_class c
373             JOIN pg_namespace n ON n.oid = c.relnamespace
374             WHERE n.nspname = $1 AND c.relkind = 'm'
375             ORDER BY c.relname",
376            &[&schema],
377        )
378        .await?;
379
380    for r in &mat_rows {
381        views.push(ViewDef {
382            schema: schema.to_string(),
383            name: r.get(0),
384            definition: r.get::<_, Option<String>>(1).unwrap_or_default(),
385            is_materialized: true,
386        });
387    }
388
389    Ok(views)
390}
391
392async fn introspect_indexes(client: &Client, schema: &str) -> Result<Vec<IndexDef>> {
393    let rows = client
394        .query(
395            "SELECT indexname, tablename, indexdef
396             FROM pg_indexes
397             WHERE schemaname = $1
398             ORDER BY indexname",
399            &[&schema],
400        )
401        .await?;
402
403    Ok(rows
404        .iter()
405        .map(|r| {
406            let definition: String = r.get(2);
407            IndexDef {
408                schema: schema.to_string(),
409                name: r.get(0),
410                table_name: r.get(1),
411                is_unique: definition.to_uppercase().contains("UNIQUE"),
412                definition,
413            }
414        })
415        .collect())
416}
417
418async fn introspect_sequences(client: &Client, schema: &str) -> Result<Vec<SequenceDef>> {
419    let rows = client
420        .query(
421            "SELECT sequence_name, data_type
422             FROM information_schema.sequences
423             WHERE sequence_schema = $1
424             ORDER BY sequence_name",
425            &[&schema],
426        )
427        .await?;
428
429    Ok(rows
430        .iter()
431        .map(|r| SequenceDef {
432            schema: schema.to_string(),
433            name: r.get(0),
434            data_type: r.get(1),
435        })
436        .collect())
437}
438
439async fn introspect_functions(client: &Client, schema: &str) -> Result<Vec<FunctionDef>> {
440    let rows = client
441        .query(
442            "SELECT p.proname,
443                    pg_get_function_arguments(p.oid),
444                    pg_get_function_result(p.oid),
445                    l.lanname,
446                    pg_get_functiondef(p.oid)
447             FROM pg_proc p
448             JOIN pg_namespace n ON n.oid = p.pronamespace
449             JOIN pg_language l ON l.oid = p.prolang
450             WHERE n.nspname = $1
451               AND p.prokind IN ('f', 'p')
452             ORDER BY p.proname",
453            &[&schema],
454        )
455        .await?;
456
457    Ok(rows
458        .iter()
459        .map(|r| FunctionDef {
460            schema: schema.to_string(),
461            name: r.get(0),
462            arguments: r.get(1),
463            return_type: r.get::<_, Option<String>>(2).unwrap_or_default(),
464            language: r.get(3),
465            definition: r.get::<_, Option<String>>(4).unwrap_or_default(),
466        })
467        .collect())
468}
469
470async fn introspect_enums(client: &Client, schema: &str) -> Result<Vec<EnumDef>> {
471    let rows = client
472        .query(
473            "SELECT t.typname, array_agg(e.enumlabel ORDER BY e.enumsortorder)::text[]
474             FROM pg_type t
475             JOIN pg_enum e ON e.enumtypid = t.oid
476             JOIN pg_namespace n ON n.oid = t.typnamespace
477             WHERE n.nspname = $1
478             GROUP BY t.typname
479             ORDER BY t.typname",
480            &[&schema],
481        )
482        .await?;
483
484    Ok(rows
485        .iter()
486        .map(|r| EnumDef {
487            schema: schema.to_string(),
488            name: r.get(0),
489            values: r.get(1),
490        })
491        .collect())
492}
493
494async fn introspect_constraints(client: &Client, schema: &str) -> Result<Vec<ConstraintDef>> {
495    let rows = client
496        .query(
497            "SELECT tc.table_name, tc.constraint_name, tc.constraint_type,
498                    pg_get_constraintdef(c.oid)
499             FROM information_schema.table_constraints tc
500             JOIN pg_constraint c ON c.conname = tc.constraint_name
501             JOIN pg_namespace n ON n.oid = c.connamespace
502             WHERE tc.constraint_schema = $1 AND n.nspname = $1
503             ORDER BY tc.table_name, tc.constraint_name",
504            &[&schema],
505        )
506        .await?;
507
508    Ok(rows
509        .iter()
510        .map(|r| ConstraintDef {
511            schema: schema.to_string(),
512            table_name: r.get(0),
513            name: r.get(1),
514            constraint_type: r.get(2),
515            definition: r.get::<_, Option<String>>(3).unwrap_or_default(),
516        })
517        .collect())
518}
519
520async fn introspect_triggers(client: &Client, schema: &str) -> Result<Vec<TriggerDef>> {
521    let rows = client
522        .query(
523            "SELECT event_object_table, trigger_name, action_statement
524             FROM information_schema.triggers
525             WHERE trigger_schema = $1
526             ORDER BY event_object_table, trigger_name",
527            &[&schema],
528        )
529        .await?;
530
531    Ok(rows
532        .iter()
533        .map(|r| TriggerDef {
534            schema: schema.to_string(),
535            table_name: r.get(0),
536            name: r.get(1),
537            definition: r.get(2),
538        })
539        .collect())
540}
541
542async fn introspect_extensions(client: &Client) -> Result<Vec<String>> {
543    let rows = client
544        .query(
545            "SELECT extname FROM pg_extension WHERE extname != 'plpgsql' ORDER BY extname",
546            &[],
547        )
548        .await?;
549
550    Ok(rows.iter().map(|r| r.get(0)).collect())
551}
552
553/// Compare two schema snapshots and return the differences.
554pub fn diff(before: &SchemaSnapshot, after: &SchemaSnapshot) -> Vec<SchemaDiff> {
555    let mut diffs = Vec::new();
556
557    // Build lookup maps for O(1) access
558
559    // Tables - keyed by name, value is reference to TableDef
560    let before_tables: HashMap<&str, &TableDef> =
561        before.tables.iter().map(|t| (t.name.as_str(), t)).collect();
562    let after_tables: HashMap<&str, &TableDef> =
563        after.tables.iter().map(|t| (t.name.as_str(), t)).collect();
564
565    // Views - keyed by name, value is reference to ViewDef
566    let before_views: HashMap<&str, &ViewDef> =
567        before.views.iter().map(|v| (v.name.as_str(), v)).collect();
568    let after_views: HashMap<&str, &ViewDef> =
569        after.views.iter().map(|v| (v.name.as_str(), v)).collect();
570
571    // Indexes - existence check only, keyed by name
572    let before_indexes: HashSet<&str> = before.indexes.iter().map(|i| i.name.as_str()).collect();
573    let after_indexes: HashSet<&str> = after.indexes.iter().map(|i| i.name.as_str()).collect();
574
575    // Sequences - existence check only, keyed by name
576    let before_sequences: HashSet<&str> =
577        before.sequences.iter().map(|s| s.name.as_str()).collect();
578    let after_sequences: HashSet<&str> = after.sequences.iter().map(|s| s.name.as_str()).collect();
579
580    // Functions - keyed by name, value is reference to FunctionDef
581    let before_functions: HashMap<&str, &FunctionDef> = before
582        .functions
583        .iter()
584        .map(|f| (f.name.as_str(), f))
585        .collect();
586    let after_functions: HashMap<&str, &FunctionDef> = after
587        .functions
588        .iter()
589        .map(|f| (f.name.as_str(), f))
590        .collect();
591
592    // Enums - existence check only, keyed by name
593    let before_enums: HashSet<&str> = before.enums.iter().map(|e| e.name.as_str()).collect();
594    let after_enums: HashSet<&str> = after.enums.iter().map(|e| e.name.as_str()).collect();
595
596    // Constraints - compound key (table_name, name)
597    let before_constraints: HashSet<(&str, &str)> = before
598        .constraints
599        .iter()
600        .map(|c| (c.table_name.as_str(), c.name.as_str()))
601        .collect();
602    let after_constraints: HashSet<(&str, &str)> = after
603        .constraints
604        .iter()
605        .map(|c| (c.table_name.as_str(), c.name.as_str()))
606        .collect();
607
608    // Triggers - compound key (table_name, name)
609    let before_triggers: HashSet<(&str, &str)> = before
610        .triggers
611        .iter()
612        .map(|t| (t.table_name.as_str(), t.name.as_str()))
613        .collect();
614    let after_triggers: HashSet<(&str, &str)> = after
615        .triggers
616        .iter()
617        .map(|t| (t.table_name.as_str(), t.name.as_str()))
618        .collect();
619
620    // Extensions - existence check only
621    let before_extensions: HashSet<&str> = before.extensions.iter().map(|e| e.as_str()).collect();
622    let after_extensions: HashSet<&str> = after.extensions.iter().map(|e| e.as_str()).collect();
623
624    // Tables: check dropped/altered then added
625    for bt in &before.tables {
626        if let Some(at) = after_tables.get(bt.name.as_str()) {
627            diff_columns(&mut diffs, &bt.name, &bt.columns, &at.columns);
628        } else {
629            diffs.push(SchemaDiff::TableDropped(bt.name.clone()));
630        }
631    }
632    for at in &after.tables {
633        if !before_tables.contains_key(at.name.as_str()) {
634            diffs.push(SchemaDiff::TableAdded(at.clone()));
635        }
636    }
637
638    // Views: check dropped/altered then added
639    for bv in &before.views {
640        if let Some(av) = after_views.get(bv.name.as_str()) {
641            if bv.definition != av.definition {
642                diffs.push(SchemaDiff::ViewAltered {
643                    name: bv.name.clone(),
644                    from: bv.definition.clone(),
645                    to: av.definition.clone(),
646                });
647            }
648        } else {
649            diffs.push(SchemaDiff::ViewDropped(bv.name.clone()));
650        }
651    }
652    for av in &after.views {
653        if !before_views.contains_key(av.name.as_str()) {
654            diffs.push(SchemaDiff::ViewAdded(av.clone()));
655        }
656    }
657
658    // Indexes: check dropped then added
659    for bi in &before.indexes {
660        if !after_indexes.contains(bi.name.as_str()) {
661            diffs.push(SchemaDiff::IndexDropped(bi.name.clone()));
662        }
663    }
664    for ai in &after.indexes {
665        if !before_indexes.contains(ai.name.as_str()) {
666            diffs.push(SchemaDiff::IndexAdded(ai.clone()));
667        }
668    }
669
670    // Sequences: check dropped then added
671    for bs in &before.sequences {
672        if !after_sequences.contains(bs.name.as_str()) {
673            diffs.push(SchemaDiff::SequenceDropped(bs.name.clone()));
674        }
675    }
676    for a_s in &after.sequences {
677        if !before_sequences.contains(a_s.name.as_str()) {
678            diffs.push(SchemaDiff::SequenceAdded(a_s.clone()));
679        }
680    }
681
682    // Functions: check dropped/altered then added
683    for bf in &before.functions {
684        if let Some(af) = after_functions.get(bf.name.as_str()) {
685            if bf.definition != af.definition {
686                diffs.push(SchemaDiff::FunctionAltered {
687                    name: bf.name.clone(),
688                });
689            }
690        } else {
691            diffs.push(SchemaDiff::FunctionDropped(bf.name.clone()));
692        }
693    }
694    for af in &after.functions {
695        if !before_functions.contains_key(af.name.as_str()) {
696            diffs.push(SchemaDiff::FunctionAdded(af.clone()));
697        }
698    }
699
700    // Enums: check dropped then added
701    for be in &before.enums {
702        if !after_enums.contains(be.name.as_str()) {
703            diffs.push(SchemaDiff::EnumDropped(be.name.clone()));
704        }
705    }
706    for ae in &after.enums {
707        if !before_enums.contains(ae.name.as_str()) {
708            diffs.push(SchemaDiff::EnumAdded(ae.clone()));
709        }
710    }
711
712    // Constraints: check dropped then added
713    for bc in &before.constraints {
714        if !after_constraints.contains(&(bc.table_name.as_str(), bc.name.as_str())) {
715            diffs.push(SchemaDiff::ConstraintDropped {
716                table: bc.table_name.clone(),
717                name: bc.name.clone(),
718            });
719        }
720    }
721    for ac in &after.constraints {
722        if !before_constraints.contains(&(ac.table_name.as_str(), ac.name.as_str())) {
723            diffs.push(SchemaDiff::ConstraintAdded(ac.clone()));
724        }
725    }
726
727    // Triggers: check dropped then added
728    for bt in &before.triggers {
729        if !after_triggers.contains(&(bt.table_name.as_str(), bt.name.as_str())) {
730            diffs.push(SchemaDiff::TriggerDropped {
731                table: bt.table_name.clone(),
732                name: bt.name.clone(),
733            });
734        }
735    }
736    for at in &after.triggers {
737        if !before_triggers.contains(&(at.table_name.as_str(), at.name.as_str())) {
738            diffs.push(SchemaDiff::TriggerAdded(at.clone()));
739        }
740    }
741
742    // Extensions: check dropped then added
743    for ext in &before.extensions {
744        if !after_extensions.contains(ext.as_str()) {
745            diffs.push(SchemaDiff::ExtensionDropped(ext.clone()));
746        }
747    }
748    for ext in &after.extensions {
749        if !before_extensions.contains(ext.as_str()) {
750            diffs.push(SchemaDiff::ExtensionAdded(ext.clone()));
751        }
752    }
753
754    diffs
755}
756
757fn diff_columns(
758    diffs: &mut Vec<SchemaDiff>,
759    table: &str,
760    before: &[ColumnDef],
761    after: &[ColumnDef],
762) {
763    let before_cols: HashMap<&str, &ColumnDef> =
764        before.iter().map(|c| (c.name.as_str(), c)).collect();
765    let after_cols: HashMap<&str, &ColumnDef> =
766        after.iter().map(|c| (c.name.as_str(), c)).collect();
767
768    for bc in before {
769        if let Some(ac) = after_cols.get(bc.name.as_str()) {
770            if bc != *ac {
771                diffs.push(SchemaDiff::ColumnAltered {
772                    table: table.to_string(),
773                    column: bc.name.clone(),
774                    from: bc.clone(),
775                    to: (*ac).clone(),
776                });
777            }
778        } else {
779            diffs.push(SchemaDiff::ColumnDropped {
780                table: table.to_string(),
781                column: bc.name.clone(),
782            });
783        }
784    }
785    for ac in after {
786        if !before_cols.contains_key(ac.name.as_str()) {
787            diffs.push(SchemaDiff::ColumnAdded {
788                table: table.to_string(),
789                column: ac.clone(),
790            });
791        }
792    }
793}
794
795/// Generate DDL statements from schema diffs.
796pub fn generate_ddl(diffs: &[SchemaDiff]) -> String {
797    let mut statements = Vec::new();
798
799    for d in diffs {
800        match d {
801            SchemaDiff::TableAdded(t) => {
802                let cols: Vec<String> = t
803                    .columns
804                    .iter()
805                    .map(|c| {
806                        let mut col = format!("    {} {}", quote_ident(&c.name), c.data_type);
807                        if !c.is_nullable {
808                            col.push_str(" NOT NULL");
809                        }
810                        if let Some(ref default) = c.default {
811                            col.push_str(&format!(" DEFAULT {}", default));
812                        }
813                        col
814                    })
815                    .collect();
816                statements.push(format!(
817                    "CREATE TABLE {} (\n{}\n);",
818                    quote_ident(&t.name),
819                    cols.join(",\n")
820                ));
821            }
822            SchemaDiff::TableDropped(name) => {
823                statements.push(format!(
824                    "DROP TABLE IF EXISTS {} CASCADE;",
825                    quote_ident(name)
826                ));
827            }
828            SchemaDiff::ColumnAdded { table, column } => {
829                let mut stmt = format!(
830                    "ALTER TABLE {} ADD COLUMN {} {}",
831                    quote_ident(table),
832                    quote_ident(&column.name),
833                    column.data_type
834                );
835                if !column.is_nullable {
836                    stmt.push_str(" NOT NULL");
837                }
838                if let Some(ref default) = column.default {
839                    stmt.push_str(&format!(" DEFAULT {}", default));
840                }
841                stmt.push(';');
842                statements.push(stmt);
843            }
844            SchemaDiff::ColumnDropped { table, column } => {
845                statements.push(format!(
846                    "ALTER TABLE {} DROP COLUMN {};",
847                    quote_ident(table),
848                    quote_ident(column)
849                ));
850            }
851            SchemaDiff::ColumnAltered {
852                table, column, to, ..
853            } => {
854                statements.push(format!(
855                    "ALTER TABLE {} ALTER COLUMN {} TYPE {};",
856                    quote_ident(table),
857                    quote_ident(column),
858                    to.data_type
859                ));
860                if to.is_nullable {
861                    statements.push(format!(
862                        "ALTER TABLE {} ALTER COLUMN {} DROP NOT NULL;",
863                        quote_ident(table),
864                        quote_ident(column)
865                    ));
866                } else {
867                    statements.push(format!(
868                        "ALTER TABLE {} ALTER COLUMN {} SET NOT NULL;",
869                        quote_ident(table),
870                        quote_ident(column)
871                    ));
872                }
873                match &to.default {
874                    Some(default) => {
875                        statements.push(format!(
876                            "ALTER TABLE {} ALTER COLUMN {} SET DEFAULT {};",
877                            quote_ident(table),
878                            quote_ident(column),
879                            default
880                        ));
881                    }
882                    None => {
883                        statements.push(format!(
884                            "ALTER TABLE {} ALTER COLUMN {} DROP DEFAULT;",
885                            quote_ident(table),
886                            quote_ident(column)
887                        ));
888                    }
889                }
890            }
891            SchemaDiff::IndexAdded(idx) => {
892                statements.push(format!("{};", idx.definition));
893            }
894            SchemaDiff::IndexDropped(name) => {
895                statements.push(format!("DROP INDEX IF EXISTS {};", quote_ident(name)));
896            }
897            SchemaDiff::ViewAdded(v) => {
898                let keyword = if v.is_materialized {
899                    "MATERIALIZED VIEW"
900                } else {
901                    "VIEW"
902                };
903                statements.push(format!(
904                    "CREATE {} {} AS {};",
905                    keyword,
906                    quote_ident(&v.name),
907                    v.definition.trim_end_matches(';').trim()
908                ));
909            }
910            SchemaDiff::ViewDropped(name) => {
911                statements.push(format!(
912                    "DROP VIEW IF EXISTS {} CASCADE;",
913                    quote_ident(name)
914                ));
915            }
916            SchemaDiff::ViewAltered { name, to, .. } => {
917                statements.push(format!(
918                    "CREATE OR REPLACE VIEW {} AS {};",
919                    quote_ident(name),
920                    to.trim_end_matches(';').trim()
921                ));
922            }
923            SchemaDiff::SequenceAdded(s) => {
924                statements.push(format!("CREATE SEQUENCE {};", quote_ident(&s.name)));
925            }
926            SchemaDiff::SequenceDropped(name) => {
927                statements.push(format!("DROP SEQUENCE IF EXISTS {};", quote_ident(name)));
928            }
929            SchemaDiff::FunctionAdded(func) => {
930                statements.push(format!("{};", func.definition.trim_end_matches(';')));
931            }
932            SchemaDiff::FunctionDropped(name) => {
933                statements.push(format!(
934                    "DROP FUNCTION IF EXISTS {} CASCADE;",
935                    quote_ident(name)
936                ));
937            }
938            SchemaDiff::FunctionAltered { name } => {
939                // For altered functions we'd need the full definition; leave a comment
940                statements.push(format!(
941                    "-- Function {} was altered; manual review needed",
942                    name
943                ));
944            }
945            SchemaDiff::EnumAdded(e) => {
946                let values: Vec<String> = e.values.iter().map(|v| format!("'{}'", v)).collect();
947                statements.push(format!(
948                    "CREATE TYPE {} AS ENUM ({});",
949                    quote_ident(&e.name),
950                    values.join(", ")
951                ));
952            }
953            SchemaDiff::EnumDropped(name) => {
954                statements.push(format!(
955                    "DROP TYPE IF EXISTS {} CASCADE;",
956                    quote_ident(name)
957                ));
958            }
959            SchemaDiff::ConstraintAdded(c) => {
960                statements.push(format!(
961                    "ALTER TABLE {} ADD CONSTRAINT {} {};",
962                    quote_ident(&c.table_name),
963                    quote_ident(&c.name),
964                    c.definition
965                ));
966            }
967            SchemaDiff::ConstraintDropped { table, name } => {
968                statements.push(format!(
969                    "ALTER TABLE {} DROP CONSTRAINT IF EXISTS {};",
970                    quote_ident(table),
971                    quote_ident(name)
972                ));
973            }
974            SchemaDiff::TriggerAdded(t) => {
975                statements.push(format!(
976                    "-- Trigger {} on {} needs manual creation",
977                    t.name, t.table_name
978                ));
979            }
980            SchemaDiff::TriggerDropped { table, name } => {
981                statements.push(format!(
982                    "DROP TRIGGER IF EXISTS {} ON {};",
983                    quote_ident(name),
984                    quote_ident(table)
985                ));
986            }
987            SchemaDiff::ExtensionAdded(name) => {
988                statements.push(format!(
989                    "CREATE EXTENSION IF NOT EXISTS {};",
990                    quote_ident(name)
991                ));
992            }
993            SchemaDiff::ExtensionDropped(name) => {
994                statements.push(format!("DROP EXTENSION IF EXISTS {};", quote_ident(name)));
995            }
996        }
997    }
998
999    statements.join("\n\n")
1000}
1001
1002/// Generate full DDL to recreate a schema from a snapshot.
1003pub fn to_ddl(snapshot: &SchemaSnapshot) -> String {
1004    let mut statements = Vec::new();
1005
1006    // Extensions first
1007    for ext in &snapshot.extensions {
1008        statements.push(format!(
1009            "CREATE EXTENSION IF NOT EXISTS {};",
1010            quote_ident(ext)
1011        ));
1012    }
1013
1014    // Enums before tables (types must exist for columns)
1015    for e in &snapshot.enums {
1016        let values: Vec<String> = e.values.iter().map(|v| format!("'{}'", v)).collect();
1017        statements.push(format!(
1018            "CREATE TYPE {} AS ENUM ({});",
1019            quote_ident(&e.name),
1020            values.join(", ")
1021        ));
1022    }
1023
1024    // Sequences
1025    for s in &snapshot.sequences {
1026        statements.push(format!("CREATE SEQUENCE {};", quote_ident(&s.name)));
1027    }
1028
1029    // Tables
1030    for t in &snapshot.tables {
1031        let cols: Vec<String> = t
1032            .columns
1033            .iter()
1034            .map(|c| {
1035                let mut col = format!("    {} {}", quote_ident(&c.name), c.data_type);
1036                if !c.is_nullable {
1037                    col.push_str(" NOT NULL");
1038                }
1039                if let Some(ref default) = c.default {
1040                    col.push_str(&format!(" DEFAULT {}", default));
1041                }
1042                col
1043            })
1044            .collect();
1045        statements.push(format!(
1046            "CREATE TABLE {} (\n{}\n);",
1047            quote_ident(&t.name),
1048            cols.join(",\n")
1049        ));
1050    }
1051
1052    // Constraints
1053    for c in &snapshot.constraints {
1054        statements.push(format!(
1055            "ALTER TABLE {} ADD CONSTRAINT {} {};",
1056            quote_ident(&c.table_name),
1057            quote_ident(&c.name),
1058            c.definition
1059        ));
1060    }
1061
1062    // Indexes
1063    for idx in &snapshot.indexes {
1064        statements.push(format!("{};", idx.definition));
1065    }
1066
1067    // Views
1068    for v in &snapshot.views {
1069        let keyword = if v.is_materialized {
1070            "MATERIALIZED VIEW"
1071        } else {
1072            "VIEW"
1073        };
1074        statements.push(format!(
1075            "CREATE {} {} AS {};",
1076            keyword,
1077            quote_ident(&v.name),
1078            v.definition.trim_end_matches(';').trim()
1079        ));
1080    }
1081
1082    // Functions
1083    for func in &snapshot.functions {
1084        statements.push(format!("{};", func.definition.trim_end_matches(';')));
1085    }
1086
1087    // Triggers
1088    for t in &snapshot.triggers {
1089        statements.push(format!(
1090            "-- Trigger {} on {}: {}",
1091            t.name, t.table_name, t.definition
1092        ));
1093    }
1094
1095    statements.join("\n\n")
1096}