1#![deny(missing_docs)]
2use async_trait::async_trait;
10use layer0::effect::SignalPayload;
11use layer0::error::OrchError;
12use layer0::id::{OperatorId, WorkflowId};
13use layer0::middleware::{DispatchNext, DispatchStack};
14use layer0::operator::{Operator, OperatorInput, OperatorOutput};
15use layer0::orchestrator::{Orchestrator, QueryPayload};
16use serde_json::json;
17use std::collections::HashMap;
18use std::sync::Arc;
19use tokio::sync::RwLock;
20
21pub struct LocalOrch {
27 agents: HashMap<String, Arc<dyn Operator>>,
28 workflow_signals: RwLock<HashMap<String, Vec<SignalPayload>>>,
30 middleware: Option<DispatchStack>,
32}
33
34impl LocalOrch {
35 pub fn new() -> Self {
37 Self {
38 agents: HashMap::new(),
39 workflow_signals: RwLock::new(HashMap::new()),
40 middleware: None,
41 }
42 }
43
44 pub fn register(&mut self, id: OperatorId, op: Arc<dyn Operator>) {
46 self.agents.insert(id.to_string(), op);
47 }
48
49 pub async fn signal_count(&self, target: &WorkflowId) -> usize {
51 let workflows = self.workflow_signals.read().await;
52 workflows.get(target.as_str()).map(|v| v.len()).unwrap_or(0)
53 }
54
55 pub fn with_middleware(mut self, stack: DispatchStack) -> Self {
57 self.middleware = Some(stack);
58 self
59 }
60}
61
62impl Default for LocalOrch {
63 fn default() -> Self {
64 Self::new()
65 }
66}
67
68struct OperatorDispatch<'a> {
70 agents: &'a HashMap<String, Arc<dyn Operator>>,
71}
72
73#[async_trait]
74impl DispatchNext for OperatorDispatch<'_> {
75 async fn dispatch(
76 &self,
77 operator: &OperatorId,
78 input: OperatorInput,
79 ) -> Result<OperatorOutput, OrchError> {
80 let op = self
81 .agents
82 .get(operator.as_str())
83 .ok_or_else(|| OrchError::OperatorNotFound(operator.to_string()))?;
84 op.execute(input).await.map_err(OrchError::OperatorError)
85 }
86}
87
88#[async_trait]
89impl Orchestrator for LocalOrch {
90 #[tracing::instrument(skip_all, fields(operator_id = %operator))]
91 async fn dispatch(
92 &self,
93 operator: &OperatorId,
94 input: OperatorInput,
95 ) -> Result<OperatorOutput, OrchError> {
96 let terminal = OperatorDispatch {
97 agents: &self.agents,
98 };
99
100 if let Some(ref stack) = self.middleware {
101 stack.dispatch_with(operator, input, &terminal).await
102 } else {
103 terminal.dispatch(operator, input).await
104 }
105 }
106
107 #[tracing::instrument(skip_all, fields(count = tasks.len()))]
108 async fn dispatch_many(
109 &self,
110 tasks: Vec<(OperatorId, OperatorInput)>,
111 ) -> Vec<Result<OperatorOutput, OrchError>> {
112 let mut handles = Vec::with_capacity(tasks.len());
113
114 for (operator_id, input) in tasks {
115 match self.agents.get(operator_id.as_str()) {
116 Some(op) => {
117 let op = Arc::clone(op);
118 handles.push(tokio::spawn(async move {
119 op.execute(input).await.map_err(OrchError::OperatorError)
120 }));
121 }
122 None => {
123 let name = operator_id.to_string();
124 handles.push(tokio::spawn(async move {
125 Err(OrchError::OperatorNotFound(name))
126 }));
127 }
128 }
129 }
130
131 let mut results = Vec::with_capacity(handles.len());
132 for handle in handles {
133 match handle.await {
134 Ok(result) => results.push(result),
135 Err(e) => results.push(Err(OrchError::DispatchFailed(e.to_string()))),
136 }
137 }
138
139 results
140 }
141
142 async fn signal(&self, target: &WorkflowId, signal: SignalPayload) -> Result<(), OrchError> {
143 let mut workflows = self.workflow_signals.write().await;
144 workflows
145 .entry(target.to_string())
146 .or_default()
147 .push(signal);
148 Ok(())
149 }
150
151 async fn query(
152 &self,
153 target: &WorkflowId,
154 _query: QueryPayload,
155 ) -> Result<serde_json::Value, OrchError> {
156 let workflows = self.workflow_signals.read().await;
157 let count = workflows.get(target.as_str()).map(|v| v.len()).unwrap_or(0);
158 Ok(json!({ "signals": count }))
159 }
160}