surql_parser/
formatting.rs1use serde::Deserialize;
10
11#[derive(Debug, Clone, Deserialize)]
13#[serde(default)]
14pub struct FormatConfig {
15 pub uppercase_keywords: bool,
16 pub indent_style: IndentStyle,
17 pub indent_width: u32,
18 pub newline_after_semicolon: bool,
19 pub newline_before_where: bool,
20 pub newline_before_set: bool,
21 pub newline_before_from: bool,
22 pub trailing_semicolon: bool,
23 pub collapse_blank_lines: bool,
24 pub max_blank_lines: u32,
25}
26
27#[derive(Debug, Clone, Deserialize, PartialEq)]
28#[serde(rename_all = "lowercase")]
29pub enum IndentStyle {
30 Tab,
31 Space,
32}
33
34impl Default for FormatConfig {
35 fn default() -> Self {
36 Self {
37 uppercase_keywords: true,
38 indent_style: IndentStyle::Tab,
39 indent_width: 4,
40 newline_after_semicolon: false,
41 newline_before_where: false,
42 newline_before_set: false,
43 newline_before_from: false,
44 trailing_semicolon: false,
45 collapse_blank_lines: false,
46 max_blank_lines: 2,
47 }
48 }
49}
50
51impl FormatConfig {
52 #[cfg(feature = "cli")]
53 pub fn load_from_dir(dir: &std::path::Path) -> Self {
54 let config_path = dir.join(".surqlformat.toml");
55 if let Ok(content) = std::fs::read_to_string(&config_path) {
56 match toml::from_str::<FormatConfig>(&content) {
57 Ok(config) => return config,
58 Err(e) => {
59 eprintln!(
60 "Invalid .surqlformat.toml at {}: {e}",
61 config_path.display()
62 );
63 }
64 }
65 }
66 Self::default()
67 }
68
69 fn indent_str(&self) -> String {
70 match self.indent_style {
71 IndentStyle::Tab => "\t".to_string(),
72 IndentStyle::Space => " ".repeat(self.indent_width as usize),
73 }
74 }
75}
76
77pub fn format_source(source: &str, config: &FormatConfig) -> Option<String> {
81 configurable_format(source, config)
82}
83
84fn configurable_format(source: &str, config: &FormatConfig) -> Option<String> {
86 use crate::upstream::syn::lexer::Lexer;
87 use crate::upstream::syn::token::{Delim, TokenKind};
88
89 let bytes = source.as_bytes();
90 if bytes.is_empty() || bytes.len() > u32::MAX as usize {
91 return None;
92 }
93
94 let tokens: Vec<_> = match std::panic::catch_unwind(|| Lexer::new(bytes).collect()) {
95 Ok(t) => t,
96 Err(e) => {
97 tracing::error!("Lexer panicked during formatting: {e:?}");
98 return None;
99 }
100 };
101
102 let mut result = String::with_capacity(source.len());
103 let mut last_end: usize = 0;
104 let mut depth: u32 = 0;
105 let indent = config.indent_str();
106
107 for token in &tokens {
108 let start = token.span.offset as usize;
109 let end = start + token.span.len as usize;
110 if end > source.len() {
111 continue;
112 }
113
114 let gap = &source[last_end..start];
115 let original = &source[start..end];
116
117 let formatted_token =
118 if config.uppercase_keywords && matches!(token.kind, TokenKind::Keyword(_)) {
119 let upper = original.to_uppercase();
120 if is_formattable_keyword(&upper) {
121 upper
122 } else {
123 original.to_string()
124 }
125 } else {
126 original.to_string()
127 };
128
129 if token.kind == TokenKind::OpenDelim(Delim::Brace) {
130 result.push_str(gap);
131 result.push_str(&formatted_token);
132 depth += 1;
133 last_end = end;
134 continue;
135 }
136
137 if token.kind == TokenKind::CloseDelim(Delim::Brace) {
138 depth = depth.saturating_sub(1);
139 result.push_str(gap);
140 result.push_str(&formatted_token);
141 last_end = end;
142 continue;
143 }
144
145 if token.kind == TokenKind::SemiColon && config.newline_after_semicolon {
146 result.push_str(gap);
147 result.push_str(&formatted_token);
148
149 let rest = &source[end..];
150 let next_non_ws = rest.find(|c: char| c != ' ' && c != '\t');
151 let already_newline = match next_non_ws {
152 Some(pos) => rest.as_bytes()[pos] == b'\n',
153 None => true,
154 };
155
156 if !already_newline {
157 result.push('\n');
158 for _ in 0..depth {
159 result.push_str(&indent);
160 }
161 }
162
163 last_end = end;
164 continue;
165 }
166
167 if matches!(token.kind, TokenKind::Keyword(_)) {
168 let upper = formatted_token.to_uppercase();
169 let should_newline = (config.newline_before_where && upper == "WHERE")
170 || (config.newline_before_set && upper == "SET")
171 || (config.newline_before_from && upper == "FROM");
172
173 if should_newline {
174 let preceding = &source[..start];
175 let last_newline = preceding.rfind('\n');
176 let line_before = match last_newline {
177 Some(pos) => &preceding[pos + 1..],
178 None => preceding,
179 };
180 let already_on_new_line = line_before.trim().is_empty();
181
182 if !already_on_new_line {
183 result.push('\n');
184 for _ in 0..depth.saturating_add(1) {
185 result.push_str(&indent);
186 }
187 result.push_str(&formatted_token);
188 last_end = end;
189 continue;
190 }
191 }
192 }
193
194 if config.collapse_blank_lines && gap.contains('\n') {
195 let newline_count = gap.matches('\n').count();
196 if newline_count > (config.max_blank_lines as usize + 1) {
197 let mut collapsed = String::new();
198 let mut seen_newlines = 0u32;
199 for ch in gap.chars() {
200 if ch == '\n' {
201 seen_newlines += 1;
202 if seen_newlines <= config.max_blank_lines + 1 {
203 collapsed.push(ch);
204 }
205 } else if seen_newlines <= config.max_blank_lines + 1 {
206 collapsed.push(ch);
207 }
208 }
209 result.push_str(&collapsed);
210 result.push_str(&formatted_token);
211 last_end = end;
212 continue;
213 }
214 }
215
216 result.push_str(gap);
217 result.push_str(&formatted_token);
218 last_end = end;
219 }
220
221 if last_end < source.len() {
222 let trailing = &source[last_end..];
223
224 if config.trailing_semicolon {
225 let trimmed = trailing.trim_end();
226 if !trimmed.is_empty() && !trimmed.ends_with(';') {
227 result.push_str(trimmed);
228 result.push(';');
229 let ws_start = trimmed.len();
230 if ws_start < trailing.len() {
231 result.push_str(&trailing[ws_start..]);
232 }
233 } else {
234 result.push_str(trailing);
235 }
236 } else {
237 result.push_str(trailing);
238 }
239 } else if config.trailing_semicolon && !result.is_empty() && !result.trim_end().ends_with(';') {
240 result.push(';');
241 }
242
243 if result == source { None } else { Some(result) }
244}
245
246fn is_formattable_keyword(upper: &str) -> bool {
247 matches!(
248 upper,
249 "SELECT"
250 | "FROM" | "WHERE"
251 | "AND" | "OR"
252 | "NOT" | "IN"
253 | "ORDER" | "BY"
254 | "GROUP" | "LIMIT"
255 | "OFFSET"
256 | "FETCH" | "CREATE"
257 | "UPDATE"
258 | "DELETE"
259 | "INSERT"
260 | "INTO" | "SET"
261 | "UPSERT"
262 | "MERGE" | "CONTENT"
263 | "RETURN"
264 | "DEFINE"
265 | "REMOVE"
266 | "TABLE" | "FIELD"
267 | "INDEX" | "EVENT"
268 | "FUNCTION"
269 | "PARAM" | "NAMESPACE"
270 | "DATABASE"
271 | "ANALYZER"
272 | "ACCESS"
273 | "SCHEMAFULL"
274 | "SCHEMALESS"
275 | "TYPE" | "DEFAULT"
276 | "READONLY"
277 | "FLEXIBLE"
278 | "UNIQUE"
279 | "FIELDS"
280 | "ON" | "WHEN"
281 | "THEN" | "ELSE"
282 | "END" | "IF"
283 | "FOR" | "LET"
284 | "BEGIN" | "COMMIT"
285 | "CANCEL"
286 | "TRANSACTION"
287 | "USE" | "NS"
288 | "DB" | "AS"
289 | "IS" | "LIKE"
290 | "CONTAINS"
291 | "CONTAINSALL"
292 | "CONTAINSANY"
293 | "CONTAINSNONE"
294 | "INSIDE"
295 | "OUTSIDE"
296 | "INTERSECTS"
297 | "ALLINSIDE"
298 | "ANYINSIDE"
299 | "NONEINSIDE"
300 | "ASC" | "DESC"
301 | "COLLATE"
302 | "NUMERIC"
303 | "COMMENT"
304 | "PERMISSIONS"
305 | "FULL" | "NONE"
306 | "RELATE"
307 | "ONLY" | "VALUE"
308 | "VALUES"
309 | "OVERWRITE"
310 | "EXISTS"
311 | "ASSERT"
312 | "ENFORCED"
313 | "DROP" | "CHANGEFEED"
314 | "INCLUDE"
315 | "ORIGINAL"
316 | "LIVE" | "DIFF"
317 | "KILL" | "SHOW"
318 | "INFO" | "SLEEP"
319 | "THROW" | "BREAK"
320 | "CONTINUE"
321 | "PARALLEL"
322 | "TIMEOUT"
323 | "EXPLAIN"
324 | "SPLIT" | "AT"
325 | "TOKENIZERS"
326 | "FILTERS"
327 | "WITH" | "NOINDEX"
328 | "UNIQ" | "SEARCH"
329 )
330}
331
332#[cfg(test)]
333mod tests {
334 use super::*;
335
336 #[test]
337 fn should_uppercase_keywords_with_default_config() {
338 let source = "select * from user where age > 18;";
339 let formatted = format_source(source, &FormatConfig::default()).unwrap();
340 assert_eq!(formatted, "SELECT * FROM user WHERE age > 18;");
341 }
342
343 #[test]
344 fn should_preserve_comments() {
345 let source = "-- this is a comment\nselect * from user;";
346 let formatted = format_source(source, &FormatConfig::default()).unwrap();
347 assert!(formatted.contains("-- this is a comment"));
348 assert!(formatted.contains("SELECT"));
349 }
350
351 #[test]
352 fn should_preserve_strings() {
353 let source = "select * from user where name = 'select from where';";
354 let formatted = format_source(source, &FormatConfig::default()).unwrap();
355 assert!(formatted.contains("'select from where'"));
356 }
357
358 #[test]
359 fn should_return_none_when_already_formatted() {
360 let source = "SELECT * FROM user WHERE age > 18;";
361 let result = format_source(source, &FormatConfig::default());
362 assert!(result.is_none(), "already formatted should return None");
363 }
364
365 #[test]
366 fn should_add_newline_after_semicolon() {
367 let config = FormatConfig {
368 newline_after_semicolon: true,
369 ..Default::default()
370 };
371 let source = "DEFINE TABLE user; DEFINE TABLE post;";
372 let formatted = format_source(source, &config).unwrap();
373 assert!(
374 formatted.contains("user;\n"),
375 "should have newline after first semicolon: {formatted}"
376 );
377 }
378
379 #[test]
380 fn should_not_double_newline_after_semicolon() {
381 let config = FormatConfig {
382 newline_after_semicolon: true,
383 ..Default::default()
384 };
385 let source = "DEFINE TABLE user;\nDEFINE TABLE post;";
386 let result = format_source(source, &config);
387 assert!(result.is_none(), "already has newlines, should return None");
388 }
389
390 #[test]
391 fn should_add_newline_before_where() {
392 let config = FormatConfig {
393 newline_before_where: true,
394 ..Default::default()
395 };
396 let source = "SELECT * FROM user WHERE age > 18;";
397 let formatted = format_source(source, &config).unwrap();
398 assert_eq!(
399 formatted, "SELECT * FROM user\n\tWHERE age > 18;",
400 "should have newline+indent before WHERE: {formatted}"
401 );
402 }
403
404 #[test]
405 fn should_collapse_blank_lines() {
406 let config = FormatConfig {
407 collapse_blank_lines: true,
408 max_blank_lines: 1,
409 ..Default::default()
410 };
411 let source = "SELECT * FROM user;\n\n\n\n\nSELECT * FROM post;";
412 let formatted = format_source(source, &config).unwrap();
413 let newline_count = formatted.matches('\n').count();
414 assert!(
415 newline_count <= 3,
416 "should collapse to max 1 blank line, got {newline_count} newlines: {formatted}"
417 );
418 }
419
420 #[test]
421 fn should_add_trailing_semicolon() {
422 let config = FormatConfig {
423 trailing_semicolon: true,
424 uppercase_keywords: false,
425 ..Default::default()
426 };
427 let source = "SELECT * FROM user";
428 let formatted = format_source(source, &config).unwrap();
429 assert!(
430 formatted.ends_with(';'),
431 "should have trailing semicolon: {formatted}"
432 );
433 }
434
435 #[test]
436 fn should_apply_custom_config() {
437 let config = FormatConfig {
438 uppercase_keywords: true,
439 indent_style: IndentStyle::Space,
440 indent_width: 2,
441 newline_after_semicolon: true,
442 newline_before_where: true,
443 ..Default::default()
444 };
445 assert!(config.uppercase_keywords);
446 assert_eq!(config.indent_style, IndentStyle::Space);
447 assert_eq!(config.indent_width, 2);
448 assert!(config.newline_after_semicolon);
449 assert!(config.newline_before_where);
450 assert!(!config.newline_before_set);
451 }
452
453 #[test]
454 fn should_add_newline_before_set() {
455 let config = FormatConfig {
456 newline_before_set: true,
457 ..Default::default()
458 };
459 let source = "UPDATE user SET name = 'Alice';";
460 let formatted = format_source(source, &config).unwrap();
461 assert_eq!(
462 formatted, "UPDATE user\n\tSET name = 'Alice';",
463 "should have newline+indent before SET: {formatted}"
464 );
465 }
466}