1use crate::error::StreamKitError;
13use crate::node::{NodeFactory, ProcessorNode, ResourceKeyHasher};
14use crate::pins::{InputPin, OutputPin};
15use crate::resource_manager::{Resource, ResourceError, ResourceKey, ResourceManager};
16use schemars::JsonSchema;
17use serde::{Deserialize, Serialize};
18use std::collections::HashMap;
19use std::sync::Arc;
20use ts_rs::TS;
21
22pub type AsyncResourceFactory = Arc<
25 dyn Fn(
26 Option<serde_json::Value>,
27 ) -> std::pin::Pin<
28 Box<dyn std::future::Future<Output = Result<Arc<dyn Resource>, ResourceError>> + Send>,
29 > + Send
30 + Sync,
31>;
32
33#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema, TS)]
35#[ts(export)]
36pub struct NodeDefinition {
37 pub kind: String,
38 #[serde(default, skip_serializing_if = "Option::is_none")]
41 pub description: Option<String>,
42 pub param_schema: serde_json::Value,
43 pub inputs: Vec<InputPin>,
44 pub outputs: Vec<OutputPin>,
45 pub categories: Vec<String>,
47 #[serde(default)]
49 pub bidirectional: bool,
50}
51
52#[derive(Clone)]
54pub struct StaticPins {
55 pub inputs: Vec<InputPin>,
56 pub outputs: Vec<OutputPin>,
57}
58
59#[derive(Clone)]
61pub(crate) struct NodeInfo {
62 pub factory: NodeFactory,
63 pub param_schema: serde_json::Value,
64 pub static_pins: Option<StaticPins>,
65 pub categories: Vec<String>,
66 pub bidirectional: bool,
67 pub description: Option<String>,
69 pub resource_factory: Option<AsyncResourceFactory>,
71 pub resource_key_hasher: Option<ResourceKeyHasher>,
73}
74
75#[derive(Clone, Default)]
77pub struct NodeRegistry {
78 info: HashMap<String, NodeInfo>,
79 #[allow(clippy::type_complexity)]
81 resource_manager: Option<Arc<ResourceManager>>,
82}
83
84impl NodeRegistry {
85 pub fn new() -> Self {
87 Self::default()
88 }
89
90 pub fn with_resource_manager(resource_manager: Arc<ResourceManager>) -> Self {
92 Self { info: HashMap::new(), resource_manager: Some(resource_manager) }
93 }
94
95 pub fn set_resource_manager(&mut self, resource_manager: Arc<ResourceManager>) {
97 self.resource_manager = Some(resource_manager);
98 }
99
100 pub fn register_static<F>(
103 &mut self,
104 name: &str,
105 factory: F,
106 param_schema: serde_json::Value,
107 pins: StaticPins,
108 categories: Vec<String>,
109 bidirectional: bool,
110 ) where
111 F: Fn(Option<&serde_json::Value>) -> Result<Box<dyn ProcessorNode>, StreamKitError>
112 + Send
113 + Sync
114 + 'static,
115 {
116 self.info.insert(
117 name.to_string(),
118 NodeInfo {
119 factory: Arc::new(factory),
120 param_schema,
121 static_pins: Some(pins),
122 categories,
123 bidirectional,
124 description: None,
125 resource_factory: None,
126 resource_key_hasher: None,
127 },
128 );
129 }
130
131 #[allow(clippy::too_many_arguments)]
133 pub fn register_static_with_description<F>(
134 &mut self,
135 name: &str,
136 factory: F,
137 param_schema: serde_json::Value,
138 pins: StaticPins,
139 categories: Vec<String>,
140 bidirectional: bool,
141 description: impl Into<String>,
142 ) where
143 F: Fn(Option<&serde_json::Value>) -> Result<Box<dyn ProcessorNode>, StreamKitError>
144 + Send
145 + Sync
146 + 'static,
147 {
148 self.info.insert(
149 name.to_string(),
150 NodeInfo {
151 factory: Arc::new(factory),
152 param_schema,
153 static_pins: Some(pins),
154 categories,
155 bidirectional,
156 description: Some(description.into()),
157 resource_factory: None,
158 resource_key_hasher: None,
159 },
160 );
161 }
162
163 pub fn register_dynamic<F>(
167 &mut self,
168 name: &str,
169 factory: F,
170 param_schema: serde_json::Value,
171 categories: Vec<String>,
172 bidirectional: bool,
173 ) where
174 F: Fn(Option<&serde_json::Value>) -> Result<Box<dyn ProcessorNode>, StreamKitError>
175 + Send
176 + Sync
177 + 'static,
178 {
179 self.info.insert(
180 name.to_string(),
181 NodeInfo {
182 factory: Arc::new(factory),
183 param_schema,
184 static_pins: None,
185 categories,
186 bidirectional,
187 description: None,
188 resource_factory: None,
189 resource_key_hasher: None,
190 },
191 );
192 }
193
194 pub fn register_dynamic_with_description<F>(
196 &mut self,
197 name: &str,
198 factory: F,
199 param_schema: serde_json::Value,
200 categories: Vec<String>,
201 bidirectional: bool,
202 description: impl Into<String>,
203 ) where
204 F: Fn(Option<&serde_json::Value>) -> Result<Box<dyn ProcessorNode>, StreamKitError>
205 + Send
206 + Sync
207 + 'static,
208 {
209 self.info.insert(
210 name.to_string(),
211 NodeInfo {
212 factory: Arc::new(factory),
213 param_schema,
214 static_pins: None,
215 categories,
216 bidirectional,
217 description: Some(description.into()),
218 resource_factory: None,
219 resource_key_hasher: None,
220 },
221 );
222 }
223
224 #[allow(clippy::too_many_arguments)]
238 pub fn register_static_with_resource<F>(
239 &mut self,
240 name: &str,
241 factory: F,
242 resource_factory: AsyncResourceFactory,
243 resource_key_hasher: ResourceKeyHasher,
244 param_schema: serde_json::Value,
245 pins: StaticPins,
246 categories: Vec<String>,
247 bidirectional: bool,
248 ) where
249 F: Fn(Option<&serde_json::Value>) -> Result<Box<dyn ProcessorNode>, StreamKitError>
250 + Send
251 + Sync
252 + 'static,
253 {
254 self.info.insert(
255 name.to_string(),
256 NodeInfo {
257 factory: Arc::new(factory),
258 param_schema,
259 static_pins: Some(pins),
260 categories,
261 bidirectional,
262 description: None,
263 resource_factory: Some(resource_factory),
264 resource_key_hasher: Some(resource_key_hasher),
265 },
266 );
267 }
268
269 #[allow(clippy::too_many_arguments)]
271 pub fn register_dynamic_with_resource<F>(
272 &mut self,
273 name: &str,
274 factory: F,
275 resource_factory: AsyncResourceFactory,
276 resource_key_hasher: ResourceKeyHasher,
277 param_schema: serde_json::Value,
278 categories: Vec<String>,
279 bidirectional: bool,
280 ) where
281 F: Fn(Option<&serde_json::Value>) -> Result<Box<dyn ProcessorNode>, StreamKitError>
282 + Send
283 + Sync
284 + 'static,
285 {
286 self.info.insert(
287 name.to_string(),
288 NodeInfo {
289 factory: Arc::new(factory),
290 param_schema,
291 static_pins: None,
292 categories,
293 bidirectional,
294 description: None,
295 resource_factory: Some(resource_factory),
296 resource_key_hasher: Some(resource_key_hasher),
297 },
298 );
299 }
300
301 pub fn create_node(
313 &self,
314 name: &str,
315 params: Option<&serde_json::Value>,
316 ) -> Result<Box<dyn ProcessorNode>, StreamKitError> {
317 self.info.get(name).map_or_else(
318 || Err(StreamKitError::Runtime(format!("Node type '{name}' not found in registry"))),
319 |info| (info.factory)(params),
320 )
321 }
322
323 pub async fn create_node_async(
333 &self,
334 name: &str,
335 params: Option<&serde_json::Value>,
336 ) -> Result<Box<dyn ProcessorNode>, StreamKitError> {
337 let info = self.info.get(name).ok_or_else(|| {
338 StreamKitError::Runtime(format!("Node type '{name}' not found in registry"))
339 })?;
340
341 if let (Some(resource_factory), Some(resource_key_hasher), Some(resource_manager)) =
343 (&info.resource_factory, &info.resource_key_hasher, &self.resource_manager)
344 {
345 let params_hash = resource_key_hasher(params);
347 let resource_key = ResourceKey::new(name, params_hash);
348
349 let params_owned = params.cloned();
351 let rf = resource_factory.clone();
352 let _resource = resource_manager
353 .get_or_create(resource_key, || (rf)(params_owned))
354 .await
355 .map_err(|e| {
356 StreamKitError::Runtime(format!(
357 "Resource initialization failed for '{name}': {e}"
358 ))
359 })?;
360
361 tracing::debug!("Resource loaded for node '{}', calling factory", name);
362 }
363
364 (info.factory)(params)
366 }
367
368 pub fn definitions(&self) -> Vec<NodeDefinition> {
370 let mut defs = Vec::new();
371 for (kind, info) in &self.info {
372 let (inputs, outputs) = match &info.static_pins {
373 Some(pins) => (pins.inputs.clone(), pins.outputs.clone()),
374 None => {
375 match (info.factory)(None) {
377 Ok(node_instance) => {
378 (node_instance.input_pins(), node_instance.output_pins())
379 },
380 Err(e) => {
381 tracing::error!(kind=%kind, error=%e, "Failed to create temporary node instance for dynamic node definition");
382 continue;
383 },
384 }
385 },
386 };
387
388 defs.push(NodeDefinition {
389 kind: kind.clone(),
390 description: info.description.clone(),
391 param_schema: info.param_schema.clone(),
392 inputs,
393 outputs,
394 categories: info.categories.clone(),
395 bidirectional: info.bidirectional,
396 });
397 }
398 defs
399 }
400
401 pub fn unregister(&mut self, name: &str) -> bool {
404 self.info.remove(name).is_some()
405 }
406
407 pub fn contains(&self, name: &str) -> bool {
409 self.info.contains_key(name)
410 }
411}