skp_validator_rules/comparison/
allowed_values.rs1use skp_validator_core::{Rule, ValidationContext, ValidationErrors, ValidationError, ValidationResult};
4use std::fmt::Display;
5
6#[derive(Debug, Clone)]
21pub struct AllowedValuesRule<T> {
22 pub values: Vec<T>,
24 pub message: Option<String>,
26}
27
28impl<T> AllowedValuesRule<T> {
29 pub fn new(values: Vec<T>) -> Self {
31 Self {
32 values,
33 message: None,
34 }
35 }
36
37 pub fn message(mut self, msg: impl Into<String>) -> Self {
39 self.message = Some(msg.into());
40 self
41 }
42}
43
44impl<T: Display> AllowedValuesRule<T> {
45 fn get_message(&self) -> String {
46 self.message.clone().unwrap_or_else(|| {
47 let values_str: Vec<String> = self.values.iter().map(|v| v.to_string()).collect();
48 format!("Must be one of: {}", values_str.join(", "))
49 })
50 }
51}
52
53impl AllowedValuesRule<String> {
54 pub fn from_strs(values: &[&str]) -> Self {
56 Self {
57 values: values.iter().map(|s| s.to_string()).collect(),
58 message: None,
59 }
60 }
61}
62
63impl Rule<str> for AllowedValuesRule<String> {
64 fn validate(&self, value: &str, _ctx: &ValidationContext) -> ValidationResult<()> {
65 if value.is_empty() {
67 return Ok(());
68 }
69
70 if self.values.iter().any(|v| v == value) {
71 Ok(())
72 } else {
73 Err(ValidationErrors::from_iter([
74 ValidationError::root("allowed_values", self.get_message())
75 .with_param("allowed", self.values.clone())
76 ]))
77 }
78 }
79
80 fn name(&self) -> &'static str {
81 "allowed_values"
82 }
83
84 fn default_message(&self) -> String {
85 self.get_message()
86 }
87}
88
89impl Rule<String> for AllowedValuesRule<String> {
90 fn validate(&self, value: &String, ctx: &ValidationContext) -> ValidationResult<()> {
91 <Self as Rule<str>>::validate(self, value.as_str(), ctx)
92 }
93
94 fn name(&self) -> &'static str {
95 "allowed_values"
96 }
97
98 fn default_message(&self) -> String {
99 self.get_message()
100 }
101}
102
103impl Rule<str> for AllowedValuesRule<&str> {
105 fn validate(&self, value: &str, _ctx: &ValidationContext) -> ValidationResult<()> {
106 if value.is_empty() {
107 return Ok(());
108 }
109
110 if self.values.contains(&value) {
111 Ok(())
112 } else {
113 let values_str: Vec<String> = self.values.iter().map(|v| v.to_string()).collect();
114 Err(ValidationErrors::from_iter([
115 ValidationError::root("allowed_values",
116 self.message.clone().unwrap_or_else(|| format!("Must be one of: {}", values_str.join(", "))))
117 .with_param("allowed", values_str)
118 ]))
119 }
120 }
121
122 fn name(&self) -> &'static str {
123 "allowed_values"
124 }
125
126 fn default_message(&self) -> String {
127 let values_str: Vec<String> = self.values.iter().map(|v| v.to_string()).collect();
128 format!("Must be one of: {}", values_str.join(", "))
129 }
130}
131
132macro_rules! impl_allowed_values_numeric {
134 ($($t:ty),+) => {
135 $(
136 impl Rule<$t> for AllowedValuesRule<$t> {
137 fn validate(&self, value: &$t, _ctx: &ValidationContext) -> ValidationResult<()> {
138 if self.values.contains(value) {
139 Ok(())
140 } else {
141 Err(ValidationErrors::from_iter([
142 ValidationError::root("allowed_values", self.get_message())
143 ]))
144 }
145 }
146
147 fn name(&self) -> &'static str {
148 "allowed_values"
149 }
150
151 fn default_message(&self) -> String {
152 self.get_message()
153 }
154 }
155 )+
156 };
157}
158
159impl_allowed_values_numeric!(i32, i64, u32, u64, f64);
160
161#[cfg(test)]
162mod tests {
163 use super::*;
164
165 #[test]
166 fn test_allowed_strings() {
167 let rule = AllowedValuesRule::from_strs(&["active", "pending", "disabled"]);
168 let ctx = ValidationContext::default();
169
170 assert!(rule.validate("active", &ctx).is_ok());
171 assert!(rule.validate("pending", &ctx).is_ok());
172 assert!(rule.validate("unknown", &ctx).is_err());
173 }
174
175 #[test]
176 fn test_allowed_numbers() {
177 let rule = AllowedValuesRule::new(vec![1, 2, 3, 5, 8, 13]);
178 let ctx = ValidationContext::default();
179
180 assert!(rule.validate(&5, &ctx).is_ok());
181 assert!(rule.validate(&4, &ctx).is_err());
182 }
183
184 #[test]
185 fn test_empty_is_valid() {
186 let rule = AllowedValuesRule::from_strs(&["a", "b"]);
187 let ctx = ValidationContext::default();
188
189 assert!(rule.validate("", &ctx).is_ok());
190 }
191
192 #[test]
193 fn test_custom_message() {
194 let rule = AllowedValuesRule::from_strs(&["a", "b"]).message("Invalid status");
195 let ctx = ValidationContext::default();
196
197 let result = rule.validate("c", &ctx);
198 assert!(result.is_err());
199 let errors = result.unwrap_err();
200 assert!(errors.to_string().contains("Invalid status"));
201 }
202}