sea_query/audit/
select.rs

1use super::*;
2use crate::{
3    CaseStatement, CommonTableExpression, Condition, ConditionExpression, ConditionHolder,
4    ConditionHolderContents, Cycle, Expr, FunctionCall, LogicalChainOper, Search, SelectStatement,
5    SubQueryStatement, TableRef, WithClause, WithQuery,
6};
7use std::collections::HashSet;
8
9impl AuditTrait for SelectStatement {
10    fn audit(&self) -> Result<QueryAccessAudit, Error> {
11        Ok(wrap_result(self.audit_impl()?))
12    }
13}
14
15impl AuditTrait for WithQuery {
16    fn audit(&self) -> Result<QueryAccessAudit, Error> {
17        Ok(wrap_result(self.audit_impl()?))
18    }
19}
20
21impl AuditTrait for WithClause {
22    fn audit(&self) -> Result<QueryAccessAudit, Error> {
23        Ok(wrap_result(self.audit_impl()?))
24    }
25}
26
27impl SelectStatement {
28    fn audit_impl(&self) -> Result<Vec<QueryAccessRequest>, Error> {
29        let mut walker = Walker::default();
30        walker.recurse_audit_select(self)?;
31        Ok(walker.access)
32    }
33}
34
35impl WithQuery {
36    fn audit_impl(&self) -> Result<Vec<QueryAccessRequest>, Error> {
37        let mut walker = Walker::default();
38        walker.recurse_audit_with(self)?;
39        Ok(walker.access)
40    }
41}
42
43impl WithClause {
44    fn audit_impl(&self) -> Result<Vec<QueryAccessRequest>, Error> {
45        let mut walker = Walker::default();
46        walker.recurse_audit_with_clause(self)?;
47        walker.recurse_audit_with_clause_cleanup(self);
48        Ok(walker.access)
49    }
50}
51
52#[derive(Default)]
53pub(super) struct Walker {
54    pub(super) access: Vec<QueryAccessRequest>,
55}
56
57impl Walker {
58    fn recurse_audit_select(&mut self, select: &SelectStatement) -> Result<(), Error> {
59        for select in &select.selects {
60            self.recurse_audit_expr(&select.expr)?;
61        }
62        for table_ref in &select.from {
63            self.recurse_audit_table(table_ref)?;
64        }
65        for join in &select.join {
66            self.recurse_audit_table(&join.table)?;
67        }
68        for (_, select) in &select.unions {
69            self.recurse_audit_select(select)?;
70        }
71        self.recurse_audit_condition_holder(&select.r#where)?;
72        if let Some(with) = &select.with {
73            self.recurse_audit_with_clause(with)?;
74            self.recurse_audit_with_clause_cleanup(with);
75        }
76        Ok(())
77    }
78
79    pub(super) fn recurse_audit_table(&mut self, table_ref: &TableRef) -> Result<(), Error> {
80        match table_ref {
81            TableRef::SubQuery(select, _) => self.recurse_audit_select(select)?,
82            TableRef::FunctionCall(function, _) => self.recurse_audit_function(function)?,
83            TableRef::Table(table_name, _) => {
84                self.access.push(QueryAccessRequest {
85                    access_type: AccessType::Select,
86                    schema_table: table_name.clone(),
87                });
88            }
89            TableRef::ValuesList(_, _) => (),
90        }
91        Ok(())
92    }
93
94    fn recurse_audit_with(&mut self, with: &WithQuery) -> Result<(), Error> {
95        self.recurse_audit_with_clause(&with.with_clause)?;
96        if let Some(subquery) = &with.query {
97            self.recurse_audit_subquery(subquery)?;
98        }
99        self.recurse_audit_with_clause_cleanup(&with.with_clause);
100        Ok(())
101    }
102
103    fn recurse_audit_function(&mut self, function: &FunctionCall) -> Result<(), Error> {
104        for arg in &function.args {
105            self.recurse_audit_expr(arg)?;
106        }
107        Ok(())
108    }
109
110    fn recurse_audit_expr(&mut self, expr: &Expr) -> Result<(), Error> {
111        match expr {
112            Expr::Column(_) => (),
113            Expr::Unary(_, expr) | Expr::AsEnum(_, expr) => self.recurse_audit_expr(expr)?,
114            Expr::FunctionCall(function) => self.recurse_audit_function(function)?,
115            Expr::Binary(left, _, right) => {
116                self.recurse_audit_expr(left)?;
117                self.recurse_audit_expr(right)?;
118            }
119            Expr::SubQuery(_, subquery) => self.recurse_audit_subquery(subquery)?,
120            Expr::Value(_) => (),
121            Expr::Values(_) => (),
122            Expr::Custom(_) => (),
123            Expr::CustomWithExpr(_, exprs) | Expr::Tuple(exprs) => {
124                for expr in exprs {
125                    self.recurse_audit_expr(expr)?;
126                }
127            }
128            Expr::Keyword(_) => (),
129            Expr::Case(case) => self.recurse_audit_case(case)?,
130            Expr::Constant(_) => (),
131            Expr::TypeName(_) => (),
132        }
133        Ok(())
134    }
135
136    fn recurse_audit_subquery(&mut self, subquery: &SubQueryStatement) -> Result<(), Error> {
137        match subquery {
138            SubQueryStatement::SelectStatement(select) => self.recurse_audit_select(select)?,
139            SubQueryStatement::InsertStatement(insert) => {
140                self.access.append(&mut insert.audit()?.requests);
141            }
142            SubQueryStatement::UpdateStatement(update) => {
143                self.access.append(&mut update.audit()?.requests);
144            }
145            SubQueryStatement::DeleteStatement(delete) => {
146                self.access.append(&mut delete.audit()?.requests);
147            }
148            SubQueryStatement::WithStatement(with) => self.recurse_audit_with(with)?,
149        }
150        Ok(())
151    }
152
153    pub(super) fn recurse_audit_with_clause(
154        &mut self,
155        with_clause: &WithClause,
156    ) -> Result<(), Error> {
157        if let Some(search) = &with_clause.search {
158            self.recurse_audit_cte_search(search)?;
159        }
160        if let Some(cycle) = &with_clause.cycle {
161            self.recurse_audit_cte_cycle(cycle)?;
162        }
163        for cte in &with_clause.cte_expressions {
164            self.recurse_audit_cte_expr(cte)?;
165        }
166        Ok(())
167    }
168
169    pub(super) fn recurse_audit_with_clause_cleanup(&mut self, with_clause: &WithClause) {
170        // remove cte alias
171        for cte in &with_clause.cte_expressions {
172            if let Some(table_name) = &cte.table_name {
173                self.remove_item(AccessType::Select, &TableName(None, table_name.clone()));
174            }
175        }
176    }
177
178    fn recurse_audit_cte_search(&mut self, search: &Search) -> Result<(), Error> {
179        if let Some(expr) = &search.expr {
180            self.recurse_audit_expr(&expr.expr)?;
181        }
182        Ok(())
183    }
184
185    fn recurse_audit_cte_cycle(&mut self, cycle: &Cycle) -> Result<(), Error> {
186        if let Some(expr) = &cycle.expr {
187            self.recurse_audit_expr(expr)?;
188        }
189        Ok(())
190    }
191
192    fn recurse_audit_cte_expr(&mut self, cte: &CommonTableExpression) -> Result<(), Error> {
193        if let Some(query) = &cte.query {
194            self.recurse_audit_subquery(query)?;
195        }
196        Ok(())
197    }
198
199    fn recurse_audit_case(&mut self, case: &CaseStatement) -> Result<(), Error> {
200        for when in &case.when {
201            self.recurse_audit_condition(&when.condition)?;
202            self.recurse_audit_expr(&when.result)?;
203        }
204        if let Some(expr) = &case.r#else {
205            self.recurse_audit_expr(expr)?;
206        }
207        Ok(())
208    }
209
210    fn recurse_audit_condition_holder(&mut self, condition: &ConditionHolder) -> Result<(), Error> {
211        match &condition.contents {
212            ConditionHolderContents::Empty => (),
213            ConditionHolderContents::Chain(chain) => {
214                for oper in chain {
215                    match oper {
216                        LogicalChainOper::And(expr) => self.recurse_audit_expr(expr)?,
217                        LogicalChainOper::Or(expr) => self.recurse_audit_expr(expr)?,
218                    }
219                }
220            }
221            ConditionHolderContents::Condition(condition) => {
222                self.recurse_audit_condition(condition)?
223            }
224        }
225        Ok(())
226    }
227
228    fn recurse_audit_condition(&mut self, condition: &Condition) -> Result<(), Error> {
229        for cond_expr in &condition.conditions {
230            match cond_expr {
231                ConditionExpression::Condition(condition) => {
232                    self.recurse_audit_condition(condition)?;
233                }
234                ConditionExpression::Expr(expr) => self.recurse_audit_expr(expr)?,
235            }
236        }
237        Ok(())
238    }
239
240    fn remove_item(&mut self, access_type: AccessType, target: &TableName) {
241        while let Some(pos) = self
242            .access
243            .iter()
244            .position(|item| item.access_type == access_type && &item.schema_table == target)
245        {
246            self.access.remove(pos);
247        }
248    }
249}
250
251fn wrap_result(access: Vec<QueryAccessRequest>) -> QueryAccessAudit {
252    let mut select_set = HashSet::new();
253    let mut insert_set = HashSet::new();
254    let mut update_set = HashSet::new();
255    let mut delete_set = HashSet::new();
256    QueryAccessAudit {
257        requests: access
258            .into_iter()
259            .filter_map(|access| {
260                let set = match access.access_type {
261                    AccessType::Select => &mut select_set,
262                    AccessType::Insert => &mut insert_set,
263                    AccessType::Update => &mut update_set,
264                    AccessType::Delete => &mut delete_set,
265                    _ => todo!(),
266                };
267                if set.contains(&access.schema_table) {
268                    None
269                } else {
270                    set.insert(access.schema_table.clone());
271                    Some(access)
272                }
273            })
274            .collect(),
275    }
276}