1use crate::pipeline::Pipeline;
7use async_trait::async_trait;
8use log::{debug, warn};
9use serde_json::Value;
10use spider_util::{error::PipelineError, item::ScrapedItem};
11use std::collections::HashMap;
12use std::marker::PhantomData;
13use std::sync::Arc;
14
15type ValidatorFn<I> = dyn Fn(&I, &Value) -> Result<(), String> + Send + Sync + 'static;
16
17#[derive(Debug, Clone, PartialEq, Eq)]
19pub enum JsonType {
20 Null,
21 Bool,
22 Number,
23 String,
24 Array,
25 Object,
26}
27
28#[derive(Debug, Clone)]
30pub enum ValidationRule {
31 Required,
32 NonEmptyString,
33 Type(JsonType),
34 MinLen(usize),
35 MaxLen(usize),
36 MinNumber(f64),
37 MaxNumber(f64),
38}
39
40pub struct ValidationPipeline<I: ScrapedItem> {
42 rules: HashMap<String, Vec<ValidationRule>>,
43 validators: Vec<Arc<ValidatorFn<I>>>,
44 _phantom: PhantomData<I>,
45}
46
47impl<I: ScrapedItem> ValidationPipeline<I> {
48 pub fn new() -> Self {
50 Self {
51 rules: HashMap::new(),
52 validators: Vec::new(),
53 _phantom: PhantomData,
54 }
55 }
56
57 pub fn with_rule(mut self, field: impl Into<String>, rule: ValidationRule) -> Self {
59 self.rules.entry(field.into()).or_default().push(rule);
60 self
61 }
62
63 pub fn with_validator<F>(mut self, validator: F) -> Self
65 where
66 F: Fn(&I, &Value) -> Result<(), String> + Send + Sync + 'static,
67 {
68 self.validators.push(Arc::new(validator));
69 self
70 }
71
72 fn validate_type(value: &Value, expected: &JsonType) -> bool {
73 match expected {
74 JsonType::Null => value.is_null(),
75 JsonType::Bool => value.is_boolean(),
76 JsonType::Number => value.is_number(),
77 JsonType::String => value.is_string(),
78 JsonType::Array => value.is_array(),
79 JsonType::Object => value.is_object(),
80 }
81 }
82
83 fn validate_item(&self, json: &Value) -> Result<(), String> {
84 let map = json
85 .as_object()
86 .ok_or_else(|| "Item must be a JSON object for validation.".to_string())?;
87
88 for (field, rules) in &self.rules {
89 let value = map.get(field);
90 for rule in rules {
91 match rule {
92 ValidationRule::Required => {
93 if value.is_none() {
94 return Err(format!("Missing required field '{}'.", field));
95 }
96 }
97 ValidationRule::NonEmptyString => {
98 if let Some(v) = value {
99 match v.as_str() {
100 Some(s) if !s.trim().is_empty() => {}
101 Some(_) => {
102 return Err(format!(
103 "Field '{}' must be a non-empty string.",
104 field
105 ));
106 }
107 None => {
108 return Err(format!("Field '{}' must be a string.", field));
109 }
110 }
111 }
112 }
113 ValidationRule::Type(expected) => {
114 if let Some(v) = value
115 && !Self::validate_type(v, expected)
116 {
117 return Err(format!(
118 "Field '{}' has invalid type. Expected {:?}.",
119 field, expected
120 ));
121 }
122 }
123 ValidationRule::MinLen(min) => {
124 if let Some(v) = value {
125 if let Some(s) = v.as_str() {
126 if s.len() < *min {
127 return Err(format!(
128 "Field '{}' length {} is less than {}.",
129 field,
130 s.len(),
131 min
132 ));
133 }
134 } else if let Some(arr) = v.as_array() {
135 if arr.len() < *min {
136 return Err(format!(
137 "Field '{}' array length {} is less than {}.",
138 field,
139 arr.len(),
140 min
141 ));
142 }
143 } else {
144 return Err(format!(
145 "Field '{}' must be string or array for MinLen.",
146 field
147 ));
148 }
149 }
150 }
151 ValidationRule::MaxLen(max) => {
152 if let Some(v) = value {
153 if let Some(s) = v.as_str() {
154 if s.len() > *max {
155 return Err(format!(
156 "Field '{}' length {} is greater than {}.",
157 field,
158 s.len(),
159 max
160 ));
161 }
162 } else if let Some(arr) = v.as_array() {
163 if arr.len() > *max {
164 return Err(format!(
165 "Field '{}' array length {} is greater than {}.",
166 field,
167 arr.len(),
168 max
169 ));
170 }
171 } else {
172 return Err(format!(
173 "Field '{}' must be string or array for MaxLen.",
174 field
175 ));
176 }
177 }
178 }
179 ValidationRule::MinNumber(min) => {
180 if let Some(v) = value {
181 let num = v.as_f64().ok_or_else(|| {
182 format!("Field '{}' must be numeric for MinNumber.", field)
183 })?;
184 if num < *min {
185 return Err(format!(
186 "Field '{}' number {} is less than {}.",
187 field, num, min
188 ));
189 }
190 }
191 }
192 ValidationRule::MaxNumber(max) => {
193 if let Some(v) = value {
194 let num = v.as_f64().ok_or_else(|| {
195 format!("Field '{}' must be numeric for MaxNumber.", field)
196 })?;
197 if num > *max {
198 return Err(format!(
199 "Field '{}' number {} is greater than {}.",
200 field, num, max
201 ));
202 }
203 }
204 }
205 }
206 }
207 }
208
209 Ok(())
210 }
211}
212
213impl<I: ScrapedItem> Default for ValidationPipeline<I> {
214 fn default() -> Self {
215 Self::new()
216 }
217}
218
219#[async_trait]
220impl<I: ScrapedItem> Pipeline<I> for ValidationPipeline<I> {
221 fn name(&self) -> &str {
222 "ValidationPipeline"
223 }
224
225 async fn process_item(&self, item: I) -> Result<Option<I>, PipelineError> {
226 debug!("ValidationPipeline processing item.");
227 let json = item.to_json_value();
228
229 if let Err(err) = self.validate_item(&json) {
230 warn!("Validation failed, dropping item: {}", err);
231 return Ok(None);
232 }
233
234 for validator in &self.validators {
235 if let Err(err) = validator(&item, &json) {
236 warn!("Custom validation failed, dropping item: {}", err);
237 return Ok(None);
238 }
239 }
240
241 Ok(Some(item))
242 }
243}
244
245#[cfg(test)]
246mod tests {
247 use super::*;
248 use serde::{Deserialize, Serialize};
249 use serde_json::json;
250 use spider_util::item::ScrapedItem;
251 use std::any::Any;
252
253 #[derive(Debug, Clone, Serialize, Deserialize)]
254 struct TestItem {
255 title: String,
256 price: f64,
257 }
258
259 impl ScrapedItem for TestItem {
260 fn as_any(&self) -> &dyn Any {
261 self
262 }
263
264 fn box_clone(&self) -> Box<dyn ScrapedItem + Send + Sync> {
265 Box::new(self.clone())
266 }
267
268 fn to_json_value(&self) -> Value {
269 serde_json::to_value(self).expect("serialize test item")
270 }
271 }
272
273 #[tokio::test]
274 async fn passes_valid_item() {
275 let pipeline = ValidationPipeline::<TestItem>::new()
276 .with_rule("title", ValidationRule::Required)
277 .with_rule("title", ValidationRule::NonEmptyString)
278 .with_rule("price", ValidationRule::MinNumber(1.0))
279 .with_rule("price", ValidationRule::MaxNumber(100.0));
280
281 let item = TestItem {
282 title: "Book".to_string(),
283 price: 20.0,
284 };
285
286 let out = pipeline
287 .process_item(item)
288 .await
289 .expect("pipeline should not fail");
290 assert!(out.is_some());
291 }
292
293 #[tokio::test]
294 async fn drops_missing_required_field() {
295 let pipeline =
296 ValidationPipeline::<TestItem>::new().with_rule("missing", ValidationRule::Required);
297 let item = TestItem {
298 title: "Book".to_string(),
299 price: 20.0,
300 };
301
302 let out = pipeline
303 .process_item(item)
304 .await
305 .expect("pipeline should not fail");
306 assert!(out.is_none());
307 }
308
309 #[tokio::test]
310 async fn drops_on_custom_validator_error() {
311 let pipeline =
312 ValidationPipeline::<TestItem>::new().with_validator(|_item, json| {
313 match json.get("title").and_then(Value::as_str) {
314 Some("Book") => Ok(()),
315 _ => Err("title mismatch".to_string()),
316 }
317 });
318
319 let item = TestItem {
320 title: "Other".to_string(),
321 price: 20.0,
322 };
323
324 let out = pipeline
325 .process_item(item)
326 .await
327 .expect("pipeline should not fail");
328 assert!(out.is_none());
329 }
330
331 #[tokio::test]
332 async fn drops_on_invalid_type_rule() {
333 let pipeline = ValidationPipeline::<TestItem>::new()
334 .with_rule("title", ValidationRule::Type(JsonType::Number));
335 let item = TestItem {
336 title: "Book".to_string(),
337 price: 20.0,
338 };
339
340 let out = pipeline
341 .process_item(item)
342 .await
343 .expect("pipeline should not fail");
344 assert!(out.is_none());
345 }
346
347 #[tokio::test]
348 async fn handles_multiple_rules() {
349 let pipeline = ValidationPipeline::<TestItem>::new()
350 .with_rule("title", ValidationRule::MinLen(2))
351 .with_rule("title", ValidationRule::MaxLen(10))
352 .with_validator(|_, _| Ok(()));
353 let item = TestItem {
354 title: "ok".to_string(),
355 price: 5.0,
356 };
357 let out = pipeline
358 .process_item(item)
359 .await
360 .expect("pipeline should not fail");
361 assert_eq!(
362 out.expect("item should pass").to_json_value(),
363 json!({"title":"ok","price":5.0})
364 );
365 }
366}