Skip to main content

surql_parser/
formatting.rs

1//! Configurable SurrealQL formatter.
2//!
3//! Preserves comments, strings, and document structure while applying
4//! configurable formatting rules via `FormatConfig`.
5//!
6//! Config is loaded from `.surqlformat.toml` in the workspace root,
7//! or falls back to sensible defaults.
8
9use serde::Deserialize;
10
11/// Formatting configuration loaded from `.surqlformat.toml`.
12#[derive(Debug, Clone, Deserialize)]
13#[serde(default)]
14pub struct FormatConfig {
15	pub uppercase_keywords: bool,
16	pub indent_style: IndentStyle,
17	pub indent_width: u32,
18	pub newline_after_semicolon: bool,
19	pub newline_before_where: bool,
20	pub newline_before_set: bool,
21	pub newline_before_from: bool,
22	pub trailing_semicolon: bool,
23	pub collapse_blank_lines: bool,
24	pub max_blank_lines: u32,
25}
26
27#[derive(Debug, Clone, Deserialize, PartialEq)]
28#[serde(rename_all = "lowercase")]
29pub enum IndentStyle {
30	Tab,
31	Space,
32}
33
34impl Default for FormatConfig {
35	fn default() -> Self {
36		Self {
37			uppercase_keywords: true,
38			indent_style: IndentStyle::Tab,
39			indent_width: 4,
40			newline_after_semicolon: false,
41			newline_before_where: false,
42			newline_before_set: false,
43			newline_before_from: false,
44			trailing_semicolon: false,
45			collapse_blank_lines: false,
46			max_blank_lines: 2,
47		}
48	}
49}
50
51impl FormatConfig {
52	#[cfg(feature = "cli")]
53	pub fn load_from_dir(dir: &std::path::Path) -> Self {
54		let config_path = dir.join(".surqlformat.toml");
55		if let Ok(content) = std::fs::read_to_string(&config_path) {
56			match toml::from_str::<FormatConfig>(&content) {
57				Ok(config) => return config,
58				Err(e) => {
59					eprintln!(
60						"Invalid .surqlformat.toml at {}: {e}",
61						config_path.display()
62					);
63				}
64			}
65		}
66		Self::default()
67	}
68
69	fn indent_str(&self) -> String {
70		match self.indent_style {
71			IndentStyle::Tab => "\t".to_string(),
72			IndentStyle::Space => " ".repeat(self.indent_width as usize),
73		}
74	}
75}
76
77/// Format SurrealQL source text using the given config.
78///
79/// Returns `Some(formatted)` if changes were made, `None` if the source is already formatted.
80pub fn format_source(source: &str, config: &FormatConfig) -> Option<String> {
81	configurable_format(source, config)
82}
83
84/// Configurable lexer-based formatter.
85fn configurable_format(source: &str, config: &FormatConfig) -> Option<String> {
86	use crate::upstream::syn::lexer::Lexer;
87	use crate::upstream::syn::token::{Delim, TokenKind};
88
89	let bytes = source.as_bytes();
90	if bytes.is_empty() || bytes.len() > u32::MAX as usize {
91		return None;
92	}
93
94	let tokens: Vec<_> = match std::panic::catch_unwind(|| Lexer::new(bytes).collect()) {
95		Ok(t) => t,
96		Err(e) => {
97			tracing::error!("Lexer panicked during formatting: {e:?}");
98			return None;
99		}
100	};
101
102	let mut result = String::with_capacity(source.len());
103	let mut last_end: usize = 0;
104	let mut depth: u32 = 0;
105	let indent = config.indent_str();
106
107	for token in &tokens {
108		let start = token.span.offset as usize;
109		let end = start + token.span.len as usize;
110		if end > source.len() {
111			continue;
112		}
113
114		let gap = &source[last_end..start];
115		let original = &source[start..end];
116
117		let formatted_token =
118			if config.uppercase_keywords && matches!(token.kind, TokenKind::Keyword(_)) {
119				let upper = original.to_uppercase();
120				if is_formattable_keyword(&upper) {
121					upper
122				} else {
123					original.to_string()
124				}
125			} else {
126				original.to_string()
127			};
128
129		if token.kind == TokenKind::OpenDelim(Delim::Brace) {
130			result.push_str(gap);
131			result.push_str(&formatted_token);
132			depth += 1;
133			last_end = end;
134			continue;
135		}
136
137		if token.kind == TokenKind::CloseDelim(Delim::Brace) {
138			depth = depth.saturating_sub(1);
139			result.push_str(gap);
140			result.push_str(&formatted_token);
141			last_end = end;
142			continue;
143		}
144
145		if token.kind == TokenKind::SemiColon && config.newline_after_semicolon {
146			result.push_str(gap);
147			result.push_str(&formatted_token);
148
149			let rest = &source[end..];
150			let next_non_ws = rest.find(|c: char| c != ' ' && c != '\t');
151			let already_newline = match next_non_ws {
152				Some(pos) => rest.as_bytes()[pos] == b'\n',
153				None => true,
154			};
155
156			if !already_newline {
157				result.push('\n');
158				for _ in 0..depth {
159					result.push_str(&indent);
160				}
161			}
162
163			last_end = end;
164			continue;
165		}
166
167		if matches!(token.kind, TokenKind::Keyword(_)) {
168			let upper = formatted_token.to_uppercase();
169			let should_newline = (config.newline_before_where && upper == "WHERE")
170				|| (config.newline_before_set && upper == "SET")
171				|| (config.newline_before_from && upper == "FROM");
172
173			if should_newline {
174				let preceding = &source[..start];
175				let last_newline = preceding.rfind('\n');
176				let line_before = match last_newline {
177					Some(pos) => &preceding[pos + 1..],
178					None => preceding,
179				};
180				let already_on_new_line = line_before.trim().is_empty();
181
182				if !already_on_new_line {
183					result.push('\n');
184					for _ in 0..depth.saturating_add(1) {
185						result.push_str(&indent);
186					}
187					result.push_str(&formatted_token);
188					last_end = end;
189					continue;
190				}
191			}
192		}
193
194		if config.collapse_blank_lines && gap.contains('\n') {
195			let newline_count = gap.matches('\n').count();
196			if newline_count > (config.max_blank_lines as usize + 1) {
197				let mut collapsed = String::new();
198				let mut seen_newlines = 0u32;
199				for ch in gap.chars() {
200					if ch == '\n' {
201						seen_newlines += 1;
202						if seen_newlines <= config.max_blank_lines + 1 {
203							collapsed.push(ch);
204						}
205					} else if seen_newlines <= config.max_blank_lines + 1 {
206						collapsed.push(ch);
207					}
208				}
209				result.push_str(&collapsed);
210				result.push_str(&formatted_token);
211				last_end = end;
212				continue;
213			}
214		}
215
216		result.push_str(gap);
217		result.push_str(&formatted_token);
218		last_end = end;
219	}
220
221	if last_end < source.len() {
222		let trailing = &source[last_end..];
223
224		if config.trailing_semicolon {
225			let trimmed = trailing.trim_end();
226			if !trimmed.is_empty() && !trimmed.ends_with(';') {
227				result.push_str(trimmed);
228				result.push(';');
229				let ws_start = trimmed.len();
230				if ws_start < trailing.len() {
231					result.push_str(&trailing[ws_start..]);
232				}
233			} else {
234				result.push_str(trailing);
235			}
236		} else {
237			result.push_str(trailing);
238		}
239	} else if config.trailing_semicolon && !result.is_empty() && !result.trim_end().ends_with(';') {
240		result.push(';');
241	}
242
243	if result == source { None } else { Some(result) }
244}
245
246fn is_formattable_keyword(upper: &str) -> bool {
247	matches!(
248		upper,
249		"SELECT"
250			| "FROM" | "WHERE"
251			| "AND" | "OR"
252			| "NOT" | "IN"
253			| "ORDER" | "BY"
254			| "GROUP" | "LIMIT"
255			| "OFFSET"
256			| "FETCH" | "CREATE"
257			| "UPDATE"
258			| "DELETE"
259			| "INSERT"
260			| "INTO" | "SET"
261			| "UPSERT"
262			| "MERGE" | "CONTENT"
263			| "RETURN"
264			| "DEFINE"
265			| "REMOVE"
266			| "TABLE" | "FIELD"
267			| "INDEX" | "EVENT"
268			| "FUNCTION"
269			| "PARAM" | "NAMESPACE"
270			| "DATABASE"
271			| "ANALYZER"
272			| "ACCESS"
273			| "SCHEMAFULL"
274			| "SCHEMALESS"
275			| "TYPE" | "DEFAULT"
276			| "READONLY"
277			| "FLEXIBLE"
278			| "UNIQUE"
279			| "FIELDS"
280			| "ON" | "WHEN"
281			| "THEN" | "ELSE"
282			| "END" | "IF"
283			| "FOR" | "LET"
284			| "BEGIN" | "COMMIT"
285			| "CANCEL"
286			| "TRANSACTION"
287			| "USE" | "NS"
288			| "DB" | "AS"
289			| "IS" | "LIKE"
290			| "CONTAINS"
291			| "CONTAINSALL"
292			| "CONTAINSANY"
293			| "CONTAINSNONE"
294			| "INSIDE"
295			| "OUTSIDE"
296			| "INTERSECTS"
297			| "ALLINSIDE"
298			| "ANYINSIDE"
299			| "NONEINSIDE"
300			| "ASC" | "DESC"
301			| "COLLATE"
302			| "NUMERIC"
303			| "COMMENT"
304			| "PERMISSIONS"
305			| "FULL" | "NONE"
306			| "RELATE"
307			| "ONLY" | "VALUE"
308			| "VALUES"
309			| "OVERWRITE"
310			| "EXISTS"
311			| "ASSERT"
312			| "ENFORCED"
313			| "DROP" | "CHANGEFEED"
314			| "INCLUDE"
315			| "ORIGINAL"
316			| "LIVE" | "DIFF"
317			| "KILL" | "SHOW"
318			| "INFO" | "SLEEP"
319			| "THROW" | "BREAK"
320			| "CONTINUE"
321			| "PARALLEL"
322			| "TIMEOUT"
323			| "EXPLAIN"
324			| "SPLIT" | "AT"
325			| "TOKENIZERS"
326			| "FILTERS"
327			| "WITH" | "NOINDEX"
328			| "UNIQ" | "SEARCH"
329	)
330}
331
332#[cfg(test)]
333mod tests {
334	use super::*;
335
336	#[test]
337	fn should_uppercase_keywords_with_default_config() {
338		let source = "select * from user where age > 18;";
339		let formatted = format_source(source, &FormatConfig::default()).unwrap();
340		assert_eq!(formatted, "SELECT * FROM user WHERE age > 18;");
341	}
342
343	#[test]
344	fn should_preserve_comments() {
345		let source = "-- this is a comment\nselect * from user;";
346		let formatted = format_source(source, &FormatConfig::default()).unwrap();
347		assert!(formatted.contains("-- this is a comment"));
348		assert!(formatted.contains("SELECT"));
349	}
350
351	#[test]
352	fn should_preserve_strings() {
353		let source = "select * from user where name = 'select from where';";
354		let formatted = format_source(source, &FormatConfig::default()).unwrap();
355		assert!(formatted.contains("'select from where'"));
356	}
357
358	#[test]
359	fn should_return_none_when_already_formatted() {
360		let source = "SELECT * FROM user WHERE age > 18;";
361		let result = format_source(source, &FormatConfig::default());
362		assert!(result.is_none(), "already formatted should return None");
363	}
364
365	#[test]
366	fn should_add_newline_after_semicolon() {
367		let config = FormatConfig {
368			newline_after_semicolon: true,
369			..Default::default()
370		};
371		let source = "DEFINE TABLE user; DEFINE TABLE post;";
372		let formatted = format_source(source, &config).unwrap();
373		assert!(
374			formatted.contains("user;\n"),
375			"should have newline after first semicolon: {formatted}"
376		);
377	}
378
379	#[test]
380	fn should_not_double_newline_after_semicolon() {
381		let config = FormatConfig {
382			newline_after_semicolon: true,
383			..Default::default()
384		};
385		let source = "DEFINE TABLE user;\nDEFINE TABLE post;";
386		let result = format_source(source, &config);
387		assert!(result.is_none(), "already has newlines, should return None");
388	}
389
390	#[test]
391	fn should_add_newline_before_where() {
392		let config = FormatConfig {
393			newline_before_where: true,
394			..Default::default()
395		};
396		let source = "SELECT * FROM user WHERE age > 18;";
397		let formatted = format_source(source, &config).unwrap();
398		assert_eq!(
399			formatted, "SELECT * FROM user\n\tWHERE age > 18;",
400			"should have newline+indent before WHERE: {formatted}"
401		);
402	}
403
404	#[test]
405	fn should_collapse_blank_lines() {
406		let config = FormatConfig {
407			collapse_blank_lines: true,
408			max_blank_lines: 1,
409			..Default::default()
410		};
411		let source = "SELECT * FROM user;\n\n\n\n\nSELECT * FROM post;";
412		let formatted = format_source(source, &config).unwrap();
413		let newline_count = formatted.matches('\n').count();
414		assert!(
415			newline_count <= 3,
416			"should collapse to max 1 blank line, got {newline_count} newlines: {formatted}"
417		);
418	}
419
420	#[test]
421	fn should_add_trailing_semicolon() {
422		let config = FormatConfig {
423			trailing_semicolon: true,
424			uppercase_keywords: false,
425			..Default::default()
426		};
427		let source = "SELECT * FROM user";
428		let formatted = format_source(source, &config).unwrap();
429		assert!(
430			formatted.ends_with(';'),
431			"should have trailing semicolon: {formatted}"
432		);
433	}
434
435	#[test]
436	fn should_apply_custom_config() {
437		let config = FormatConfig {
438			uppercase_keywords: true,
439			indent_style: IndentStyle::Space,
440			indent_width: 2,
441			newline_after_semicolon: true,
442			newline_before_where: true,
443			..Default::default()
444		};
445		assert!(config.uppercase_keywords);
446		assert_eq!(config.indent_style, IndentStyle::Space);
447		assert_eq!(config.indent_width, 2);
448		assert!(config.newline_after_semicolon);
449		assert!(config.newline_before_where);
450		assert!(!config.newline_before_set);
451	}
452
453	#[test]
454	fn should_add_newline_before_set() {
455		let config = FormatConfig {
456			newline_before_set: true,
457			..Default::default()
458		};
459		let source = "UPDATE user SET name = 'Alice';";
460		let formatted = format_source(source, &config).unwrap();
461		assert_eq!(
462			formatted, "UPDATE user\n\tSET name = 'Alice';",
463			"should have newline+indent before SET: {formatted}"
464		);
465	}
466}