safety_postgres/access/
join_tables.rs1use crate::access::errors::{JoinTableError, JoinTableErrorGenerator};
2use crate::access::validators::{validate_alphanumeric_name, validate_string};
3
4#[derive(Clone)]
6struct JoinTable {
7 schema: String,
8 table_name: String,
9 join_columns: Vec<String>,
10 destination_columns: Vec<String>,
11}
12
13#[derive(Clone)]
34pub struct JoinTables {
35 tables: Vec<JoinTable>,
36}
37
38impl JoinTables {
39 pub fn new() -> Self {
41 Self {
42 tables: Vec::new(),
43 }
44 }
45
46 pub fn add_join_table(&mut self, schema: &str, table_name: &str, join_columns: &[&str], destination_columns: &[&str]) -> Result<&mut Self, JoinTableError> {
72 validate_string(table_name, "table_name", &JoinTableErrorGenerator)?;
73 validate_string(schema, "schema", &JoinTableErrorGenerator)?;
74 Self::validate_column_collection_pare(join_columns, destination_columns)?;
75
76 fn convert_vec(input: &[&str]) -> Vec<String> {
77 input.iter().map(|str| str.to_string()).collect()
78 }
79
80 let join_table = JoinTable {
81 schema: schema.to_string(),
82 table_name: table_name.to_string(),
83 join_columns: convert_vec(join_columns),
84 destination_columns: convert_vec(destination_columns),
85 };
86
87 self.tables.push(join_table);
88
89 Ok(self)
90 }
91
92 pub(super) fn generate_statement_text(&self, main_table: &str) -> String {
103 let mut statement_texts:Vec<String> = Vec::new();
104
105 for table in &self.tables {
106 let statement_text = table.generate_statement_text(main_table.to_string());
107 statement_texts.push(statement_text);
108 }
109 statement_texts.join(" ")
110 }
111
112 pub fn get_joined_text(&self) -> String {
132 self.generate_statement_text("main_table_name")
133 }
134
135 pub fn is_tables_empty(&self) -> bool {
141 self.tables.is_empty()
142 }
143
144 fn validate_column_collection_pare(join_columns: &[&str], destination_columns: &[&str]) -> Result<(), JoinTableError> {
165 if !join_columns.iter().all(|column| validate_alphanumeric_name(column, "_")) {
166 return Err(JoinTableError::InputInvalidError("'join_columns' includes invalid name. Please check your input.".to_string()));
167 }
168 if !destination_columns.iter().all(|column| validate_alphanumeric_name(column, "_")) {
169 return Err(JoinTableError::InputInvalidError("'destination_columns' includes invalid name. Please check your input.".to_string()));
170 }
171
172 if join_columns.len() != destination_columns.len() {
173 return Err(JoinTableError::InputInconsistentError("'join_columns' and 'destination_columns' will be join key in SQL so these should have match number of elements.".to_string()));
174 }
175
176 Ok(())
177 }
178}
179
180impl JoinTable {
181 fn generate_statement_text(&self, main_table: String) -> String {
191 let table_with_schema = if self.schema.is_empty() {
192 self.table_name.clone()
193 } else {
194 format!("{}.{}", self.schema, self.table_name)
195 };
196 let mut statement = format!("INNER JOIN {} ON", table_with_schema);
197 for (index, (join_column, destination_column)) in self.join_columns.iter().zip(&self.destination_columns).enumerate() {
198 statement += format!(" {}.{} = {}.{}", main_table, destination_column, table_with_schema, join_column).as_str();
199 if index + 1 < self.join_columns.len() {
200 statement += " AND";
201 }
202 }
203 statement
204 }
205}
206
207#[cfg(test)]
208mod tests {
209 use super::*;
210
211 #[test]
213 fn test_add_join_table() {
214 let mut join_tables = JoinTables::new();
215 join_tables.add_join_table("", "users", &["id"], &["user_id"]).unwrap();
216
217 assert_eq!(join_tables.tables.len(), 1);
218 assert_eq!(join_tables.tables[0].table_name, "users");
219 assert_eq!(join_tables.tables[0].join_columns, vec!["id".to_string()]);
220 assert_eq!(join_tables.tables[0].destination_columns, vec!["user_id".to_string()]);
221 }
222
223 #[test]
225 fn test_generate_statement_text() {
226 let mut join_tables = JoinTables::new();
227 join_tables.add_join_table("", "users", &["id"], &["user_id"]).unwrap();
228 join_tables.add_join_table("schema", "teams", &["id"], &["team_id"]).unwrap();
229
230 let stmt = join_tables.generate_statement_text("main");
231 assert!(stmt.contains("INNER JOIN users ON main.user_id = users.id INNER JOIN schema.teams ON main.team_id = schema.teams.id"));
232 }
233
234 #[test]
236 fn test_is_tables_empty() {
237 let join_tables = JoinTables::new();
238
239 assert!(join_tables.is_tables_empty());
240 }
241
242 #[test]
244 fn test_join_table_generate_statement_text() {
245 let join_table = JoinTable {
246 schema: "".to_string(),
247 table_name: "users".to_string(),
248 join_columns: vec!["id".to_string()],
249 destination_columns: vec!["user_id".to_string()],
250 };
251
252 let stmt = join_table.generate_statement_text("main".to_string());
253 assert!(stmt.contains("INNER JOIN users ON main.user_id = users.id"))
254 }
255
256 #[test]
258 fn test_join_tables_empty_constructor() {
259 let join_tables = JoinTables::new();
260 assert_eq!(join_tables.tables.len(), 0);
261 }
262
263 #[test]
265 fn test_invalid_schema_name() {
266 let mut join_tables = JoinTables::new();
267 let Err(e) = join_tables.add_join_table("schema!", "table", &["id"], &["table_id"]) else { panic!() };
268 assert_eq!(e, JoinTableError::InputInvalidError(format!("'{}' has invalid characters. '{}' allows alphabets, numbers and under bar only.", "schema!", "schema")));
269 }
270
271 #[test]
273 fn test_invalid_table_name() {
274 let mut join_tables = JoinTables::new();
275 let Err(e) = join_tables.add_join_table("", "tabl+e", &["id"], &["table_id"]) else { panic!() };
276 assert_eq!(e, JoinTableError::InputInvalidError("'tabl+e' has invalid characters. 'table_name' allows alphabets, numbers and under bar only.".to_string()));
277 }
278
279 #[test]
281 fn test_invalid_char_contains_columns() {
282 let ok_columns = vec!["id", "team", "data"];
283 let ng_columns = vec!["id", "te;am", "date"];
284
285 let mut join_tables = JoinTables::new();
286 let Err(e) = join_tables.add_join_table("", "table", &ng_columns, &ok_columns) else { panic!() };
287
288 assert_eq!(e, JoinTableError::InputInvalidError("'join_columns' includes invalid name. Please check your input.".to_string()));
289
290 let Err(e) = join_tables.add_join_table("", "table", &ok_columns, &ng_columns) else { panic!() };
291 assert_eq!(e, JoinTableError::InputInvalidError("'destination_columns' includes invalid name. Please check your input.".to_string()))
292 }
293
294 #[test]
296 fn test_inconsistent_number_columns() {
297 let mut join_tables = JoinTables::new();
298 let Err(e) = join_tables.add_join_table("", "table", &["id"], &["user_name", "id"]) else { panic!() };
299
300 assert_eq!(e, JoinTableError::InputInconsistentError("'join_columns' and 'destination_columns' will be join key in SQL so these should have match number of elements.".to_string()))
301 }
302}