Skip to main content

wscall_server/
validation.rs

1#![allow(dead_code)]
2
3use serde_json::{Map, Value, json};
4use validator::{ValidationError, ValidationErrors, ValidationErrorsKind};
5
6pub use validator::Validate;
7
8pub fn required<T>(value: &Option<T>) -> Result<(), ValidationError> {
9    if value.is_none() {
10        return Err(simple_error("required", "value is required"));
11    }
12    Ok(())
13}
14
15pub fn assert_true(value: &bool) -> Result<(), ValidationError> {
16    if !*value {
17        return Err(simple_error("assert_true", "value must be true"));
18    }
19    Ok(())
20}
21
22pub fn assert_false(value: &bool) -> Result<(), ValidationError> {
23    if *value {
24        return Err(simple_error("assert_false", "value must be false"));
25    }
26    Ok(())
27}
28
29pub fn not_empty(value: &str) -> Result<(), ValidationError> {
30    if value.is_empty() {
31        return Err(simple_error("not_empty", "value cannot be empty"));
32    }
33    Ok(())
34}
35
36pub fn not_blank(value: &str) -> Result<(), ValidationError> {
37    if value.trim().is_empty() {
38        return Err(simple_error("not_blank", "value cannot be blank"));
39    }
40    Ok(())
41}
42
43pub fn no_whitespace(value: &str) -> Result<(), ValidationError> {
44    if value.chars().any(char::is_whitespace) {
45        return Err(simple_error(
46            "no_whitespace",
47            "value cannot contain whitespace",
48        ));
49    }
50    Ok(())
51}
52
53pub fn alphabetic(value: &str) -> Result<(), ValidationError> {
54    if !value.chars().all(char::is_alphabetic) {
55        return Err(simple_error(
56            "alphabetic",
57            "value must contain only letters",
58        ));
59    }
60    Ok(())
61}
62
63pub fn alphanumeric(value: &str) -> Result<(), ValidationError> {
64    if !value.chars().all(char::is_alphanumeric) {
65        return Err(simple_error(
66            "alphanumeric",
67            "value must contain only letters or digits",
68        ));
69    }
70    Ok(())
71}
72
73pub fn ascii_alphanumeric(value: &str) -> Result<(), ValidationError> {
74    if !value.chars().all(|ch| ch.is_ascii_alphanumeric()) {
75        return Err(simple_error(
76            "ascii_alphanumeric",
77            "value must contain only ASCII letters or digits",
78        ));
79    }
80    Ok(())
81}
82
83pub fn numeric_text(value: &str) -> Result<(), ValidationError> {
84    if !value.chars().all(|ch| ch.is_ascii_digit()) {
85        return Err(simple_error(
86            "numeric_text",
87            "value must contain only digits",
88        ));
89    }
90    Ok(())
91}
92
93pub fn lowercase(value: &str) -> Result<(), ValidationError> {
94    if value
95        .chars()
96        .any(|ch| ch.is_alphabetic() && !ch.is_lowercase())
97    {
98        return Err(simple_error("lowercase", "value must be lowercase"));
99    }
100    Ok(())
101}
102
103pub fn uppercase(value: &str) -> Result<(), ValidationError> {
104    if value
105        .chars()
106        .any(|ch| ch.is_alphabetic() && !ch.is_uppercase())
107    {
108        return Err(simple_error("uppercase", "value must be uppercase"));
109    }
110    Ok(())
111}
112
113pub fn non_empty_vec<T>(value: &[T]) -> Result<(), ValidationError> {
114    if value.is_empty() {
115        return Err(simple_error("not_empty", "collection cannot be empty"));
116    }
117    Ok(())
118}
119
120pub fn non_empty_map<K, V, S>(
121    value: &std::collections::HashMap<K, V, S>,
122) -> Result<(), ValidationError> {
123    if value.is_empty() {
124        return Err(simple_error("not_empty", "map cannot be empty"));
125    }
126    Ok(())
127}
128
129pub fn positive_i32(value: i32) -> Result<(), ValidationError> {
130    if value <= 0 {
131        return Err(simple_error("positive", "value must be greater than 0"));
132    }
133    Ok(())
134}
135
136pub fn non_negative_i32(value: i32) -> Result<(), ValidationError> {
137    if value < 0 {
138        return Err(simple_error(
139            "non_negative",
140            "value must be greater than or equal to 0",
141        ));
142    }
143    Ok(())
144}
145
146pub fn positive_i64(value: i64) -> Result<(), ValidationError> {
147    if value <= 0 {
148        return Err(simple_error("positive", "value must be greater than 0"));
149    }
150    Ok(())
151}
152
153pub fn non_negative_i64(value: i64) -> Result<(), ValidationError> {
154    if value < 0 {
155        return Err(simple_error(
156            "non_negative",
157            "value must be greater than or equal to 0",
158        ));
159    }
160    Ok(())
161}
162
163pub fn positive_f64(value: f64) -> Result<(), ValidationError> {
164    if value <= 0.0 {
165        return Err(simple_error("positive", "value must be greater than 0"));
166    }
167    Ok(())
168}
169
170pub fn non_negative_f64(value: f64) -> Result<(), ValidationError> {
171    if value < 0.0 {
172        return Err(simple_error(
173            "non_negative",
174            "value must be greater than or equal to 0",
175        ));
176    }
177    Ok(())
178}
179
180pub fn percentage(value: f64) -> Result<(), ValidationError> {
181    if !(0.0..=100.0).contains(&value) {
182        return Err(simple_error(
183            "percentage",
184            "value must be between 0 and 100",
185        ));
186    }
187    Ok(())
188}
189
190pub fn errors_to_details(errors: &ValidationErrors) -> Value {
191    let mut fields = Map::new();
192
193    for (field, kind) in errors.errors() {
194        fields.insert(field.to_string(), error_kind_to_value(kind));
195    }
196
197    Value::Object(fields)
198}
199
200fn error_kind_to_value(kind: &ValidationErrorsKind) -> Value {
201    match kind {
202        ValidationErrorsKind::Field(errors) => Value::Array(
203            errors
204                .iter()
205                .map(|error| {
206                    json!({
207                        "code": error.code,
208                        "message": error.message.as_ref().map(|message| message.to_string()),
209                        "params": error.params,
210                    })
211                })
212                .collect(),
213        ),
214        ValidationErrorsKind::List(items) => {
215            let mut list = Map::new();
216            for (index, nested) in items {
217                list.insert(index.to_string(), errors_to_details(nested));
218            }
219            Value::Object(list)
220        }
221        ValidationErrorsKind::Struct(nested) => errors_to_details(nested),
222    }
223}
224
225fn simple_error(code: &'static str, message: &'static str) -> ValidationError {
226    let mut error = ValidationError::new(code);
227    error.message = Some(message.into());
228    error
229}
230
231#[macro_export]
232macro_rules! wscall_regex_validator {
233    ($name:ident, $pattern:literal, $code:literal) => {
234        fn $name(value: &str) -> Result<(), ::validator::ValidationError> {
235            static REGEX: ::std::sync::OnceLock<::regex::Regex> = ::std::sync::OnceLock::new();
236            let regex = REGEX.get_or_init(|| {
237                ::regex::Regex::new($pattern)
238                    .expect("invalid regex pattern in wscall_regex_validator!")
239            });
240
241            if regex.is_match(value) {
242                Ok(())
243            } else {
244                let mut error = ::validator::ValidationError::new($code);
245                error.message = Some(::std::borrow::Cow::Owned(format!(
246                    "value does not match pattern {}",
247                    $pattern
248                )));
249                Err(error)
250            }
251        }
252    };
253}
254
255#[macro_export]
256macro_rules! wscall_min_length_validator {
257    ($name:ident, $min:expr) => {
258        fn $name(value: &str) -> Result<(), ::validator::ValidationError> {
259            let len = value.chars().count();
260            if len < $min {
261                let mut error = ::validator::ValidationError::new("min_length");
262                error.message = Some(::std::borrow::Cow::Owned(format!(
263                    "length must be at least {}",
264                    $min
265                )));
266                error.add_param(::std::borrow::Cow::Borrowed("min"), &$min);
267                error.add_param(::std::borrow::Cow::Borrowed("actual"), &len);
268                Err(error)
269            } else {
270                Ok(())
271            }
272        }
273    };
274}
275
276#[macro_export]
277macro_rules! wscall_max_length_validator {
278    ($name:ident, $max:expr) => {
279        fn $name(value: &str) -> Result<(), ::validator::ValidationError> {
280            let len = value.chars().count();
281            if len > $max {
282                let mut error = ::validator::ValidationError::new("max_length");
283                error.message = Some(::std::borrow::Cow::Owned(format!(
284                    "length must be at most {}",
285                    $max
286                )));
287                error.add_param(::std::borrow::Cow::Borrowed("max"), &$max);
288                error.add_param(::std::borrow::Cow::Borrowed("actual"), &len);
289                Err(error)
290            } else {
291                Ok(())
292            }
293        }
294    };
295}
296
297#[macro_export]
298macro_rules! wscall_length_range_validator {
299    ($name:ident, $min:expr, $max:expr) => {
300        fn $name(value: &str) -> Result<(), ::validator::ValidationError> {
301            let len = value.chars().count();
302            if !($min..=$max).contains(&len) {
303                let mut error = ::validator::ValidationError::new("length_range");
304                error.message = Some(::std::borrow::Cow::Owned(format!(
305                    "length must be between {} and {}",
306                    $min, $max
307                )));
308                error.add_param(::std::borrow::Cow::Borrowed("min"), &$min);
309                error.add_param(::std::borrow::Cow::Borrowed("max"), &$max);
310                error.add_param(::std::borrow::Cow::Borrowed("actual"), &len);
311                Err(error)
312            } else {
313                Ok(())
314            }
315        }
316    };
317}
318
319#[macro_export]
320macro_rules! wscall_contains_validator {
321    ($name:ident, $needle:literal, $code:literal) => {
322        fn $name(value: &str) -> Result<(), ::validator::ValidationError> {
323            if value.contains($needle) {
324                Ok(())
325            } else {
326                let mut error = ::validator::ValidationError::new($code);
327                error.message = Some(::std::borrow::Cow::Owned(format!(
328                    "value must contain {}",
329                    $needle
330                )));
331                Err(error)
332            }
333        }
334    };
335}
336
337#[macro_export]
338macro_rules! wscall_not_contains_validator {
339    ($name:ident, $needle:literal, $code:literal) => {
340        fn $name(value: &str) -> Result<(), ::validator::ValidationError> {
341            if value.contains($needle) {
342                let mut error = ::validator::ValidationError::new($code);
343                error.message = Some(::std::borrow::Cow::Owned(format!(
344                    "value cannot contain {}",
345                    $needle
346                )));
347                Err(error)
348            } else {
349                Ok(())
350            }
351        }
352    };
353}
354
355#[macro_export]
356macro_rules! wscall_one_of_validator {
357    ($name:ident, [$($value:expr),+ $(,)?], $code:literal) => {
358        fn $name(value: &str) -> Result<(), ::validator::ValidationError> {
359            const ALLOWED: &[&str] = &[$($value),+];
360            if ALLOWED.contains(&value) {
361                Ok(())
362            } else {
363                let mut error = ::validator::ValidationError::new($code);
364                error.message = Some(::std::borrow::Cow::Owned(format!(
365                    "value must be one of {:?}",
366                    ALLOWED
367                )));
368                Err(error)
369            }
370        }
371    };
372}
373
374#[macro_export]
375macro_rules! wscall_numeric_min_validator {
376    ($name:ident, $ty:ty, $min:expr) => {
377        fn $name(value: $ty) -> Result<(), ::validator::ValidationError> {
378            if value < $min {
379                let mut error = ::validator::ValidationError::new("min");
380                error.message = Some(::std::borrow::Cow::Owned(format!(
381                    "value must be greater than or equal to {}",
382                    $min
383                )));
384                error.add_param(::std::borrow::Cow::Borrowed("min"), &$min);
385                Err(error)
386            } else {
387                Ok(())
388            }
389        }
390    };
391}
392
393#[macro_export]
394macro_rules! wscall_numeric_max_validator {
395    ($name:ident, $ty:ty, $max:expr) => {
396        fn $name(value: $ty) -> Result<(), ::validator::ValidationError> {
397            if value > $max {
398                let mut error = ::validator::ValidationError::new("max");
399                error.message = Some(::std::borrow::Cow::Owned(format!(
400                    "value must be less than or equal to {}",
401                    $max
402                )));
403                error.add_param(::std::borrow::Cow::Borrowed("max"), &$max);
404                Err(error)
405            } else {
406                Ok(())
407            }
408        }
409    };
410}
411
412#[macro_export]
413macro_rules! wscall_numeric_range_validator {
414    ($name:ident, $ty:ty, $min:expr, $max:expr) => {
415        fn $name(value: $ty) -> Result<(), ::validator::ValidationError> {
416            if !($min..=$max).contains(&value) {
417                let mut error = ::validator::ValidationError::new("range");
418                error.message = Some(::std::borrow::Cow::Owned(format!(
419                    "value must be between {} and {}",
420                    $min, $max
421                )));
422                error.add_param(::std::borrow::Cow::Borrowed("min"), &$min);
423                error.add_param(::std::borrow::Cow::Borrowed("max"), &$max);
424                Err(error)
425            } else {
426                Ok(())
427            }
428        }
429    };
430}
431
432#[cfg(test)]
433mod tests {
434    use super::*;
435
436    wscall_min_length_validator!(validate_min_len_3, 3);
437    wscall_max_length_validator!(validate_max_len_5, 5);
438    wscall_length_range_validator!(validate_len_2_to_4, 2, 4);
439    wscall_numeric_range_validator!(validate_age_range, i32, 1, 150);
440    wscall_one_of_validator!(validate_env, ["dev", "test", "prod"], "invalid_env");
441
442    #[test]
443    fn built_in_string_validators_work() {
444        assert!(not_blank("hello").is_ok());
445        assert!(not_blank("   ").is_err());
446        assert!(ascii_alphanumeric("abc123").is_ok());
447        assert!(ascii_alphanumeric("abc-123").is_err());
448        assert!(numeric_text("123456").is_ok());
449        assert!(numeric_text("12a456").is_err());
450    }
451
452    #[test]
453    fn macro_validators_work() {
454        assert!(validate_min_len_3("abc").is_ok());
455        assert!(validate_min_len_3("ab").is_err());
456        assert!(validate_max_len_5("abcde").is_ok());
457        assert!(validate_max_len_5("abcdef").is_err());
458        assert!(validate_len_2_to_4("abc").is_ok());
459        assert!(validate_len_2_to_4("a").is_err());
460        assert!(validate_age_range(42).is_ok());
461        assert!(validate_age_range(151).is_err());
462        assert!(validate_env("prod").is_ok());
463        assert!(validate_env("stage").is_err());
464    }
465}