Skip to main content

upflow/nodes/
start_node.rs

1use crate::models::context::NodeContext;
2use crate::models::error::WorkflowError;
3use crate::nodes::NodeExecutor;
4use async_trait::async_trait;
5use regex::Regex;
6use serde::Deserialize;
7use serde_json::Value;
8use std::thread;
9
10/// 开始节点执行器
11/// 负责处理工作流的输入参数验证
12pub struct StartNode;
13
14/// 开始节点配置数据结构
15#[derive(Debug, Deserialize)]
16struct StartNodeData {
17    /// 输入参数定义列表
18    #[serde(default)]
19    input: Vec<InputDef>,
20}
21
22/// 输入参数定义
23#[derive(Debug, Deserialize)]
24struct InputDef {
25    /// 参数名称
26    name: String,
27    /// 参数类型
28    #[serde(rename = "type")]
29    typ: InputType,
30    /// 验证规则列表
31    #[serde(default)]
32    rules: Vec<RuleConfig>,
33}
34
35/// 支持的输入参数类型
36#[derive(Debug, Deserialize, PartialEq, Clone, Copy)]
37#[allow(non_camel_case_types)]
38enum InputType {
39    // 基本类型
40    STRING,
41    INTEGER,
42    LONG,
43    DECIMAL,
44    BOOLEAN,
45    OBJECT,
46    // 文件类型
47    FILE_IMAGE,
48    FILE_VIDEO,
49    FILE_AUDIO,
50    FILE_DOCUMENT,
51    // 数组类型
52    ARRAY,
53    ARRAY_STRING,
54    ARRAY_INTEGER,
55    ARRAY_LONG,
56    ARRAY_DECIMAL,
57    ARRAY_BOOLEAN,
58    ARRAY_OBJECT,
59    ARRAY_FILE_IMAGE,
60    ARRAY_FILE_VIDEO,
61    ARRAY_FILE_AUDIO,
62    ARRAY_FILE_DOCUMENT,
63}
64
65/// 验证规则配置
66#[derive(Debug, Deserialize)]
67struct RuleConfig {
68    /// 规则类型 (required, length, min, max, enum, pattern, email, size)
69    #[serde(rename = "type")]
70    rule_type: String,
71    /// 验证失败时的错误消息
72    message: Option<String>,
73
74    // 验证参数
75    /// 最小值 (用于数值)
76    min: Option<f64>,
77    /// 最大值 (用于数值)
78    max: Option<f64>,
79    /// 数组大小 (用于数组)
80    size: Option<usize>,
81    /// 字符串长度 (用于字符串)
82    length: Option<usize>,
83    /// 正则表达式模式 (用于字符串)
84    pattern: Option<String>,
85    /// 枚举值列表
86    #[serde(rename = "enum")]
87    enum_values: Option<Vec<Value>>,
88}
89
90#[async_trait]
91impl NodeExecutor for StartNode {
92    /// 执行开始节点逻辑
93    /// 1. 解析节点配置中的输入定义
94    /// 2. 验证工作流上下文中的 payload 数据是否符合定义
95    async fn execute(&self, ctx: NodeContext) -> Result<Value, WorkflowError> {
96        println!(
97            "StartNode [{}] 线程号: {:?}",
98            ctx.node.id,
99            thread::current().id()
100        );
101
102        // 1. 解析配置
103        let config: StartNodeData = serde_json::from_value(ctx.node.data.as_ref().clone())
104            .map_err(|e| WorkflowError::ParseError(format!("Invalid StartNode config: {}", e)))?;
105
106        // 2. 验证 payload
107        let payload = &ctx.flow_context.payload;
108
109        // 如果定义了输入参数,payload 必须是一个对象
110        if !config.input.is_empty() {
111            if !payload.is_object() {
112                return Err(WorkflowError::ValidationError(
113                    "StartNode payload must be an object".to_string(),
114                ));
115            }
116
117            // 逐个验证输入参数
118            for input_def in &config.input {
119                let val = payload.get(&input_def.name);
120                validate_input(val, input_def)?;
121            }
122        }
123
124        Ok(payload.clone())
125    }
126}
127
128/// 验证单个输入参数
129fn validate_input(val: Option<&Value>, def: &InputDef) -> Result<(), WorkflowError> {
130    // 检查必填项
131    let is_required = def.rules.iter().any(|r| r.rule_type == "required");
132
133    if val.is_none() || val.unwrap().is_null() {
134        if is_required {
135            let msg = def
136                .rules
137                .iter()
138                .find(|r| r.rule_type == "required")
139                .and_then(|r| r.message.clone())
140                .unwrap_or_else(|| format!("Field '{}' is required", def.name));
141            return Err(WorkflowError::ValidationError(msg));
142        }
143        return Ok(());
144    }
145
146    let val = val.unwrap();
147
148    // 类型检查
149    if !check_type(val, def.typ) {
150        return Err(WorkflowError::ValidationError(format!(
151            "Field '{}' expected type {:?}",
152            def.name, def.typ
153        )));
154    }
155
156    // 规则检查
157    for rule in &def.rules {
158        match rule.rule_type.as_str() {
159            "required" => {} // 已经在前面处理过
160            "length" => {
161                // 检查字符串长度
162                if let Some(s) = val.as_str() {
163                    if let Some(len) = rule.length {
164                        if s.len() != len {
165                            return Err(WorkflowError::ValidationError(
166                                rule.message.clone().unwrap_or(format!(
167                                    "Field '{}' length must be {}",
168                                    def.name, len
169                                )),
170                            ));
171                        }
172                    }
173                }
174            }
175            "max" => {
176                // 检查最大值 (数值大小 / 字符串长度 / 数组长度)
177                if let Some(max) = rule.max {
178                    if let Some(n) = val.as_f64() {
179                        if n > max {
180                            return Err(WorkflowError::ValidationError(
181                                rule.message
182                                    .clone()
183                                    .unwrap_or(format!("Field '{}' must be <= {}", def.name, max)),
184                            ));
185                        }
186                    } else if let Some(s) = val.as_str() {
187                        if s.len() as f64 > max {
188                            return Err(WorkflowError::ValidationError(
189                                rule.message.clone().unwrap_or(format!(
190                                    "Field '{}' length must be <= {}",
191                                    def.name, max
192                                )),
193                            ));
194                        }
195                    } else if let Some(arr) = val.as_array() {
196                        if arr.len() as f64 > max {
197                            return Err(WorkflowError::ValidationError(
198                                rule.message.clone().unwrap_or(format!(
199                                    "Field '{}' size must be <= {}",
200                                    def.name, max
201                                )),
202                            ));
203                        }
204                    }
205                }
206            }
207            "min" => {
208                // 检查最小值 (数值大小 / 字符串长度 / 数组长度)
209                if let Some(min) = rule.min {
210                    if let Some(n) = val.as_f64() {
211                        if n < min {
212                            return Err(WorkflowError::ValidationError(
213                                rule.message
214                                    .clone()
215                                    .unwrap_or(format!("Field '{}' must be >= {}", def.name, min)),
216                            ));
217                        }
218                    } else if let Some(s) = val.as_str() {
219                        if (s.len() as f64) < min {
220                            return Err(WorkflowError::ValidationError(
221                                rule.message.clone().unwrap_or(format!(
222                                    "Field '{}' length must be >= {}",
223                                    def.name, min
224                                )),
225                            ));
226                        }
227                    } else if let Some(arr) = val.as_array() {
228                        if (arr.len() as f64) < min {
229                            return Err(WorkflowError::ValidationError(
230                                rule.message.clone().unwrap_or(format!(
231                                    "Field '{}' size must be >= {}",
232                                    def.name, min
233                                )),
234                            ));
235                        }
236                    }
237                }
238            }
239            "enum" => {
240                // 检查枚举值
241                if let Some(ref options) = rule.enum_values {
242                    if !options.contains(val) {
243                        return Err(WorkflowError::ValidationError(
244                            rule.message.clone().unwrap_or(format!(
245                                "Field '{}' must be one of {:?}",
246                                def.name, options
247                            )),
248                        ));
249                    }
250                }
251            }
252            "pattern" => {
253                // 正则表达式匹配
254                if let Some(ref pat) = rule.pattern {
255                    if let Some(s) = val.as_str() {
256                        let re = Regex::new(pat).map_err(|e| {
257                            WorkflowError::RuntimeError(format!("Invalid regex: {}", e))
258                        })?;
259                        if !re.is_match(s) {
260                            return Err(WorkflowError::ValidationError(
261                                rule.message
262                                    .clone()
263                                    .unwrap_or(format!("Field '{}' format invalid", def.name)),
264                            ));
265                        }
266                    }
267                }
268            }
269            "email" => {
270                // 邮箱格式检查
271                if let Some(s) = val.as_str() {
272                    let email_regex =
273                        Regex::new(r"^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$").unwrap();
274                    if !email_regex.is_match(s) {
275                        return Err(WorkflowError::ValidationError(
276                            rule.message
277                                .clone()
278                                .unwrap_or(format!("Field '{}' must be a valid email", def.name)),
279                        ));
280                    }
281                }
282            }
283            "size" => {
284                // 数组固定大小检查
285                if let Some(size) = rule.size {
286                    if let Some(arr) = val.as_array() {
287                        if arr.len() != size {
288                            return Err(WorkflowError::ValidationError(
289                                rule.message.clone().unwrap_or(format!(
290                                    "Field '{}' array size must be {}",
291                                    def.name, size
292                                )),
293                            ));
294                        }
295                    }
296                }
297            }
298            _ => {}
299        }
300    }
301
302    Ok(())
303}
304
305/// 检查值类型是否匹配
306fn check_type(val: &Value, typ: InputType) -> bool {
307    match typ {
308        InputType::STRING => val.is_string(),
309        InputType::INTEGER => val.is_i64(),
310        InputType::LONG => val.is_i64(),
311        InputType::DECIMAL => val.is_f64(),
312        InputType::BOOLEAN => val.is_boolean(),
313        InputType::OBJECT => val.is_object(),
314        // 文件类型暂作为对象或字符串处理
315        InputType::FILE_IMAGE
316        | InputType::FILE_VIDEO
317        | InputType::FILE_AUDIO
318        | InputType::FILE_DOCUMENT => val.is_object() || val.is_string(),
319        InputType::ARRAY => val.is_array(),
320        // 数组元素类型检查
321        InputType::ARRAY_STRING => val
322            .as_array()
323            .map_or(false, |arr| arr.iter().all(|v| v.is_string())),
324        InputType::ARRAY_INTEGER => val
325            .as_array()
326            .map_or(false, |arr| arr.iter().all(|v| v.is_i64())),
327        InputType::ARRAY_LONG => val
328            .as_array()
329            .map_or(false, |arr| arr.iter().all(|v| v.is_i64())),
330        InputType::ARRAY_DECIMAL => val
331            .as_array()
332            .map_or(false, |arr| arr.iter().all(|v| v.is_f64())),
333        InputType::ARRAY_BOOLEAN => val
334            .as_array()
335            .map_or(false, |arr| arr.iter().all(|v| v.is_boolean())),
336        InputType::ARRAY_OBJECT => val
337            .as_array()
338            .map_or(false, |arr| arr.iter().all(|v| v.is_object())),
339        InputType::ARRAY_FILE_IMAGE
340        | InputType::ARRAY_FILE_VIDEO
341        | InputType::ARRAY_FILE_AUDIO
342        | InputType::ARRAY_FILE_DOCUMENT => val.as_array().map_or(false, |arr| {
343            arr.iter().all(|v| v.is_object() || v.is_string())
344        }),
345    }
346}
347
348#[cfg(test)]
349mod tests {
350    use super::*;
351    use crate::models::context::{FlowContext, NodeContext};
352    use crate::models::event_bus::EventBus;
353    use crate::models::workflow::Node;
354    use serde_json::json;
355    use std::sync::Arc;
356
357    async fn create_ctx(node_data: Value, payload: Value) -> NodeContext {
358        let flow_ctx = FlowContext::new().with_payload(payload);
359        let node_data = Arc::new(node_data);
360        NodeContext {
361            instance_id: "test-instance".to_string(),
362            node: Node {
363                id: "start".to_string(),
364                parent_id: None,
365                node_type: "start".to_string(),
366                data: node_data.clone(),
367                retry_policy: None,
368            },
369            flow_context: Arc::new(flow_ctx),
370            event_bus: EventBus::new(10),
371            resolved_data: node_data,
372            next_nodes: Arc::new(Vec::new()),
373        }
374    }
375
376    #[tokio::test]
377    async fn test_start_node_validation_success() {
378        let node_data = json!({
379            "input": [
380                {
381                    "name": "age",
382                    "type": "INTEGER",
383                    "rules": [
384                        { "type": "required" },
385                        { "type": "min", "min": 18.0 }
386                    ]
387                },
388                {
389                    "name": "email",
390                    "type": "STRING",
391                    "rules": [
392                        { "type": "email" }
393                    ]
394                }
395            ]
396        });
397
398        let payload = json!({
399            "age": 20,
400            "email": "test@example.com"
401        });
402
403        let ctx = create_ctx(node_data, payload).await;
404        let executor = StartNode;
405        let result = executor.execute(ctx).await;
406        assert!(result.is_ok());
407    }
408
409    #[tokio::test]
410    async fn test_start_node_validation_fail_required() {
411        let node_data = json!({
412            "input": [
413                {
414                    "name": "age",
415                    "type": "INTEGER",
416                    "rules": [
417                        { "type": "required" }
418                    ]
419                }
420            ]
421        });
422
423        let payload = json!({});
424
425        let ctx = create_ctx(node_data, payload).await;
426        let executor = StartNode;
427        let result = executor.execute(ctx).await;
428        assert!(result.is_err());
429        match result.unwrap_err() {
430            WorkflowError::ValidationError(msg) => assert!(msg.contains("required")),
431            _ => panic!("Expected ValidationError"),
432        }
433    }
434
435    #[tokio::test]
436    async fn test_start_node_validation_fail_type() {
437        let node_data = json!({
438            "input": [
439                {
440                    "name": "age",
441                    "type": "INTEGER",
442                    "rules": []
443                }
444            ]
445        });
446
447        let payload = json!({ "age": "20" }); // String instead of Int
448
449        let ctx = create_ctx(node_data, payload).await;
450        let executor = StartNode;
451        let result = executor.execute(ctx).await;
452        assert!(result.is_err());
453        match result.unwrap_err() {
454            WorkflowError::ValidationError(msg) => assert!(msg.contains("expected type")),
455            _ => panic!("Expected ValidationError"),
456        }
457    }
458
459    #[tokio::test]
460    async fn test_start_node_validation_fail_rule() {
461        let node_data = json!({
462            "input": [
463                {
464                    "name": "age",
465                    "type": "INTEGER",
466                    "rules": [
467                        { "type": "min", "min": 18.0, "message": "Too young" }
468                    ]
469                }
470            ]
471        });
472
473        let payload = json!({ "age": 10 });
474
475        let ctx = create_ctx(node_data, payload).await;
476        let executor = StartNode;
477        let result = executor.execute(ctx).await;
478        assert!(result.is_err());
479        match result.unwrap_err() {
480            WorkflowError::ValidationError(msg) => assert_eq!(msg, "Too young"),
481            _ => panic!("Expected ValidationError"),
482        }
483    }
484}