Skip to main content

surql_parser/upstream/fmt/
escape.rs

1use std::fmt::{Display, Write};
2use surrealdb_types::{SqlFormat, ToSql};
3pub struct EscapeWriter<W> {
4	escape_char: char,
5	writer: W,
6}
7impl<'a> EscapeWriter<&'a mut String> {
8	fn escape<D: Display + ?Sized>(into: &'a mut String, escape: char, display: &D) {
9		Self {
10			escape_char: escape,
11			writer: into,
12		}
13		.write(display)
14	}
15	fn write<D: Display + ?Sized>(&mut self, display: &D) {
16		let _ = self.write_fmt(format_args!("{display}"));
17	}
18}
19impl<W: Write> Write for EscapeWriter<W> {
20	fn write_str(&mut self, s: &str) -> std::fmt::Result {
21		for c in s.chars() {
22			self.write_char(c)?;
23		}
24		Ok(())
25	}
26	fn write_char(&mut self, c: char) -> std::fmt::Result {
27		match c {
28			'\0' => {
29				self.writer.write_str("\\0")?;
30			}
31			'\r' => {
32				self.writer.write_str("\\r")?;
33			}
34			'\t' => {
35				self.writer.write_str("\\t")?;
36			}
37			'\n' => {
38				self.writer.write_str("\\n")?;
39			}
40			'\x08' => {
41				self.writer.write_str("\\u{8}")?;
42			}
43			'\x0C' => {
44				self.writer.write_str("\\f")?;
45			}
46			'\\' => {
47				self.writer.write_str("\\\\")?;
48			}
49			x if x == self.escape_char => {
50				self.writer.write_char('\\')?;
51				self.writer.write_char(x)?;
52			}
53			_ => self.writer.write_char(c)?,
54		}
55		Ok(())
56	}
57}
58pub struct QuoteStr<'a>(pub &'a str);
59impl ToSql for QuoteStr<'_> {
60	fn fmt_sql(&self, f: &mut String, _: SqlFormat) {
61		let s = self.0;
62		let quote = if s.contains('\'') { '\"' } else { '\'' };
63		f.push(quote);
64		EscapeWriter::escape(f, quote, self.0);
65		f.push(quote);
66	}
67}
68/// Escapes identifiers which might be used in the same place as a keyword.
69pub struct EscapeIdent<T>(pub T);
70impl<T: AsRef<str>> ToSql for EscapeIdent<T> {
71	fn fmt_sql(&self, f: &mut String, fmt: SqlFormat) {
72		let s = self.0.as_ref();
73		if crate::upstream::syn::could_be_reserved_keyword(s) {
74			f.push('`');
75			EscapeWriter::escape(f, '`', self.0.as_ref());
76			f.push('`');
77		} else {
78			EscapeKwFreeIdent(s).fmt_sql(f, fmt);
79		}
80	}
81}
82pub struct EscapeKwIdent<'a>(pub &'a str, pub &'a [&'static str]);
83impl ToSql for EscapeKwIdent<'_> {
84	fn fmt_sql(&self, f: &mut String, fmt: SqlFormat) {
85		if self.1.iter().any(|x| x.eq_ignore_ascii_case(self.0)) {
86			f.push('`');
87			EscapeWriter::escape(f, '`', self.0);
88			f.push('`');
89		} else {
90			EscapeKwFreeIdent(self.0).fmt_sql(f, fmt);
91		}
92	}
93}
94/// Escapes identifiers which can never be used in the same place as a keyword.
95///
96/// Examples of this is a Param as '$' is in front of the identifier so it
97/// cannot be an
98pub struct EscapeKwFreeIdent<'a>(pub &'a str);
99impl ToSql for EscapeKwFreeIdent<'_> {
100	fn fmt_sql(&self, f: &mut String, _: SqlFormat) {
101		let s = self.0;
102		if s.is_empty()
103			|| s.starts_with(|x: char| x.is_ascii_digit())
104			|| s.contains(|x: char| !x.is_ascii_alphanumeric() && x != '_')
105			|| s == "NaN"
106			|| s == "Infinity"
107		{
108			f.push('`');
109			EscapeWriter::escape(f, '`', self.0);
110			f.push('`');
111		} else {
112			f.push_str(s);
113		}
114	}
115}
116pub struct EscapeObjectKey<'a>(pub &'a str);
117impl ToSql for EscapeObjectKey<'_> {
118	fn fmt_sql(&self, f: &mut String, _: SqlFormat) {
119		let s = self.0;
120		if s.is_empty()
121			|| s.starts_with(|x: char| x.is_ascii_digit())
122			|| s.contains(|x: char| !x.is_ascii_alphanumeric() && x != '_')
123			|| s == "NaN"
124			|| s == "Infinity"
125		{
126			f.push('"');
127			EscapeWriter::escape(f, '"', self.0);
128			f.push('"');
129		} else {
130			f.push_str(s);
131		}
132	}
133}
134pub struct EscapeRidKey<'a>(pub &'a str);
135impl ToSql for EscapeRidKey<'_> {
136	fn fmt_sql(&self, f: &mut String, _: SqlFormat) {
137		let s = self.0;
138		if s.is_empty()
139			|| s.contains(|x: char| !x.is_ascii_alphanumeric() && x != '_')
140			|| !s.contains(|x: char| !x.is_ascii_digit() && x != '_')
141			|| s == "Infinity"
142			|| s == "NaN"
143		{
144			f.push('`');
145			EscapeWriter::escape(f, '`', self.0);
146			f.push('`');
147		} else {
148			f.push_str(s)
149		}
150	}
151}