1use crate::AliasedEntityColumn;
2use sql_orm_core::{Entity, EntityColumn, SqlTypeMapping, SqlValue};
3use sql_orm_query::{Expr, Predicate};
4
5const LIKE_ESCAPE_CHAR: char = '\\';
6
7#[allow(clippy::wrong_self_convention)]
9pub trait EntityColumnPredicateExt<E: Entity> {
10 fn eq<V>(self, value: V) -> Predicate
11 where
12 V: SqlTypeMapping;
13
14 fn ne<V>(self, value: V) -> Predicate
15 where
16 V: SqlTypeMapping;
17
18 fn gt<V>(self, value: V) -> Predicate
19 where
20 V: SqlTypeMapping;
21
22 fn gte<V>(self, value: V) -> Predicate
23 where
24 V: SqlTypeMapping;
25
26 fn lt<V>(self, value: V) -> Predicate
27 where
28 V: SqlTypeMapping;
29
30 fn lte<V>(self, value: V) -> Predicate
31 where
32 V: SqlTypeMapping;
33
34 fn is_null(self) -> Predicate;
35
36 fn is_not_null(self) -> Predicate;
37
38 fn contains(self, value: impl Into<String>) -> Predicate;
39
40 fn starts_with(self, value: impl Into<String>) -> Predicate;
41
42 fn ends_with(self, value: impl Into<String>) -> Predicate;
43}
44
45impl<E: Entity> EntityColumnPredicateExt<E> for EntityColumn<E> {
46 fn eq<V>(self, value: V) -> Predicate
47 where
48 V: SqlTypeMapping,
49 {
50 Predicate::eq(Expr::from(self), Expr::value(value.to_sql_value()))
51 }
52
53 fn ne<V>(self, value: V) -> Predicate
54 where
55 V: SqlTypeMapping,
56 {
57 Predicate::ne(Expr::from(self), Expr::value(value.to_sql_value()))
58 }
59
60 fn gt<V>(self, value: V) -> Predicate
61 where
62 V: SqlTypeMapping,
63 {
64 Predicate::gt(Expr::from(self), Expr::value(value.to_sql_value()))
65 }
66
67 fn gte<V>(self, value: V) -> Predicate
68 where
69 V: SqlTypeMapping,
70 {
71 Predicate::gte(Expr::from(self), Expr::value(value.to_sql_value()))
72 }
73
74 fn lt<V>(self, value: V) -> Predicate
75 where
76 V: SqlTypeMapping,
77 {
78 Predicate::lt(Expr::from(self), Expr::value(value.to_sql_value()))
79 }
80
81 fn lte<V>(self, value: V) -> Predicate
82 where
83 V: SqlTypeMapping,
84 {
85 Predicate::lte(Expr::from(self), Expr::value(value.to_sql_value()))
86 }
87
88 fn is_null(self) -> Predicate {
89 Predicate::is_null(Expr::from(self))
90 }
91
92 fn is_not_null(self) -> Predicate {
93 Predicate::is_not_null(Expr::from(self))
94 }
95
96 fn contains(self, value: impl Into<String>) -> Predicate {
97 Predicate::like_escaped(
98 Expr::from(self),
99 Expr::value(SqlValue::String(format!(
100 "%{}%",
101 escape_like_literal(value.into())
102 ))),
103 LIKE_ESCAPE_CHAR,
104 )
105 }
106
107 fn starts_with(self, value: impl Into<String>) -> Predicate {
108 Predicate::like_escaped(
109 Expr::from(self),
110 Expr::value(SqlValue::String(format!(
111 "{}%",
112 escape_like_literal(value.into())
113 ))),
114 LIKE_ESCAPE_CHAR,
115 )
116 }
117
118 fn ends_with(self, value: impl Into<String>) -> Predicate {
119 Predicate::like_escaped(
120 Expr::from(self),
121 Expr::value(SqlValue::String(format!(
122 "%{}",
123 escape_like_literal(value.into())
124 ))),
125 LIKE_ESCAPE_CHAR,
126 )
127 }
128}
129
130impl<E: Entity> EntityColumnPredicateExt<E> for AliasedEntityColumn<E> {
131 fn eq<V>(self, value: V) -> Predicate
132 where
133 V: SqlTypeMapping,
134 {
135 Predicate::eq(Expr::from(self), Expr::value(value.to_sql_value()))
136 }
137
138 fn ne<V>(self, value: V) -> Predicate
139 where
140 V: SqlTypeMapping,
141 {
142 Predicate::ne(Expr::from(self), Expr::value(value.to_sql_value()))
143 }
144
145 fn gt<V>(self, value: V) -> Predicate
146 where
147 V: SqlTypeMapping,
148 {
149 Predicate::gt(Expr::from(self), Expr::value(value.to_sql_value()))
150 }
151
152 fn gte<V>(self, value: V) -> Predicate
153 where
154 V: SqlTypeMapping,
155 {
156 Predicate::gte(Expr::from(self), Expr::value(value.to_sql_value()))
157 }
158
159 fn lt<V>(self, value: V) -> Predicate
160 where
161 V: SqlTypeMapping,
162 {
163 Predicate::lt(Expr::from(self), Expr::value(value.to_sql_value()))
164 }
165
166 fn lte<V>(self, value: V) -> Predicate
167 where
168 V: SqlTypeMapping,
169 {
170 Predicate::lte(Expr::from(self), Expr::value(value.to_sql_value()))
171 }
172
173 fn is_null(self) -> Predicate {
174 Predicate::is_null(Expr::from(self))
175 }
176
177 fn is_not_null(self) -> Predicate {
178 Predicate::is_not_null(Expr::from(self))
179 }
180
181 fn contains(self, value: impl Into<String>) -> Predicate {
182 Predicate::like_escaped(
183 Expr::from(self),
184 Expr::value(SqlValue::String(format!(
185 "%{}%",
186 escape_like_literal(value.into())
187 ))),
188 LIKE_ESCAPE_CHAR,
189 )
190 }
191
192 fn starts_with(self, value: impl Into<String>) -> Predicate {
193 Predicate::like_escaped(
194 Expr::from(self),
195 Expr::value(SqlValue::String(format!(
196 "{}%",
197 escape_like_literal(value.into())
198 ))),
199 LIKE_ESCAPE_CHAR,
200 )
201 }
202
203 fn ends_with(self, value: impl Into<String>) -> Predicate {
204 Predicate::like_escaped(
205 Expr::from(self),
206 Expr::value(SqlValue::String(format!(
207 "%{}",
208 escape_like_literal(value.into())
209 ))),
210 LIKE_ESCAPE_CHAR,
211 )
212 }
213}
214
215fn escape_like_literal(value: impl AsRef<str>) -> String {
216 let value = value.as_ref();
217 let mut escaped = String::with_capacity(value.len());
218
219 for ch in value.chars() {
220 if matches!(ch, LIKE_ESCAPE_CHAR | '%' | '_' | '[' | ']') {
221 escaped.push(LIKE_ESCAPE_CHAR);
222 }
223 escaped.push(ch);
224 }
225
226 escaped
227}
228
229#[cfg(test)]
230mod tests {
231 use super::{EntityColumnPredicateExt, LIKE_ESCAPE_CHAR};
232 use crate::EntityColumnAliasExt;
233 use sql_orm_core::{
234 ColumnMetadata, Entity, EntityColumn, EntityMetadata, PrimaryKeyMetadata, SqlServerType,
235 SqlValue,
236 };
237 use sql_orm_query::{ColumnRef, Expr, Predicate, TableRef};
238
239 struct TestEntity;
240
241 static TEST_ENTITY_COLUMNS: [ColumnMetadata; 2] = [
242 ColumnMetadata {
243 rust_field: "id",
244 column_name: "id",
245 renamed_from: None,
246 sql_type: SqlServerType::BigInt,
247 nullable: false,
248 primary_key: true,
249 identity: None,
250 default_sql: None,
251 computed_sql: None,
252 rowversion: false,
253 insertable: false,
254 updatable: false,
255 max_length: None,
256 precision: None,
257 scale: None,
258 },
259 ColumnMetadata {
260 rust_field: "name",
261 column_name: "name",
262 renamed_from: None,
263 sql_type: SqlServerType::NVarChar,
264 nullable: true,
265 primary_key: false,
266 identity: None,
267 default_sql: None,
268 computed_sql: None,
269 rowversion: false,
270 insertable: true,
271 updatable: true,
272 max_length: Some(120),
273 precision: None,
274 scale: None,
275 },
276 ];
277
278 static TEST_ENTITY_METADATA: EntityMetadata = EntityMetadata {
279 rust_name: "TestEntity",
280 schema: "dbo",
281 table: "test_entities",
282 renamed_from: None,
283 columns: &TEST_ENTITY_COLUMNS,
284 primary_key: PrimaryKeyMetadata {
285 name: None,
286 columns: &["id"],
287 },
288 indexes: &[],
289 foreign_keys: &[],
290 navigations: &[],
291 };
292
293 impl Entity for TestEntity {
294 fn metadata() -> &'static EntityMetadata {
295 &TEST_ENTITY_METADATA
296 }
297 }
298
299 #[allow(non_upper_case_globals)]
300 impl TestEntity {
301 const id: EntityColumn<TestEntity> = EntityColumn::new("id", "id");
302 const name: EntityColumn<TestEntity> = EntityColumn::new("name", "name");
303 }
304
305 #[test]
306 fn comparison_methods_build_expected_predicates() {
307 let expected_column = Expr::Column(ColumnRef::new(
308 TableRef::new("dbo", "test_entities"),
309 "id",
310 "id",
311 ));
312
313 assert_eq!(
314 TestEntity::id.eq(7_i64),
315 Predicate::eq(expected_column.clone(), Expr::Value(SqlValue::I64(7)))
316 );
317 assert_eq!(
318 TestEntity::id.ne(8_i64),
319 Predicate::ne(expected_column.clone(), Expr::Value(SqlValue::I64(8)))
320 );
321 assert_eq!(
322 TestEntity::id.gt(9_i64),
323 Predicate::gt(expected_column.clone(), Expr::Value(SqlValue::I64(9)))
324 );
325 assert_eq!(
326 TestEntity::id.gte(10_i64),
327 Predicate::gte(expected_column.clone(), Expr::Value(SqlValue::I64(10)))
328 );
329 assert_eq!(
330 TestEntity::id.lt(11_i64),
331 Predicate::lt(expected_column.clone(), Expr::Value(SqlValue::I64(11)))
332 );
333 assert_eq!(
334 TestEntity::id.lte(12_i64),
335 Predicate::lte(expected_column, Expr::Value(SqlValue::I64(12)))
336 );
337 }
338
339 #[test]
340 fn null_predicate_methods_build_expected_predicates() {
341 let expected_column = Expr::Column(ColumnRef::new(
342 TableRef::new("dbo", "test_entities"),
343 "name",
344 "name",
345 ));
346
347 assert_eq!(
348 TestEntity::name.is_null(),
349 Predicate::is_null(expected_column.clone())
350 );
351 assert_eq!(
352 TestEntity::name.is_not_null(),
353 Predicate::is_not_null(expected_column)
354 );
355 }
356
357 #[test]
358 fn string_predicate_methods_build_expected_like_patterns() {
359 let expected_column = Expr::Column(ColumnRef::new(
360 TableRef::new("dbo", "test_entities"),
361 "name",
362 "name",
363 ));
364
365 assert_eq!(
366 TestEntity::name.contains("ana"),
367 Predicate::like_escaped(
368 expected_column.clone(),
369 Expr::Value(SqlValue::String("%ana%".to_string())),
370 LIKE_ESCAPE_CHAR
371 )
372 );
373 assert_eq!(
374 TestEntity::name.starts_with("ana"),
375 Predicate::like_escaped(
376 expected_column.clone(),
377 Expr::Value(SqlValue::String("ana%".to_string())),
378 LIKE_ESCAPE_CHAR
379 )
380 );
381 assert_eq!(
382 TestEntity::name.ends_with("ana"),
383 Predicate::like_escaped(
384 expected_column,
385 Expr::Value(SqlValue::String("%ana".to_string())),
386 LIKE_ESCAPE_CHAR
387 )
388 );
389 }
390
391 #[test]
392 fn string_predicate_methods_escape_like_wildcards_and_ranges() {
393 let expected_column = Expr::Column(ColumnRef::new(
394 TableRef::new("dbo", "test_entities"),
395 "name",
396 "name",
397 ));
398
399 assert_eq!(
400 TestEntity::name.contains(r"a%_b[c]\d"),
401 Predicate::like_escaped(
402 expected_column,
403 Expr::Value(SqlValue::String(r"%a\%\_b\[c\]\\d%".to_string())),
404 LIKE_ESCAPE_CHAR
405 )
406 );
407 }
408
409 #[test]
410 fn aliased_columns_build_predicates_against_table_alias() {
411 let expected_column = Expr::Column(ColumnRef::new(
412 TableRef::with_alias("dbo", "test_entities", "t"),
413 "name",
414 "name",
415 ));
416
417 assert_eq!(
418 TestEntity::name.aliased("t").contains("ana"),
419 Predicate::like_escaped(
420 expected_column.clone(),
421 Expr::Value(SqlValue::String("%ana%".to_string())),
422 LIKE_ESCAPE_CHAR
423 )
424 );
425 assert_eq!(
426 TestEntity::name.aliased("t").is_not_null(),
427 Predicate::is_not_null(expected_column)
428 );
429 }
430}