1use super::*;
2use crate::{
3 CaseStatement, CommonTableExpression, Condition, ConditionExpression, ConditionHolder,
4 ConditionHolderContents, Cycle, Expr, FunctionCall, LogicalChainOper, Search, SelectStatement,
5 SubQueryStatement, TableRef, WithClause, WithQuery,
6};
7use std::collections::HashSet;
8
9impl AuditTrait for SelectStatement {
10 fn audit(&self) -> Result<QueryAccessAudit, Error> {
11 Ok(wrap_result(self.audit_impl()?))
12 }
13}
14
15impl AuditTrait for WithQuery {
16 fn audit(&self) -> Result<QueryAccessAudit, Error> {
17 Ok(wrap_result(self.audit_impl()?))
18 }
19}
20
21impl AuditTrait for WithClause {
22 fn audit(&self) -> Result<QueryAccessAudit, Error> {
23 Ok(wrap_result(self.audit_impl()?))
24 }
25}
26
27impl SelectStatement {
28 fn audit_impl(&self) -> Result<Vec<QueryAccessRequest>, Error> {
29 let mut walker = Walker::default();
30 walker.recurse_audit_select(self)?;
31 Ok(walker.access)
32 }
33}
34
35impl WithQuery {
36 fn audit_impl(&self) -> Result<Vec<QueryAccessRequest>, Error> {
37 let mut walker = Walker::default();
38 walker.recurse_audit_with(self)?;
39 Ok(walker.access)
40 }
41}
42
43impl WithClause {
44 fn audit_impl(&self) -> Result<Vec<QueryAccessRequest>, Error> {
45 let mut walker = Walker::default();
46 walker.recurse_audit_with_clause(self)?;
47 walker.recurse_audit_with_clause_cleanup(self);
48 Ok(walker.access)
49 }
50}
51
52#[derive(Default)]
53pub(super) struct Walker {
54 pub(super) access: Vec<QueryAccessRequest>,
55}
56
57impl Walker {
58 fn recurse_audit_select(&mut self, select: &SelectStatement) -> Result<(), Error> {
59 for select in &select.selects {
60 self.recurse_audit_expr(&select.expr)?;
61 }
62 for table_ref in &select.from {
63 self.recurse_audit_table(table_ref)?;
64 }
65 for join in &select.join {
66 self.recurse_audit_table(&join.table)?;
67 }
68 for (_, select) in &select.unions {
69 self.recurse_audit_select(select)?;
70 }
71 self.recurse_audit_condition_holder(&select.r#where)?;
72 if let Some(with) = &select.with {
73 self.recurse_audit_with_clause(with)?;
74 self.recurse_audit_with_clause_cleanup(with);
75 }
76 Ok(())
77 }
78
79 pub(super) fn recurse_audit_table(&mut self, table_ref: &TableRef) -> Result<(), Error> {
80 match table_ref {
81 TableRef::SubQuery(select, _) => self.recurse_audit_select(select)?,
82 TableRef::FunctionCall(function, _) => self.recurse_audit_function(function)?,
83 TableRef::Table(table_name, _) => {
84 self.access.push(QueryAccessRequest {
85 access_type: AccessType::Select,
86 schema_table: table_name.clone(),
87 });
88 }
89 TableRef::ValuesList(_, _) => (),
90 }
91 Ok(())
92 }
93
94 fn recurse_audit_with(&mut self, with: &WithQuery) -> Result<(), Error> {
95 self.recurse_audit_with_clause(&with.with_clause)?;
96 if let Some(subquery) = &with.query {
97 self.recurse_audit_subquery(subquery)?;
98 }
99 self.recurse_audit_with_clause_cleanup(&with.with_clause);
100 Ok(())
101 }
102
103 fn recurse_audit_function(&mut self, function: &FunctionCall) -> Result<(), Error> {
104 for arg in &function.args {
105 self.recurse_audit_expr(arg)?;
106 }
107 Ok(())
108 }
109
110 fn recurse_audit_expr(&mut self, expr: &Expr) -> Result<(), Error> {
111 match expr {
112 Expr::Column(_) => (),
113 Expr::Unary(_, expr) | Expr::AsEnum(_, expr) => self.recurse_audit_expr(expr)?,
114 Expr::FunctionCall(function) => self.recurse_audit_function(function)?,
115 Expr::Binary(left, _, right) => {
116 self.recurse_audit_expr(left)?;
117 self.recurse_audit_expr(right)?;
118 }
119 Expr::SubQuery(_, subquery) => self.recurse_audit_subquery(subquery)?,
120 Expr::Value(_) => (),
121 Expr::Values(_) => (),
122 Expr::Custom(_) => (),
123 Expr::CustomWithExpr(_, exprs) | Expr::Tuple(exprs) => {
124 for expr in exprs {
125 self.recurse_audit_expr(expr)?;
126 }
127 }
128 Expr::Keyword(_) => (),
129 Expr::Case(case) => self.recurse_audit_case(case)?,
130 Expr::Constant(_) => (),
131 Expr::TypeName(_) => (),
132 }
133 Ok(())
134 }
135
136 fn recurse_audit_subquery(&mut self, subquery: &SubQueryStatement) -> Result<(), Error> {
137 match subquery {
138 SubQueryStatement::SelectStatement(select) => self.recurse_audit_select(select)?,
139 SubQueryStatement::InsertStatement(insert) => {
140 self.access.append(&mut insert.audit()?.requests);
141 }
142 SubQueryStatement::UpdateStatement(update) => {
143 self.access.append(&mut update.audit()?.requests);
144 }
145 SubQueryStatement::DeleteStatement(delete) => {
146 self.access.append(&mut delete.audit()?.requests);
147 }
148 SubQueryStatement::WithStatement(with) => self.recurse_audit_with(with)?,
149 }
150 Ok(())
151 }
152
153 pub(super) fn recurse_audit_with_clause(
154 &mut self,
155 with_clause: &WithClause,
156 ) -> Result<(), Error> {
157 if let Some(search) = &with_clause.search {
158 self.recurse_audit_cte_search(search)?;
159 }
160 if let Some(cycle) = &with_clause.cycle {
161 self.recurse_audit_cte_cycle(cycle)?;
162 }
163 for cte in &with_clause.cte_expressions {
164 self.recurse_audit_cte_expr(cte)?;
165 }
166 Ok(())
167 }
168
169 pub(super) fn recurse_audit_with_clause_cleanup(&mut self, with_clause: &WithClause) {
170 for cte in &with_clause.cte_expressions {
172 if let Some(table_name) = &cte.table_name {
173 self.remove_item(AccessType::Select, &TableName(None, table_name.clone()));
174 }
175 }
176 }
177
178 fn recurse_audit_cte_search(&mut self, search: &Search) -> Result<(), Error> {
179 if let Some(expr) = &search.expr {
180 self.recurse_audit_expr(&expr.expr)?;
181 }
182 Ok(())
183 }
184
185 fn recurse_audit_cte_cycle(&mut self, cycle: &Cycle) -> Result<(), Error> {
186 if let Some(expr) = &cycle.expr {
187 self.recurse_audit_expr(expr)?;
188 }
189 Ok(())
190 }
191
192 fn recurse_audit_cte_expr(&mut self, cte: &CommonTableExpression) -> Result<(), Error> {
193 if let Some(query) = &cte.query {
194 self.recurse_audit_subquery(query)?;
195 }
196 Ok(())
197 }
198
199 fn recurse_audit_case(&mut self, case: &CaseStatement) -> Result<(), Error> {
200 for when in &case.when {
201 self.recurse_audit_condition(&when.condition)?;
202 self.recurse_audit_expr(&when.result)?;
203 }
204 if let Some(expr) = &case.r#else {
205 self.recurse_audit_expr(expr)?;
206 }
207 Ok(())
208 }
209
210 fn recurse_audit_condition_holder(&mut self, condition: &ConditionHolder) -> Result<(), Error> {
211 match &condition.contents {
212 ConditionHolderContents::Empty => (),
213 ConditionHolderContents::Chain(chain) => {
214 for oper in chain {
215 match oper {
216 LogicalChainOper::And(expr) => self.recurse_audit_expr(expr)?,
217 LogicalChainOper::Or(expr) => self.recurse_audit_expr(expr)?,
218 }
219 }
220 }
221 ConditionHolderContents::Condition(condition) => {
222 self.recurse_audit_condition(condition)?
223 }
224 }
225 Ok(())
226 }
227
228 fn recurse_audit_condition(&mut self, condition: &Condition) -> Result<(), Error> {
229 for cond_expr in &condition.conditions {
230 match cond_expr {
231 ConditionExpression::Condition(condition) => {
232 self.recurse_audit_condition(condition)?;
233 }
234 ConditionExpression::Expr(expr) => self.recurse_audit_expr(expr)?,
235 }
236 }
237 Ok(())
238 }
239
240 fn remove_item(&mut self, access_type: AccessType, target: &TableName) {
241 while let Some(pos) = self
242 .access
243 .iter()
244 .position(|item| item.access_type == access_type && &item.schema_table == target)
245 {
246 self.access.remove(pos);
247 }
248 }
249}
250
251fn wrap_result(access: Vec<QueryAccessRequest>) -> QueryAccessAudit {
252 let mut select_set = HashSet::new();
253 let mut insert_set = HashSet::new();
254 let mut update_set = HashSet::new();
255 let mut delete_set = HashSet::new();
256 QueryAccessAudit {
257 requests: access
258 .into_iter()
259 .filter_map(|access| {
260 let set = match access.access_type {
261 AccessType::Select => &mut select_set,
262 AccessType::Insert => &mut insert_set,
263 AccessType::Update => &mut update_set,
264 AccessType::Delete => &mut delete_set,
265 _ => todo!(),
266 };
267 if set.contains(&access.schema_table) {
268 None
269 } else {
270 set.insert(access.schema_table.clone());
271 Some(access)
272 }
273 })
274 .collect(),
275 }
276}