skp_validator_rules/comparison/
allowed_values.rs

1//! Allowed values validation rule.
2
3use skp_validator_core::{Rule, ValidationContext, ValidationErrors, ValidationError, ValidationResult};
4use std::fmt::Display;
5
6/// Allowed values validation rule - field must be one of the specified values.
7///
8/// # Example
9///
10/// ```rust
11/// use skp_validator_rules::comparison::allowed_values::AllowedValuesRule;
12/// use skp_validator_core::{Rule, ValidationContext};
13///
14/// let rule = AllowedValuesRule::new(vec!["active", "pending", "disabled"]);
15/// let ctx = ValidationContext::default();
16///
17/// assert!(rule.validate("active", &ctx).is_ok());
18/// assert!(rule.validate("unknown", &ctx).is_err());
19/// ```
20#[derive(Debug, Clone)]
21pub struct AllowedValuesRule<T> {
22    /// The allowed values
23    pub values: Vec<T>,
24    /// Custom error message
25    pub message: Option<String>,
26}
27
28impl<T> AllowedValuesRule<T> {
29    /// Create a new allowed values rule.
30    pub fn new(values: Vec<T>) -> Self {
31        Self {
32            values,
33            message: None,
34        }
35    }
36
37    /// Set custom error message.
38    pub fn message(mut self, msg: impl Into<String>) -> Self {
39        self.message = Some(msg.into());
40        self
41    }
42}
43
44impl<T: Display> AllowedValuesRule<T> {
45    fn get_message(&self) -> String {
46        self.message.clone().unwrap_or_else(|| {
47            let values_str: Vec<String> = self.values.iter().map(|v| v.to_string()).collect();
48            format!("Must be one of: {}", values_str.join(", "))
49        })
50    }
51}
52
53impl AllowedValuesRule<String> {
54    /// Create from string slices.
55    pub fn from_strs(values: &[&str]) -> Self {
56        Self {
57            values: values.iter().map(|s| s.to_string()).collect(),
58            message: None,
59        }
60    }
61}
62
63impl Rule<str> for AllowedValuesRule<String> {
64    fn validate(&self, value: &str, _ctx: &ValidationContext) -> ValidationResult<()> {
65        // Empty is valid (use required for non-empty)
66        if value.is_empty() {
67            return Ok(());
68        }
69
70        if self.values.iter().any(|v| v == value) {
71            Ok(())
72        } else {
73            Err(ValidationErrors::from_iter([
74                ValidationError::root("allowed_values", self.get_message())
75                    .with_param("allowed", self.values.clone())
76            ]))
77        }
78    }
79
80    fn name(&self) -> &'static str {
81        "allowed_values"
82    }
83
84    fn default_message(&self) -> String {
85        self.get_message()
86    }
87}
88
89impl Rule<String> for AllowedValuesRule<String> {
90    fn validate(&self, value: &String, ctx: &ValidationContext) -> ValidationResult<()> {
91        <Self as Rule<str>>::validate(self, value.as_str(), ctx)
92    }
93
94    fn name(&self) -> &'static str {
95        "allowed_values"
96    }
97
98    fn default_message(&self) -> String {
99        self.get_message()
100    }
101}
102
103// Implement for &str
104impl Rule<str> for AllowedValuesRule<&str> {
105    fn validate(&self, value: &str, _ctx: &ValidationContext) -> ValidationResult<()> {
106        if value.is_empty() {
107            return Ok(());
108        }
109
110        if self.values.contains(&value) {
111            Ok(())
112        } else {
113            let values_str: Vec<String> = self.values.iter().map(|v| v.to_string()).collect();
114            Err(ValidationErrors::from_iter([
115                ValidationError::root("allowed_values", 
116                    self.message.clone().unwrap_or_else(|| format!("Must be one of: {}", values_str.join(", "))))
117                    .with_param("allowed", values_str)
118            ]))
119        }
120    }
121
122    fn name(&self) -> &'static str {
123        "allowed_values"
124    }
125
126    fn default_message(&self) -> String {
127        let values_str: Vec<String> = self.values.iter().map(|v| v.to_string()).collect();
128        format!("Must be one of: {}", values_str.join(", "))
129    }
130}
131
132// Implement for numeric types
133macro_rules! impl_allowed_values_numeric {
134    ($($t:ty),+) => {
135        $(
136            impl Rule<$t> for AllowedValuesRule<$t> {
137                fn validate(&self, value: &$t, _ctx: &ValidationContext) -> ValidationResult<()> {
138                    if self.values.contains(value) {
139                        Ok(())
140                    } else {
141                        Err(ValidationErrors::from_iter([
142                            ValidationError::root("allowed_values", self.get_message())
143                        ]))
144                    }
145                }
146
147                fn name(&self) -> &'static str {
148                    "allowed_values"
149                }
150
151                fn default_message(&self) -> String {
152                    self.get_message()
153                }
154            }
155        )+
156    };
157}
158
159impl_allowed_values_numeric!(i32, i64, u32, u64, f64);
160
161#[cfg(test)]
162mod tests {
163    use super::*;
164
165    #[test]
166    fn test_allowed_strings() {
167        let rule = AllowedValuesRule::from_strs(&["active", "pending", "disabled"]);
168        let ctx = ValidationContext::default();
169
170        assert!(rule.validate("active", &ctx).is_ok());
171        assert!(rule.validate("pending", &ctx).is_ok());
172        assert!(rule.validate("unknown", &ctx).is_err());
173    }
174
175    #[test]
176    fn test_allowed_numbers() {
177        let rule = AllowedValuesRule::new(vec![1, 2, 3, 5, 8, 13]);
178        let ctx = ValidationContext::default();
179
180        assert!(rule.validate(&5, &ctx).is_ok());
181        assert!(rule.validate(&4, &ctx).is_err());
182    }
183
184    #[test]
185    fn test_empty_is_valid() {
186        let rule = AllowedValuesRule::from_strs(&["a", "b"]);
187        let ctx = ValidationContext::default();
188
189        assert!(rule.validate("", &ctx).is_ok());
190    }
191
192    #[test]
193    fn test_custom_message() {
194        let rule = AllowedValuesRule::from_strs(&["a", "b"]).message("Invalid status");
195        let ctx = ValidationContext::default();
196
197        let result = rule.validate("c", &ctx);
198        assert!(result.is_err());
199        let errors = result.unwrap_err();
200        assert!(errors.to_string().contains("Invalid status"));
201    }
202}