Skip to main content

skg_orch_local/
lib.rs

1#![deny(missing_docs)]
2//! In-process implementation of layer0's Orchestrator trait.
3//!
4//! Dispatches to registered operators via `HashMap<OperatorId, Arc<dyn Operator>>`.
5//! Concurrent dispatch uses `tokio::spawn`. No durability — operators that fail
6//! are not retried and state is not persisted. Workflow `signal` semantics and a
7//! minimal `query` are implemented via an in-memory, per-workflow signal journal.
8
9use async_trait::async_trait;
10use layer0::effect::SignalPayload;
11use layer0::error::OrchError;
12use layer0::id::{OperatorId, WorkflowId};
13use layer0::middleware::{DispatchNext, DispatchStack};
14use layer0::operator::{Operator, OperatorInput, OperatorOutput};
15use layer0::orchestrator::{Orchestrator, QueryPayload};
16use serde_json::json;
17use std::collections::HashMap;
18use std::sync::Arc;
19use tokio::sync::RwLock;
20
21/// In-process orchestrator that dispatches to registered operators.
22///
23/// Uses `Arc<dyn Operator>` for true concurrent dispatch via `tokio::spawn`.
24/// No durability, but tracks workflow signals in-memory for `signal`/`query`.
25/// Suitable for development, testing, and single-process deployments.
26pub struct LocalOrch {
27    agents: HashMap<String, Arc<dyn Operator>>,
28    // Per-workflow signal journal
29    workflow_signals: RwLock<HashMap<String, Vec<SignalPayload>>>,
30    /// Optional middleware stack for Pre/PostDispatch interception.
31    middleware: Option<DispatchStack>,
32}
33
34impl LocalOrch {
35    /// Create a new empty orchestrator.
36    pub fn new() -> Self {
37        Self {
38            agents: HashMap::new(),
39            workflow_signals: RwLock::new(HashMap::new()),
40            middleware: None,
41        }
42    }
43
44    /// Register an operator with the orchestrator.
45    pub fn register(&mut self, id: OperatorId, op: Arc<dyn Operator>) {
46        self.agents.insert(id.to_string(), op);
47    }
48
49    /// Return the number of recorded signals for a workflow.
50    pub async fn signal_count(&self, target: &WorkflowId) -> usize {
51        let workflows = self.workflow_signals.read().await;
52        workflows.get(target.as_str()).map(|v| v.len()).unwrap_or(0)
53    }
54
55    /// Attach a middleware stack for dispatch interception.
56    pub fn with_middleware(mut self, stack: DispatchStack) -> Self {
57        self.middleware = Some(stack);
58        self
59    }
60}
61
62impl Default for LocalOrch {
63    fn default() -> Self {
64        Self::new()
65    }
66}
67
68/// Terminal dispatch: looks up the operator and calls `execute()`.
69struct OperatorDispatch<'a> {
70    agents: &'a HashMap<String, Arc<dyn Operator>>,
71}
72
73#[async_trait]
74impl DispatchNext for OperatorDispatch<'_> {
75    async fn dispatch(
76        &self,
77        operator: &OperatorId,
78        input: OperatorInput,
79    ) -> Result<OperatorOutput, OrchError> {
80        let op = self
81            .agents
82            .get(operator.as_str())
83            .ok_or_else(|| OrchError::OperatorNotFound(operator.to_string()))?;
84        op.execute(input).await.map_err(OrchError::OperatorError)
85    }
86}
87
88#[async_trait]
89impl Orchestrator for LocalOrch {
90    #[tracing::instrument(skip_all, fields(operator_id = %operator))]
91    async fn dispatch(
92        &self,
93        operator: &OperatorId,
94        input: OperatorInput,
95    ) -> Result<OperatorOutput, OrchError> {
96        let terminal = OperatorDispatch {
97            agents: &self.agents,
98        };
99
100        if let Some(ref stack) = self.middleware {
101            stack.dispatch_with(operator, input, &terminal).await
102        } else {
103            terminal.dispatch(operator, input).await
104        }
105    }
106
107    #[tracing::instrument(skip_all, fields(count = tasks.len()))]
108    async fn dispatch_many(
109        &self,
110        tasks: Vec<(OperatorId, OperatorInput)>,
111    ) -> Vec<Result<OperatorOutput, OrchError>> {
112        let mut handles = Vec::with_capacity(tasks.len());
113
114        for (operator_id, input) in tasks {
115            match self.agents.get(operator_id.as_str()) {
116                Some(op) => {
117                    let op = Arc::clone(op);
118                    handles.push(tokio::spawn(async move {
119                        op.execute(input).await.map_err(OrchError::OperatorError)
120                    }));
121                }
122                None => {
123                    let name = operator_id.to_string();
124                    handles.push(tokio::spawn(async move {
125                        Err(OrchError::OperatorNotFound(name))
126                    }));
127                }
128            }
129        }
130
131        let mut results = Vec::with_capacity(handles.len());
132        for handle in handles {
133            match handle.await {
134                Ok(result) => results.push(result),
135                Err(e) => results.push(Err(OrchError::DispatchFailed(e.to_string()))),
136            }
137        }
138
139        results
140    }
141
142    async fn signal(&self, target: &WorkflowId, signal: SignalPayload) -> Result<(), OrchError> {
143        let mut workflows = self.workflow_signals.write().await;
144        workflows
145            .entry(target.to_string())
146            .or_default()
147            .push(signal);
148        Ok(())
149    }
150
151    async fn query(
152        &self,
153        target: &WorkflowId,
154        _query: QueryPayload,
155    ) -> Result<serde_json::Value, OrchError> {
156        let workflows = self.workflow_signals.read().await;
157        let count = workflows.get(target.as_str()).map(|v| v.len()).unwrap_or(0);
158        Ok(json!({ "signals": count }))
159    }
160}