1use super::context::Context;
4use crate::error::Result;
5use crate::types::RelPtr;
6use std::collections::HashMap;
7use std::future::Future;
8use std::pin::Pin;
9
10#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
12pub enum PortDirection {
13 Input,
15 Output,
17}
18
19#[derive(Debug, Clone)]
21pub struct Port {
22 pub name: String,
24 pub direction: PortDirection,
26 pub schema: String,
28 pub required: bool,
30 pub description: String,
32}
33
34impl Port {
35 pub fn input(schema: impl Into<String>) -> Self {
37 Self {
38 name: "in".to_string(),
39 direction: PortDirection::Input,
40 schema: schema.into(),
41 required: true,
42 description: "Default input".to_string(),
43 }
44 }
45
46 pub fn output(schema: impl Into<String>) -> Self {
48 Self {
49 name: "out".to_string(),
50 direction: PortDirection::Output,
51 schema: schema.into(),
52 required: false,
53 description: "Default output".to_string(),
54 }
55 }
56
57 pub fn error() -> Self {
59 Self {
60 name: "error".to_string(),
61 direction: PortDirection::Output,
62 schema: "Error@v1".to_string(),
63 required: false,
64 description: "Error output".to_string(),
65 }
66 }
67
68 pub fn named(
70 name: impl Into<String>,
71 direction: PortDirection,
72 schema: impl Into<String>,
73 ) -> Self {
74 Self {
75 name: name.into(),
76 direction,
77 schema: schema.into(),
78 required: direction == PortDirection::Input,
79 description: String::new(),
80 }
81 }
82
83 pub fn optional(mut self) -> Self {
85 self.required = false;
86 self
87 }
88
89 pub fn with_description(mut self, desc: impl Into<String>) -> Self {
91 self.description = desc.into();
92 self
93 }
94}
95
96#[derive(Debug, Clone)]
98pub struct NodeInfo {
99 pub name: String,
101 pub namespace: String,
103 pub short_name: String,
105 pub description: String,
107 pub version: String,
109 pub inputs: Vec<Port>,
111 pub outputs: Vec<Port>,
113 pub effectful: bool,
115 pub deterministic: bool,
117}
118
119impl NodeInfo {
120 pub fn new(namespace: impl Into<String>, name: impl Into<String>) -> Self {
122 let namespace = namespace.into();
123 let short_name = name.into();
124 let full_name = format!("{}::{}", namespace, short_name);
125
126 Self {
127 name: full_name,
128 namespace,
129 short_name,
130 description: String::new(),
131 version: "1.0.0".to_string(),
132 inputs: vec![Port::input("Any")],
133 outputs: vec![Port::output("Any"), Port::error()],
134 effectful: false,
135 deterministic: true,
136 }
137 }
138
139 pub fn with_description(mut self, desc: impl Into<String>) -> Self {
141 self.description = desc.into();
142 self
143 }
144
145 pub fn with_version(mut self, version: impl Into<String>) -> Self {
147 self.version = version.into();
148 self
149 }
150
151 pub fn with_inputs(mut self, inputs: Vec<Port>) -> Self {
153 self.inputs = inputs;
154 self
155 }
156
157 pub fn with_outputs(mut self, outputs: Vec<Port>) -> Self {
159 self.outputs = outputs;
160 self
161 }
162
163 pub fn effectful(mut self) -> Self {
165 self.effectful = true;
166 self
167 }
168
169 pub fn non_deterministic(mut self) -> Self {
171 self.deterministic = false;
172 self
173 }
174
175 pub fn get_input(&self, name: &str) -> Option<&Port> {
177 self.inputs.iter().find(|p| p.name == name)
178 }
179
180 pub fn get_output(&self, name: &str) -> Option<&Port> {
182 self.outputs.iter().find(|p| p.name == name)
183 }
184}
185
186#[derive(Debug)]
188pub struct NodeOutput {
189 pub port: String,
191 pub data: RelPtr<()>,
193 pub schema_hash: u64,
195 pub error_message: Option<String>,
197}
198
199impl NodeOutput {
200 pub fn new<T>(port: impl Into<String>, data: RelPtr<T>) -> Self {
202 Self {
203 port: port.into(),
204 data: RelPtr::new(data.offset(), data.size()),
205 schema_hash: 0,
206 error_message: None,
207 }
208 }
209
210 pub fn out<T>(data: RelPtr<T>) -> Self {
212 Self::new("out", data)
213 }
214
215 pub fn error<T>(data: RelPtr<T>) -> Self {
217 Self::new("error", data)
218 }
219
220 pub fn error_with_message(message: impl Into<String>) -> Self {
225 Self {
226 port: "error".to_string(),
227 data: RelPtr::null(),
228 schema_hash: 0,
229 error_message: Some(message.into()),
230 }
231 }
232
233 pub fn on_true<T>(data: RelPtr<T>) -> Self {
235 Self::new("true", data)
236 }
237
238 pub fn on_false<T>(data: RelPtr<T>) -> Self {
240 Self::new("false", data)
241 }
242
243 pub fn with_schema_hash(mut self, hash: u64) -> Self {
245 self.schema_hash = hash;
246 self
247 }
248
249 pub fn has_error_message(&self) -> bool {
251 self.error_message.is_some()
252 }
253
254 pub fn get_error_message(&self) -> Option<&str> {
256 self.error_message.as_deref()
257 }
258
259 pub fn arena_location(&self) -> (crate::types::ArenaOffset, u32) {
264 (self.data.offset(), self.data.size())
265 }
266}
267
268pub type NodeFuture<'a> = Pin<Box<dyn Future<Output = Result<NodeOutput>> + Send + 'a>>;
270
271pub trait Node: Send + Sync {
310 fn info(&self) -> NodeInfo;
312
313 fn execute<'a>(&'a self, ctx: Context, inputs: HashMap<String, RelPtr<()>>) -> NodeFuture<'a>;
322
323 fn shutdown(&self) {}
325
326 fn output_schema_hash(&self) -> u64 {
328 0
329 }
330}
331
332pub trait NodeFactory: Send + Sync {
334 fn node_type(&self) -> &str;
336
337 fn create(&self, config: &serde_yaml::Value) -> Result<Box<dyn Node>>;
339}
340
341#[cfg(test)]
342mod tests {
343 use super::*;
344
345 #[test]
346 fn port_creation() {
347 let input = Port::input("OrderInput@v1");
348 assert_eq!(input.name, "in");
349 assert_eq!(input.direction, PortDirection::Input);
350 assert!(input.required);
351
352 let output = Port::output("OrderOutput@v1").optional();
353 assert_eq!(output.name, "out");
354 assert!(!output.required);
355 }
356
357 #[test]
358 fn node_info_creation() {
359 let info = NodeInfo::new("std", "switch")
360 .with_description("Conditional branching")
361 .with_inputs(vec![Port::input("Any")])
362 .with_outputs(vec![
363 Port::named("true", PortDirection::Output, "Any"),
364 Port::named("false", PortDirection::Output, "Any"),
365 Port::error(),
366 ]);
367
368 assert_eq!(info.name, "std::switch");
369 assert_eq!(info.namespace, "std");
370 assert_eq!(info.short_name, "switch");
371 assert_eq!(info.inputs.len(), 1);
372 assert_eq!(info.outputs.len(), 3);
373 }
374}