sql_comment_parser/
lib.rs

1#[derive(Debug)]
2pub struct SqlCommentParser<'a> {
3    sql: &'a str,
4    pos: usize,
5    start: usize,
6}
7
8#[derive(Debug)]
9struct Comment {
10    start_index: usize,
11    end_index: usize,
12}
13
14impl Comment {
15    fn new(start: usize, end: usize) -> Self {
16        Self {
17            start_index: start,
18            end_index: end,
19        }
20    }
21}
22
23impl<'a> SqlCommentParser<'a> {
24    pub fn new(sql: &'a str) -> Self {
25        Self {
26            sql,
27            pos: 0,
28            start: 0,
29        }
30    }
31
32    pub fn get_comment_sql(&mut self) -> String {
33        let mut comment_sql = String::new();
34        loop {
35            match self.next_comment() {
36                Some(comment_range) => {
37                    comment_sql
38                        .push_str(&self.sql[comment_range.start_index..comment_range.end_index]);
39                }
40                None => {
41                    break;
42                }
43            };
44        }
45        comment_sql
46    }
47
48    pub fn remove_comment_sql(&mut self) -> String {
49        let mut new_sql = String::new();
50        let mut start_index = 0;
51        loop {
52            match self.next_comment() {
53                Some(comment_range) => {
54                    new_sql.push_str(&self.sql[start_index..comment_range.start_index]);
55                    start_index = comment_range.end_index;
56                }
57                None => {
58                    if start_index != self.sql.len() {
59                        new_sql.push_str(&self.sql[start_index..]);
60                    }
61                    break;
62                }
63            };
64        }
65        new_sql
66    }
67
68    fn next_comment(&mut self) -> Option<Comment> {
69        while self.pos < self.sql.len() {
70            let c = self.sql.as_bytes()[self.pos] as char;
71            let start_index;
72            match c {
73                '\'' => {
74                    self.start = self.pos;
75                    self.pos += 1;
76                    while self.pos < self.sql.len() {
77                        let c = self.sql.as_bytes()[self.pos] as char;
78                        self.pos += 1;
79                        if c == '\'' {
80                            break;
81                        }
82                    }
83                }
84                '`' => {
85                    self.start = self.pos;
86                    self.pos += 1;
87                    while self.pos < self.sql.len() {
88                        let c = self.sql.as_bytes()[self.pos] as char;
89                        self.pos += 1;
90                        if c == '`' {
91                            break;
92                        }
93                    }
94                }
95                '\"' => {
96                    self.start = self.pos;
97                    self.pos += 1;
98                    while self.pos < self.sql.len() {
99                        let c = self.sql.as_bytes()[self.pos] as char;
100                        self.pos += 1;
101                        if c == '\"' {
102                            break;
103                        }
104                    }
105                }
106                '/' => {
107                    // possible start of '/*'
108                    if self.pos + 1 < self.sql.len() {
109                        let c = self.sql.as_bytes()[self.pos + 1] as char;
110                        if c == '*' {
111                            start_index = self.pos;
112                            // 从pos + 2开始查找"*/"
113                            let end: usize = match self.sql.find("*/") {
114                                Some(end) => end + "*/".len(),
115                                None => self.sql.len(),
116                            };
117
118                            // 更新pos并计算end_index
119                            self.pos = end;
120                            let end_index = self.pos;
121
122                            return Some(Comment::new(start_index, end_index));
123                        }
124                    }
125                }
126                '-' => {
127                    // possible start of '--' comment
128                    if c == '-'
129                        && self.pos + 1 < self.sql.len()
130                        && self.sql.as_bytes()[self.pos + 1] as char == '-'
131                    {
132                        start_index = self.pos;
133                        self.pos = SqlCommentParser::index_of_line_end(self.sql, self.pos + 2);
134                        let end_index = self.pos;
135                        return Some(Comment::new(start_index, end_index));
136                    }
137                }
138                _ => {
139                    if SqlCommentParser::is_open_quote(c) {
140                        break;
141                    } else {
142                        loop {
143                            self.pos += 1;
144                            if self.pos >= self.sql.len() {
145                                break;
146                            }
147                            let c = self.sql.as_bytes()[self.pos] as char;
148                            match c {
149                                '\'' | '`' | '\"' | '/' => break,
150                                '-' => {
151                                    if self.pos + 1 < self.sql.len()
152                                        && self.sql.as_bytes()[self.pos + 1] as char == '-'
153                                    {
154                                        break;
155                                    }
156                                }
157                                _ => {}
158                            }
159                        }
160                    }
161                }
162            }
163        }
164        return None;
165    }
166
167    fn index_of_line_end(sql: &'a str, mut i: usize) -> usize {
168        let length = sql.len();
169        while i < length {
170            let c = sql.as_bytes()[i] as char;
171            match c {
172                '\r' | '\n' => {
173                    return i;
174                }
175                _ => {
176                    i += 1;
177                }
178            }
179        }
180        return i;
181    }
182
183    fn is_open_quote(character: char) -> bool {
184        match character {
185            '\"' | '`' | '\'' => {
186                return true;
187            }
188            _ => {
189                return false;
190            }
191        }
192    }
193}