Skip to main content

sql_orm_sqlserver/
quoting.rs

1use sql_orm_core::OrmError;
2use sql_orm_query::{ColumnRef, TableRef};
3
4pub fn quote_identifier(identifier: &str) -> Result<String, OrmError> {
5    validate_identifier(identifier)?;
6
7    let escaped = identifier.replace(']', "]]");
8    Ok(format!("[{escaped}]"))
9}
10
11pub fn quote_qualified_identifier(schema: &str, identifier: &str) -> Result<String, OrmError> {
12    Ok(format!(
13        "{}.{}",
14        quote_identifier(schema)?,
15        quote_identifier(identifier)?,
16    ))
17}
18
19pub fn quote_table_ref(table: &TableRef) -> Result<String, OrmError> {
20    quote_qualified_identifier(table.schema, table.table)
21}
22
23pub fn quote_table_source(table: &TableRef) -> Result<String, OrmError> {
24    let source = quote_table_ref(table)?;
25
26    match table.alias {
27        Some(alias) => Ok(format!("{source} AS {}", quote_identifier(alias)?)),
28        None => Ok(source),
29    }
30}
31
32pub fn quote_table_reference(table: &TableRef) -> Result<String, OrmError> {
33    match table.alias {
34        Some(alias) => quote_identifier(alias),
35        None => quote_table_ref(table),
36    }
37}
38
39pub fn quote_column_ref(column: &ColumnRef) -> Result<String, OrmError> {
40    Ok(format!(
41        "{}.{}",
42        quote_table_reference(&column.table)?,
43        quote_identifier(column.column_name)?,
44    ))
45}
46
47fn validate_identifier(identifier: &str) -> Result<(), OrmError> {
48    if identifier.is_empty() {
49        return Err(OrmError::compile("SQL Server identifier cannot be empty"));
50    }
51
52    if identifier.contains('.') {
53        return Err(OrmError::compile(
54            "SQL Server identifier cannot contain '.'; quote each part separately",
55        ));
56    }
57
58    if identifier.chars().any(|ch| ch.is_control()) {
59        return Err(OrmError::compile(
60            "SQL Server identifier cannot contain control characters",
61        ));
62    }
63
64    Ok(())
65}
66
67#[cfg(test)]
68mod tests {
69    use super::{
70        quote_column_ref, quote_identifier, quote_qualified_identifier, quote_table_ref,
71        quote_table_reference, quote_table_source,
72    };
73    use sql_orm_core::OrmErrorKind;
74    use sql_orm_query::{ColumnRef, TableRef};
75
76    #[test]
77    fn quotes_simple_identifier_with_brackets() {
78        assert_eq!(quote_identifier("customers").unwrap(), "[customers]");
79    }
80
81    #[test]
82    fn escapes_closing_brackets_inside_identifier() {
83        assert_eq!(
84            quote_identifier("report]archive").unwrap(),
85            "[report]]archive]"
86        );
87    }
88
89    #[test]
90    fn rejects_empty_identifier() {
91        let error = quote_identifier("").unwrap_err();
92
93        assert_eq!(error.kind(), OrmErrorKind::Compile);
94        assert_eq!(error.message(), "SQL Server identifier cannot be empty");
95    }
96
97    #[test]
98    fn rejects_control_characters() {
99        let error = quote_identifier("line\nbreak").unwrap_err();
100
101        assert_eq!(
102            error.message(),
103            "SQL Server identifier cannot contain control characters"
104        );
105    }
106
107    #[test]
108    fn rejects_multipart_identifier_in_single_segment_api() {
109        let error = quote_identifier("dbo.customers").unwrap_err();
110
111        assert_eq!(
112            error.message(),
113            "SQL Server identifier cannot contain '.'; quote each part separately"
114        );
115    }
116
117    #[test]
118    fn quotes_schema_qualified_identifier() {
119        assert_eq!(
120            quote_qualified_identifier("sales", "customers").unwrap(),
121            "[sales].[customers]"
122        );
123    }
124
125    #[test]
126    fn quotes_table_and_column_refs_from_ast() {
127        let table = TableRef::new("sales", "customers");
128        let column = ColumnRef::new(table, "email", "email");
129
130        assert_eq!(quote_table_ref(&table).unwrap(), "[sales].[customers]");
131        assert_eq!(
132            quote_column_ref(&column).unwrap(),
133            "[sales].[customers].[email]"
134        );
135    }
136
137    #[test]
138    fn quotes_aliased_table_sources_and_column_refs_from_ast() {
139        let table = TableRef::with_alias("sales", "customers", "c");
140        let column = ColumnRef::new(table, "email", "email");
141
142        assert_eq!(
143            quote_table_source(&table).unwrap(),
144            "[sales].[customers] AS [c]"
145        );
146        assert_eq!(quote_table_reference(&table).unwrap(), "[c]");
147        assert_eq!(quote_column_ref(&column).unwrap(), "[c].[email]");
148    }
149}