safety_postgres/access/
join_tables.rs

1use crate::access::errors::{JoinTableError, JoinTableErrorGenerator};
2use crate::access::validators::{validate_alphanumeric_name, validate_string};
3
4/// Represents a join table in a database.
5#[derive(Clone)]
6struct JoinTable {
7    schema: String,
8    table_name: String,
9    join_columns: Vec<String>,
10    destination_columns: Vec<String>,
11}
12
13/// Represents a collection of join tables in a database.
14///
15/// # Example
16/// ```rust
17/// use safety_postgres::access::join_tables::JoinTables;
18///
19/// let mut join_tables = JoinTables::new();
20///
21/// join_tables.add_join_table(
22///     "",
23///     "joined_table",
24///     &vec!["joined_table_c1"],
25///     &vec!["main_table_c1"]).expect("add joined table failed");
26///
27/// let join_text = join_tables.get_joined_text();
28/// let expected_text =
29///     "INNER JOIN joined_table ON main_table_name.main_table_c1 = joined_table.joined_table_c1";
30///
31/// assert_eq!(join_text, expected_text.to_string());
32/// ```
33#[derive(Clone)]
34pub struct JoinTables {
35    tables: Vec<JoinTable>,
36}
37
38impl JoinTables {
39    /// Create a new instance of JoinTables.
40    pub fn new() -> Self {
41        Self {
42            tables: Vec::new(),
43        }
44    }
45
46    /// Adds a join table to the instance.
47    ///
48    /// # Arguments
49    ///
50    /// * `schema` - The schema name for the new join table (input "" if there is no schema_name).
51    /// * `table_name` - The table name for the new join table.
52    /// * `join_columns` - The names of the columns in the joined table.
53    /// * `destination_columns` - The names of the columns in the main(base) table.
54    ///
55    /// # Errors
56    ///
57    /// Returns a `JoinTableError` if there is an error adding the join table.
58    ///
59    /// # Examples
60    ///
61    /// ```
62    /// use safety_postgres::access::join_tables::JoinTables;
63    ///
64    /// let mut join_tables = JoinTables::new();
65    ///
66    /// join_tables.add_join_table("public", "users", &["id"], &["user_id"]).expect("adding join table failed");
67    /// let joined_text = join_tables.get_joined_text();
68    ///
69    /// assert_eq!(joined_text, "INNER JOIN public.users ON main_table_name.user_id = public.users.id");
70    /// ```
71    pub fn add_join_table(&mut self, schema: &str, table_name: &str, join_columns: &[&str], destination_columns: &[&str]) -> Result<&mut Self, JoinTableError> {
72        validate_string(table_name, "table_name", &JoinTableErrorGenerator)?;
73        validate_string(schema, "schema", &JoinTableErrorGenerator)?;
74        Self::validate_column_collection_pare(join_columns, destination_columns)?;
75
76        fn convert_vec(input: &[&str]) -> Vec<String> {
77            input.iter().map(|str| str.to_string()).collect()
78        }
79
80        let join_table = JoinTable {
81            schema: schema.to_string(),
82            table_name: table_name.to_string(),
83            join_columns: convert_vec(join_columns),
84            destination_columns: convert_vec(destination_columns),
85        };
86
87        self.tables.push(join_table);
88
89        Ok(self)
90    }
91
92    /// Generate the statement text for the given main table.
93    ///
94    /// # Arguments
95    ///
96    /// * `main_table` - The name of the main(base) table.
97    ///
98    /// # Returns
99    ///
100    /// The generated statement text as a `String`.
101    ///
102    pub(super) fn generate_statement_text(&self, main_table: &str) -> String {
103        let mut statement_texts:Vec<String> = Vec::new();
104
105        for table in &self.tables {
106            let statement_text = table.generate_statement_text(main_table.to_string());
107            statement_texts.push(statement_text);
108        }
109        statement_texts.join(" ")
110    }
111
112    /// Returns the joined text generated from the given join information.
113    ///
114    /// # Examples
115    ///
116    /// ```
117    /// use safety_postgres::access::join_tables::JoinTables;
118    ///
119    /// let mut obj = JoinTables::new();
120    /// obj.add_join_table("", "category", &["id"], &["cid"])
121    ///     .expect("adding join table failed");
122    /// let joined_text = obj.get_joined_text();
123    /// println!("Joined Text: {}", joined_text);
124    /// // This will display:
125    /// // "Joined Text: INNER JOIN main_table_name ON main_table_name.cid = category.id"
126    /// ```
127    ///
128    /// # Returns
129    ///
130    /// Returns a `String` that represents the joined text generated.
131    pub fn get_joined_text(&self) -> String {
132        self.generate_statement_text("main_table_name")
133    }
134
135    /// Checks if the tables collection is empty.
136    ///
137    /// # Returns
138    ///
139    /// Returns `true` if the tables collection is empty, `false` otherwise.
140    pub fn is_tables_empty(&self) -> bool {
141        self.tables.is_empty()
142    }
143
144    /// Validates the column collections for joining tables.
145    ///
146    /// This function takes two slices of strings representing join and
147    /// destination columns respectively,
148    /// and checks for the validity of the column names according to the following rules:
149    ///
150    /// - All column names must be alphanumeric or contain underscores.
151    /// - There should be a matching number of join and destination columns.
152    ///
153    /// # Arguments
154    ///
155    /// * `join_columns` - A slice of strings representing join columns from joined table.
156    /// * `destination_columns` - A slice of strings representing destination columns from main table.
157    ///
158    /// # Returns
159    ///
160    /// This function returns a `Result` with the following possible outcomes:
161    ///
162    /// * `Ok(())` - If the column collections pass the validation.
163    /// * `Err(JoinTableError)` - If there are any validation errors. The error type provides a detailed message.
164    fn validate_column_collection_pare(join_columns: &[&str], destination_columns: &[&str]) -> Result<(), JoinTableError> {
165        if !join_columns.iter().all(|column| validate_alphanumeric_name(column, "_")) {
166            return Err(JoinTableError::InputInvalidError("'join_columns' includes invalid name. Please check your input.".to_string()));
167        }
168        if !destination_columns.iter().all(|column| validate_alphanumeric_name(column, "_")) {
169            return Err(JoinTableError::InputInvalidError("'destination_columns' includes invalid name. Please check your input.".to_string()));
170        }
171
172        if join_columns.len() != destination_columns.len() {
173            return Err(JoinTableError::InputInconsistentError("'join_columns' and 'destination_columns' will be join key in SQL so these should have match number of elements.".to_string()));
174        }
175
176        Ok(())
177    }
178}
179
180impl JoinTable {
181    /// Generates a SQL statement for joining a main table with a secondary table.
182    ///
183    /// # Arguments
184    ///
185    /// * `main_table` - The name of the main table to join.
186    ///
187    /// # Returns
188    ///
189    /// The generated inner join SQL statement as a `String`.
190    fn generate_statement_text(&self, main_table: String) -> String {
191        let table_with_schema = if self.schema.is_empty() {
192            self.table_name.clone()
193        } else {
194            format!("{}.{}", self.schema, self.table_name)
195        };
196        let mut statement = format!("INNER JOIN {} ON", table_with_schema);
197        for (index, (join_column, destination_column)) in self.join_columns.iter().zip(&self.destination_columns).enumerate() {
198            statement += format!(" {}.{} = {}.{}", main_table, destination_column, table_with_schema, join_column).as_str();
199            if index + 1 < self.join_columns.len() {
200                statement += " AND";
201            }
202        }
203        statement
204    }
205}
206
207#[cfg(test)]
208mod tests {
209    use super::*;
210
211    /// Verifies the successful addition of a `JoinTable`.
212    #[test]
213    fn test_add_join_table() {
214        let mut join_tables = JoinTables::new();
215        join_tables.add_join_table("", "users", &["id"], &["user_id"]).unwrap();
216
217        assert_eq!(join_tables.tables.len(), 1);
218        assert_eq!(join_tables.tables[0].table_name, "users");
219        assert_eq!(join_tables.tables[0].join_columns, vec!["id".to_string()]);
220        assert_eq!(join_tables.tables[0].destination_columns, vec!["user_id".to_string()]);
221    }
222
223    /// Ensures that a correct SQL INNER JOIN statement is generated.
224    #[test]
225    fn test_generate_statement_text() {
226        let mut join_tables = JoinTables::new();
227        join_tables.add_join_table("", "users", &["id"], &["user_id"]).unwrap();
228        join_tables.add_join_table("schema", "teams", &["id"], &["team_id"]).unwrap();
229
230        let stmt = join_tables.generate_statement_text("main");
231        assert!(stmt.contains("INNER JOIN users ON main.user_id = users.id INNER JOIN schema.teams ON main.team_id = schema.teams.id"));
232    }
233
234    /// Checks whether the tables collection is empty.
235    #[test]
236    fn test_is_tables_empty() {
237        let join_tables = JoinTables::new();
238
239        assert!(join_tables.is_tables_empty());
240    }
241
242    /// Validates the proper generation of a SQL query from a `JoinTable`.
243    #[test]
244    fn test_join_table_generate_statement_text() {
245        let join_table = JoinTable {
246            schema: "".to_string(),
247            table_name: "users".to_string(),
248            join_columns: vec!["id".to_string()],
249            destination_columns: vec!["user_id".to_string()],
250        };
251
252        let stmt = join_table.generate_statement_text("main".to_string());
253        assert!(stmt.contains("INNER JOIN users ON main.user_id = users.id"))
254    }
255
256    /// Tests that the tables vector is initially empty on new `JoinTables` creation.
257    #[test]
258    fn test_join_tables_empty_constructor() {
259        let join_tables = JoinTables::new();
260        assert_eq!(join_tables.tables.len(), 0);
261    }
262
263    /// Validates the error on `add_join_table` with an invalid schema name is used.
264    #[test]
265    fn test_invalid_schema_name() {
266        let mut join_tables = JoinTables::new();
267        let Err(e) = join_tables.add_join_table("schema!", "table", &["id"], &["table_id"]) else { panic!() };
268        assert_eq!(e, JoinTableError::InputInvalidError(format!("'{}' has invalid characters. '{}' allows alphabets, numbers and under bar only.", "schema!", "schema")));
269    }
270
271    /// Checks error handling when invalid characters are used in a table name.
272    #[test]
273    fn test_invalid_table_name() {
274        let mut join_tables = JoinTables::new();
275        let Err(e) = join_tables.add_join_table("", "tabl+e", &["id"], &["table_id"]) else { panic!() };
276        assert_eq!(e, JoinTableError::InputInvalidError("'tabl+e' has invalid characters. 'table_name' allows alphabets, numbers and under bar only.".to_string()));
277    }
278
279    /// Confirms error when either 'join_columns' or 'destination_columns' contains invalid characters.
280    #[test]
281    fn test_invalid_char_contains_columns() {
282        let ok_columns = vec!["id", "team", "data"];
283        let ng_columns = vec!["id", "te;am", "date"];
284
285        let mut join_tables = JoinTables::new();
286        let Err(e) = join_tables.add_join_table("", "table", &ng_columns, &ok_columns) else { panic!() };
287
288        assert_eq!(e, JoinTableError::InputInvalidError("'join_columns' includes invalid name. Please check your input.".to_string()));
289
290        let Err(e) = join_tables.add_join_table("", "table", &ok_columns, &ng_columns) else { panic!() };
291        assert_eq!(e, JoinTableError::InputInvalidError("'destination_columns' includes invalid name. Please check your input.".to_string()))
292    }
293
294    /// Ensures error when 'join_columns' and 'destination_columns' collections' number of elements don't match.
295    #[test]
296    fn test_inconsistent_number_columns() {
297        let mut join_tables = JoinTables::new();
298        let Err(e) = join_tables.add_join_table("", "table", &["id"], &["user_name", "id"]) else { panic!() };
299
300        assert_eq!(e, JoinTableError::InputInconsistentError("'join_columns' and 'destination_columns' will be join key in SQL so these should have match number of elements.".to_string()))
301    }
302}