yauth_migration/
collector.rs1use std::collections::{HashMap, HashSet, VecDeque};
4use std::fmt;
5
6use super::types::TableDef;
7
8#[derive(Debug, Clone)]
10pub struct YAuthSchema {
11 pub tables: Vec<TableDef>,
12}
13
14impl YAuthSchema {
15 pub fn table(&self, name: &str) -> Option<&TableDef> {
17 self.tables.iter().find(|t| t.name == name)
18 }
19}
20
21#[derive(Debug)]
23pub enum SchemaError {
24 DuplicateTable(String),
25 MissingDependency { table: String, references: String },
26 Cycle(Vec<String>),
27 UnknownPlugin(String),
28}
29
30impl fmt::Display for SchemaError {
31 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
32 match self {
33 SchemaError::DuplicateTable(name) => {
34 write!(
35 f,
36 "duplicate table definition: '{name}' -- each table must be defined exactly once"
37 )
38 }
39 SchemaError::MissingDependency { table, references } => {
40 write!(
41 f,
42 "table '{table}' references '{references}' which is not in the schema -- ensure the referenced table's plugin is enabled"
43 )
44 }
45 SchemaError::Cycle(tables) => {
46 write!(
47 f,
48 "cycle detected in FK dependencies among tables: {tables:?}"
49 )
50 }
51 SchemaError::UnknownPlugin(name) => {
52 write!(f, "unknown plugin: '{name}'")
53 }
54 }
55 }
56}
57
58impl std::error::Error for SchemaError {}
59
60pub fn collect_schema(table_lists: Vec<Vec<TableDef>>) -> Result<YAuthSchema, SchemaError> {
65 let mut ordered_names: Vec<String> = Vec::new();
67 let mut tables_by_name: HashMap<String, TableDef> = HashMap::new();
68
69 for tables in &table_lists {
70 for table in tables {
71 if tables_by_name.contains_key(&table.name) {
72 return Err(SchemaError::DuplicateTable(table.name.clone()));
73 }
74 ordered_names.push(table.name.clone());
75 tables_by_name.insert(table.name.clone(), table.clone());
76 }
77 }
78
79 let table_names: HashSet<String> = tables_by_name.keys().cloned().collect();
81
82 let mut in_degree: HashMap<String, usize> = HashMap::new();
84 let mut dependents: HashMap<String, Vec<String>> = HashMap::new();
85
86 for name in &table_names {
87 in_degree.entry(name.clone()).or_insert(0);
88 }
89
90 for (name, table) in &tables_by_name {
91 for dep in table.dependencies() {
92 if !table_names.contains(dep) {
93 return Err(SchemaError::MissingDependency {
94 table: name.clone(),
95 references: dep.to_string(),
96 });
97 }
98 *in_degree.entry(name.clone()).or_insert(0) += 1;
99 dependents
100 .entry(dep.to_string())
101 .or_default()
102 .push(name.clone());
103 }
104 }
105
106 let mut queue: VecDeque<String> = VecDeque::new();
108 for name in &ordered_names {
109 if in_degree[name] == 0 {
110 queue.push_back(name.clone());
111 }
112 }
113
114 let mut sorted: Vec<TableDef> = Vec::new();
115 while let Some(name) = queue.pop_front() {
116 sorted.push(
117 tables_by_name
118 .remove(&name)
119 .expect("invariant: name came from tables_by_name keys"),
120 );
121 if let Some(deps) = dependents.get(&name) {
122 let mut freed: Vec<String> = Vec::new();
124 for dep in deps {
125 let d = in_degree
126 .get_mut(dep)
127 .expect("invariant: all table names have in_degree entries");
128 *d -= 1;
129 if *d == 0 {
130 freed.push(dep.clone());
131 }
132 }
133 freed.sort_by_key(|n| ordered_names.iter().position(|on| on == n));
135 for n in freed {
136 queue.push_back(n);
137 }
138 }
139 }
140
141 if sorted.len() != table_names.len() {
142 let remaining: Vec<String> = tables_by_name.keys().cloned().collect();
143 return Err(SchemaError::Cycle(remaining));
144 }
145
146 Ok(YAuthSchema { tables: sorted })
147}