wafrift_encoding/encoding/keyword/
comment.rs1use std::fmt::Write as _;
4
5const SQL_KEYWORDS: &[&str] = &[
6 "SELECT", "UNION", "INSERT", "UPDATE", "DELETE", "DROP", "WHERE", "FROM", "ORDER", "GROUP",
7 "HAVING",
8];
9
10pub fn sql_comment_insert(payload: &str) -> String {
14 payload.replace(' ', "/**/")
15}
16
17pub fn mysql_versioned_comment(payload: &str, version: u32) -> String {
19 let mut result = String::with_capacity(payload.len() * 2);
20 let chars: Vec<char> = payload.chars().collect();
21 let lower_chars: Vec<char> = chars.iter().map(char::to_ascii_lowercase).collect();
22
23 let mut kw_data: Vec<(usize, Vec<char>)> = SQL_KEYWORDS
24 .iter()
25 .map(|kw| {
26 (
27 kw.chars().count(),
28 kw.chars().map(|c| c.to_ascii_lowercase()).collect(),
29 )
30 })
31 .collect();
32 kw_data.sort_by_key(|t| std::cmp::Reverse(t.0));
33
34 let mut i = 0;
35 while i < chars.len() {
36 let mut matched = false;
37 for (kw_len, kw_lower) in &kw_data {
38 if i + kw_len <= chars.len() && lower_chars[i..i + kw_len] == kw_lower[..] {
39 let _ = write!(&mut result, "/*!{version}");
40 for j in 0..*kw_len {
41 result.push(chars[i + j]);
42 }
43 result.push_str("*/");
44 i += kw_len;
45 matched = true;
46 break;
47 }
48 }
49 if !matched {
50 result.push(chars[i]);
51 i += 1;
52 }
53 }
54 result
55}
56
57#[cfg(test)]
58mod tests {
59 use super::*;
60
61 #[test]
62 fn sql_comment_insert_replaces_spaces() {
63 assert_eq!(sql_comment_insert("SELECT * FROM"), "SELECT/**/*/**/FROM");
64 }
65
66 #[test]
67 fn mysql_versioned_comment_wraps_keywords() {
68 let result = mysql_versioned_comment("SELECT * FROM users", 50_000);
69 assert!(result.contains("/*!50000SELECT*/"));
70 assert!(result.contains("/*!50000FROM*/"));
71 }
72}