Skip to main content

sql_orm_sqlserver/
quoting.rs

1use sql_orm_core::OrmError;
2use sql_orm_query::{ColumnRef, TableRef};
3
4pub fn quote_identifier(identifier: &str) -> Result<String, OrmError> {
5    validate_identifier(identifier)?;
6
7    let escaped = identifier.replace(']', "]]");
8    Ok(format!("[{escaped}]"))
9}
10
11pub fn quote_qualified_identifier(schema: &str, identifier: &str) -> Result<String, OrmError> {
12    Ok(format!(
13        "{}.{}",
14        quote_identifier(schema)?,
15        quote_identifier(identifier)?,
16    ))
17}
18
19pub fn quote_table_ref(table: &TableRef) -> Result<String, OrmError> {
20    quote_qualified_identifier(table.schema, table.table)
21}
22
23pub fn quote_table_source(table: &TableRef) -> Result<String, OrmError> {
24    let source = quote_table_ref(table)?;
25
26    match table.alias {
27        Some(alias) => Ok(format!("{source} AS {}", quote_identifier(alias)?)),
28        None => Ok(source),
29    }
30}
31
32pub fn quote_table_reference(table: &TableRef) -> Result<String, OrmError> {
33    match table.alias {
34        Some(alias) => quote_identifier(alias),
35        None => quote_table_ref(table),
36    }
37}
38
39pub fn quote_column_ref(column: &ColumnRef) -> Result<String, OrmError> {
40    Ok(format!(
41        "{}.{}",
42        quote_table_reference(&column.table)?,
43        quote_identifier(column.column_name)?,
44    ))
45}
46
47fn validate_identifier(identifier: &str) -> Result<(), OrmError> {
48    if identifier.is_empty() {
49        return Err(OrmError::new("SQL Server identifier cannot be empty"));
50    }
51
52    if identifier.contains('.') {
53        return Err(OrmError::new(
54            "SQL Server identifier cannot contain '.'; quote each part separately",
55        ));
56    }
57
58    if identifier.chars().any(|ch| ch.is_control()) {
59        return Err(OrmError::new(
60            "SQL Server identifier cannot contain control characters",
61        ));
62    }
63
64    Ok(())
65}
66
67#[cfg(test)]
68mod tests {
69    use super::{
70        quote_column_ref, quote_identifier, quote_qualified_identifier, quote_table_ref,
71        quote_table_reference, quote_table_source,
72    };
73    use sql_orm_query::{ColumnRef, TableRef};
74
75    #[test]
76    fn quotes_simple_identifier_with_brackets() {
77        assert_eq!(quote_identifier("customers").unwrap(), "[customers]");
78    }
79
80    #[test]
81    fn escapes_closing_brackets_inside_identifier() {
82        assert_eq!(
83            quote_identifier("report]archive").unwrap(),
84            "[report]]archive]"
85        );
86    }
87
88    #[test]
89    fn rejects_empty_identifier() {
90        let error = quote_identifier("").unwrap_err();
91
92        assert_eq!(error.message(), "SQL Server identifier cannot be empty");
93    }
94
95    #[test]
96    fn rejects_control_characters() {
97        let error = quote_identifier("line\nbreak").unwrap_err();
98
99        assert_eq!(
100            error.message(),
101            "SQL Server identifier cannot contain control characters"
102        );
103    }
104
105    #[test]
106    fn rejects_multipart_identifier_in_single_segment_api() {
107        let error = quote_identifier("dbo.customers").unwrap_err();
108
109        assert_eq!(
110            error.message(),
111            "SQL Server identifier cannot contain '.'; quote each part separately"
112        );
113    }
114
115    #[test]
116    fn quotes_schema_qualified_identifier() {
117        assert_eq!(
118            quote_qualified_identifier("sales", "customers").unwrap(),
119            "[sales].[customers]"
120        );
121    }
122
123    #[test]
124    fn quotes_table_and_column_refs_from_ast() {
125        let table = TableRef::new("sales", "customers");
126        let column = ColumnRef::new(table, "email", "email");
127
128        assert_eq!(quote_table_ref(&table).unwrap(), "[sales].[customers]");
129        assert_eq!(
130            quote_column_ref(&column).unwrap(),
131            "[sales].[customers].[email]"
132        );
133    }
134
135    #[test]
136    fn quotes_aliased_table_sources_and_column_refs_from_ast() {
137        let table = TableRef::with_alias("sales", "customers", "c");
138        let column = ColumnRef::new(table, "email", "email");
139
140        assert_eq!(
141            quote_table_source(&table).unwrap(),
142            "[sales].[customers] AS [c]"
143        );
144        assert_eq!(quote_table_reference(&table).unwrap(), "[c]");
145        assert_eq!(quote_column_ref(&column).unwrap(), "[c].[email]");
146    }
147}