1use crate::models::context::NodeContext;
2use crate::models::error::WorkflowError;
3use crate::nodes::NodeExecutor;
4use async_trait::async_trait;
5use regex::Regex;
6use serde::Deserialize;
7use serde_json::Value;
8use std::thread;
9
10pub struct StartNode;
13
14#[derive(Debug, Deserialize)]
16struct StartNodeData {
17 #[serde(default)]
19 input: Vec<InputDef>,
20}
21
22#[derive(Debug, Deserialize)]
24struct InputDef {
25 name: String,
27 #[serde(rename = "type")]
29 typ: InputType,
30 #[serde(default)]
32 rules: Vec<RuleConfig>,
33}
34
35#[derive(Debug, Deserialize, PartialEq, Clone, Copy)]
37#[allow(non_camel_case_types)]
38enum InputType {
39 STRING,
41 INTEGER,
42 LONG,
43 DECIMAL,
44 BOOLEAN,
45 OBJECT,
46 FILE_IMAGE,
48 FILE_VIDEO,
49 FILE_AUDIO,
50 FILE_DOCUMENT,
51 ARRAY,
53 ARRAY_STRING,
54 ARRAY_INTEGER,
55 ARRAY_LONG,
56 ARRAY_DECIMAL,
57 ARRAY_BOOLEAN,
58 ARRAY_OBJECT,
59 ARRAY_FILE_IMAGE,
60 ARRAY_FILE_VIDEO,
61 ARRAY_FILE_AUDIO,
62 ARRAY_FILE_DOCUMENT,
63}
64
65#[derive(Debug, Deserialize)]
67struct RuleConfig {
68 #[serde(rename = "type")]
70 rule_type: String,
71 message: Option<String>,
73
74 min: Option<f64>,
77 max: Option<f64>,
79 size: Option<usize>,
81 length: Option<usize>,
83 pattern: Option<String>,
85 #[serde(rename = "enum")]
87 enum_values: Option<Vec<Value>>,
88}
89
90#[async_trait]
91impl NodeExecutor for StartNode {
92 async fn execute(&self, ctx: NodeContext) -> Result<Value, WorkflowError> {
96 println!(
97 "StartNode [{}] 线程号: {:?}",
98 ctx.node.id,
99 thread::current().id()
100 );
101
102 let config: StartNodeData = serde_json::from_value(ctx.node.data.as_ref().clone())
104 .map_err(|e| WorkflowError::ParseError(format!("Invalid StartNode config: {}", e)))?;
105
106 let payload = &ctx.flow_context.payload;
108
109 if !config.input.is_empty() {
111 if !payload.is_object() {
112 return Err(WorkflowError::ValidationError(
113 "StartNode payload must be an object".to_string(),
114 ));
115 }
116
117 for input_def in &config.input {
119 let val = payload.get(&input_def.name);
120 validate_input(val, input_def)?;
121 }
122 }
123
124 Ok(payload.clone())
125 }
126}
127
128fn validate_input(val: Option<&Value>, def: &InputDef) -> Result<(), WorkflowError> {
130 let is_required = def.rules.iter().any(|r| r.rule_type == "required");
132
133 if val.is_none() || val.unwrap().is_null() {
134 if is_required {
135 let msg = def
136 .rules
137 .iter()
138 .find(|r| r.rule_type == "required")
139 .and_then(|r| r.message.clone())
140 .unwrap_or_else(|| format!("Field '{}' is required", def.name));
141 return Err(WorkflowError::ValidationError(msg));
142 }
143 return Ok(());
144 }
145
146 let val = val.unwrap();
147
148 if !check_type(val, def.typ) {
150 return Err(WorkflowError::ValidationError(format!(
151 "Field '{}' expected type {:?}",
152 def.name, def.typ
153 )));
154 }
155
156 for rule in &def.rules {
158 match rule.rule_type.as_str() {
159 "required" => {} "length" => {
161 if let Some(s) = val.as_str() {
163 if let Some(len) = rule.length {
164 if s.len() != len {
165 return Err(WorkflowError::ValidationError(
166 rule.message.clone().unwrap_or(format!(
167 "Field '{}' length must be {}",
168 def.name, len
169 )),
170 ));
171 }
172 }
173 }
174 }
175 "max" => {
176 if let Some(max) = rule.max {
178 if let Some(n) = val.as_f64() {
179 if n > max {
180 return Err(WorkflowError::ValidationError(
181 rule.message
182 .clone()
183 .unwrap_or(format!("Field '{}' must be <= {}", def.name, max)),
184 ));
185 }
186 } else if let Some(s) = val.as_str() {
187 if s.len() as f64 > max {
188 return Err(WorkflowError::ValidationError(
189 rule.message.clone().unwrap_or(format!(
190 "Field '{}' length must be <= {}",
191 def.name, max
192 )),
193 ));
194 }
195 } else if let Some(arr) = val.as_array() {
196 if arr.len() as f64 > max {
197 return Err(WorkflowError::ValidationError(
198 rule.message.clone().unwrap_or(format!(
199 "Field '{}' size must be <= {}",
200 def.name, max
201 )),
202 ));
203 }
204 }
205 }
206 }
207 "min" => {
208 if let Some(min) = rule.min {
210 if let Some(n) = val.as_f64() {
211 if n < min {
212 return Err(WorkflowError::ValidationError(
213 rule.message
214 .clone()
215 .unwrap_or(format!("Field '{}' must be >= {}", def.name, min)),
216 ));
217 }
218 } else if let Some(s) = val.as_str() {
219 if (s.len() as f64) < min {
220 return Err(WorkflowError::ValidationError(
221 rule.message.clone().unwrap_or(format!(
222 "Field '{}' length must be >= {}",
223 def.name, min
224 )),
225 ));
226 }
227 } else if let Some(arr) = val.as_array() {
228 if (arr.len() as f64) < min {
229 return Err(WorkflowError::ValidationError(
230 rule.message.clone().unwrap_or(format!(
231 "Field '{}' size must be >= {}",
232 def.name, min
233 )),
234 ));
235 }
236 }
237 }
238 }
239 "enum" => {
240 if let Some(ref options) = rule.enum_values {
242 if !options.contains(val) {
243 return Err(WorkflowError::ValidationError(
244 rule.message.clone().unwrap_or(format!(
245 "Field '{}' must be one of {:?}",
246 def.name, options
247 )),
248 ));
249 }
250 }
251 }
252 "pattern" => {
253 if let Some(ref pat) = rule.pattern {
255 if let Some(s) = val.as_str() {
256 let re = Regex::new(pat).map_err(|e| {
257 WorkflowError::RuntimeError(format!("Invalid regex: {}", e))
258 })?;
259 if !re.is_match(s) {
260 return Err(WorkflowError::ValidationError(
261 rule.message
262 .clone()
263 .unwrap_or(format!("Field '{}' format invalid", def.name)),
264 ));
265 }
266 }
267 }
268 }
269 "email" => {
270 if let Some(s) = val.as_str() {
272 let email_regex =
273 Regex::new(r"^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$").unwrap();
274 if !email_regex.is_match(s) {
275 return Err(WorkflowError::ValidationError(
276 rule.message
277 .clone()
278 .unwrap_or(format!("Field '{}' must be a valid email", def.name)),
279 ));
280 }
281 }
282 }
283 "size" => {
284 if let Some(size) = rule.size {
286 if let Some(arr) = val.as_array() {
287 if arr.len() != size {
288 return Err(WorkflowError::ValidationError(
289 rule.message.clone().unwrap_or(format!(
290 "Field '{}' array size must be {}",
291 def.name, size
292 )),
293 ));
294 }
295 }
296 }
297 }
298 _ => {}
299 }
300 }
301
302 Ok(())
303}
304
305fn check_type(val: &Value, typ: InputType) -> bool {
307 match typ {
308 InputType::STRING => val.is_string(),
309 InputType::INTEGER => val.is_i64(),
310 InputType::LONG => val.is_i64(),
311 InputType::DECIMAL => val.is_f64(),
312 InputType::BOOLEAN => val.is_boolean(),
313 InputType::OBJECT => val.is_object(),
314 InputType::FILE_IMAGE
316 | InputType::FILE_VIDEO
317 | InputType::FILE_AUDIO
318 | InputType::FILE_DOCUMENT => val.is_object() || val.is_string(),
319 InputType::ARRAY => val.is_array(),
320 InputType::ARRAY_STRING => val
322 .as_array()
323 .map_or(false, |arr| arr.iter().all(|v| v.is_string())),
324 InputType::ARRAY_INTEGER => val
325 .as_array()
326 .map_or(false, |arr| arr.iter().all(|v| v.is_i64())),
327 InputType::ARRAY_LONG => val
328 .as_array()
329 .map_or(false, |arr| arr.iter().all(|v| v.is_i64())),
330 InputType::ARRAY_DECIMAL => val
331 .as_array()
332 .map_or(false, |arr| arr.iter().all(|v| v.is_f64())),
333 InputType::ARRAY_BOOLEAN => val
334 .as_array()
335 .map_or(false, |arr| arr.iter().all(|v| v.is_boolean())),
336 InputType::ARRAY_OBJECT => val
337 .as_array()
338 .map_or(false, |arr| arr.iter().all(|v| v.is_object())),
339 InputType::ARRAY_FILE_IMAGE
340 | InputType::ARRAY_FILE_VIDEO
341 | InputType::ARRAY_FILE_AUDIO
342 | InputType::ARRAY_FILE_DOCUMENT => val.as_array().map_or(false, |arr| {
343 arr.iter().all(|v| v.is_object() || v.is_string())
344 }),
345 }
346}
347
348#[cfg(test)]
349mod tests {
350 use super::*;
351 use crate::models::context::{FlowContext, NodeContext};
352 use crate::models::event_bus::EventBus;
353 use crate::models::workflow::Node;
354 use serde_json::json;
355 use std::sync::Arc;
356
357 async fn create_ctx(node_data: Value, payload: Value) -> NodeContext {
358 let flow_ctx = FlowContext::new().with_payload(payload);
359 let node_data = Arc::new(node_data);
360 NodeContext {
361 instance_id: "test-instance".to_string(),
362 node: Node {
363 id: "start".to_string(),
364 parent_id: None,
365 node_type: "start".to_string(),
366 data: node_data.clone(),
367 retry_policy: None,
368 },
369 flow_context: Arc::new(flow_ctx),
370 event_bus: EventBus::new(10),
371 resolved_data: node_data,
372 next_nodes: Arc::new(Vec::new()),
373 }
374 }
375
376 #[tokio::test]
377 async fn test_start_node_validation_success() {
378 let node_data = json!({
379 "input": [
380 {
381 "name": "age",
382 "type": "INTEGER",
383 "rules": [
384 { "type": "required" },
385 { "type": "min", "min": 18.0 }
386 ]
387 },
388 {
389 "name": "email",
390 "type": "STRING",
391 "rules": [
392 { "type": "email" }
393 ]
394 }
395 ]
396 });
397
398 let payload = json!({
399 "age": 20,
400 "email": "test@example.com"
401 });
402
403 let ctx = create_ctx(node_data, payload).await;
404 let executor = StartNode;
405 let result = executor.execute(ctx).await;
406 assert!(result.is_ok());
407 }
408
409 #[tokio::test]
410 async fn test_start_node_validation_fail_required() {
411 let node_data = json!({
412 "input": [
413 {
414 "name": "age",
415 "type": "INTEGER",
416 "rules": [
417 { "type": "required" }
418 ]
419 }
420 ]
421 });
422
423 let payload = json!({});
424
425 let ctx = create_ctx(node_data, payload).await;
426 let executor = StartNode;
427 let result = executor.execute(ctx).await;
428 assert!(result.is_err());
429 match result.unwrap_err() {
430 WorkflowError::ValidationError(msg) => assert!(msg.contains("required")),
431 _ => panic!("Expected ValidationError"),
432 }
433 }
434
435 #[tokio::test]
436 async fn test_start_node_validation_fail_type() {
437 let node_data = json!({
438 "input": [
439 {
440 "name": "age",
441 "type": "INTEGER",
442 "rules": []
443 }
444 ]
445 });
446
447 let payload = json!({ "age": "20" }); let ctx = create_ctx(node_data, payload).await;
450 let executor = StartNode;
451 let result = executor.execute(ctx).await;
452 assert!(result.is_err());
453 match result.unwrap_err() {
454 WorkflowError::ValidationError(msg) => assert!(msg.contains("expected type")),
455 _ => panic!("Expected ValidationError"),
456 }
457 }
458
459 #[tokio::test]
460 async fn test_start_node_validation_fail_rule() {
461 let node_data = json!({
462 "input": [
463 {
464 "name": "age",
465 "type": "INTEGER",
466 "rules": [
467 { "type": "min", "min": 18.0, "message": "Too young" }
468 ]
469 }
470 ]
471 });
472
473 let payload = json!({ "age": 10 });
474
475 let ctx = create_ctx(node_data, payload).await;
476 let executor = StartNode;
477 let result = executor.execute(ctx).await;
478 assert!(result.is_err());
479 match result.unwrap_err() {
480 WorkflowError::ValidationError(msg) => assert_eq!(msg, "Too young"),
481 _ => panic!("Expected ValidationError"),
482 }
483 }
484}