1use crate::protocol::*;
4use somatize_core::cache::CacheStore;
5use somatize_core::event::Event;
6use somatize_core::filter::Filter;
7use somatize_runtime::{Context, EventBus, FilterLibrary, MemoryCache, execute};
8use std::sync::Arc;
9use std::time::Instant;
10
11pub struct Worker {
13 pub id: WorkerId,
14 pub capabilities: Capabilities,
15 event_bus: Arc<EventBus>,
16 cache: Arc<dyn CacheStore>,
17 filters: FilterLibrary,
18}
19
20impl Worker {
21 pub fn new(id: impl Into<String>, capabilities: Capabilities) -> Self {
22 Self {
23 id: id.into(),
24 capabilities,
25 event_bus: Arc::new(EventBus::new(256)),
26 cache: Arc::new(MemoryCache::default()),
27 filters: FilterLibrary::new(),
28 }
29 }
30
31 pub fn with_cache(mut self, cache: Arc<dyn CacheStore>) -> Self {
33 self.cache = cache;
34 self
35 }
36
37 pub fn register_filter(&mut self, node_id: impl Into<String>, filter: Box<dyn Filter>) {
39 self.filters.register(node_id, filter);
40 }
41
42 pub fn subscribe(&self) -> tokio::sync::broadcast::Receiver<Event> {
44 self.event_bus.subscribe()
45 }
46
47 pub fn registration_message(&self) -> WorkerToCoordinator {
49 WorkerToCoordinator::Register {
50 worker_id: self.id.clone(),
51 capabilities: self.capabilities.clone(),
52 }
53 }
54
55 pub fn execute_plan(&mut self, plan: &SerializedPlan) -> PlanResult {
57 let start = Instant::now();
58
59 let mut ctx = Context::new(
60 self.event_bus.clone(),
61 format!("worker_run_{}", plan.plan_id),
62 );
63
64 if let Some(input_source) = &plan.input {
66 use crate::protocol::InputSource;
67 let input_value = match input_source {
68 InputSource::Inline { value } => value.clone(),
69 InputSource::Reference { data_ref } => {
70 if let Some(store) = &ctx.data_store {
72 store
73 .get(data_ref)
74 .unwrap_or(somatize_core::value::Value::Empty)
75 } else {
76 tracing::warn!("DataRef input but no DataStore configured on worker");
77 somatize_core::value::Value::Empty
78 }
79 }
80 };
81 ctx.set("input", input_value);
82 }
83
84 match execute(&plan.plan, &mut ctx, &self.filters, self.cache.as_ref()) {
85 Ok(()) => {
86 let output = ctx
88 .execution_order
89 .last()
90 .and_then(|id| ctx.get(id))
91 .cloned()
92 .unwrap_or(somatize_core::value::Value::Empty);
93
94 PlanResult::Success {
95 output,
96 duration_ms: start.elapsed().as_millis() as u64,
97 }
98 }
99 Err(e) => PlanResult::Failed {
100 error: e.to_string(),
101 duration_ms: start.elapsed().as_millis() as u64,
102 },
103 }
104 }
105
106 pub fn matches_target(&self, target: &somatize_core::filter::RemoteTarget) -> bool {
108 match target {
109 somatize_core::filter::RemoteTarget::WorkerId(id) => &self.id == id,
110 somatize_core::filter::RemoteTarget::Tag(tag) => self.capabilities.tags.contains(tag),
111 }
112 }
113}
114
115#[cfg(test)]
116mod tests {
117 use super::*;
118 use somatize_compiler::ExecutionPlan;
119 use somatize_core::cache::CacheKey;
120 use somatize_core::error::Result as SomaResult;
121 use somatize_core::filter::{FilterKind, FilterMeta, StreamMode};
122 use somatize_core::value::Value;
123
124 struct TestDoubler;
125
126 impl Filter for TestDoubler {
127 fn config_hash(&self) -> CacheKey {
128 CacheKey::from_parts(&[b"TestDoubler"])
129 }
130 fn fit(&self, _x: &Value, _y: Option<&Value>) -> SomaResult<Value> {
131 Ok(Value::Empty)
132 }
133 fn forward(&self, x: &Value, _state: &Value) -> SomaResult<Value> {
134 match x {
135 Value::Tensor { values, shape } => {
136 let doubled: Vec<f64> = values.iter().map(|v| v * 2.0).collect();
137 Ok(Value::tensor(doubled, shape.clone()))
138 }
139 _ => Ok(x.clone()),
140 }
141 }
142 fn meta(&self) -> FilterMeta {
143 FilterMeta {
144 name: "TestDoubler".into(),
145 kind: FilterKind::Stateless,
146 cacheable: true,
147 differentiable: true,
148 stream_mode: StreamMode::FixedState,
149 distribution: somatize_core::filter::Distribution::Local,
150 input_schema: None,
151 output_schema: None,
152 }
153 }
154 }
155
156 fn make_worker() -> Worker {
157 Worker::new(
158 "test_worker",
159 Capabilities {
160 cpu_cores: 4,
161 ram_bytes: 8_000_000_000,
162 gpus: vec![],
163 python_envs: vec![],
164 tags: vec!["cpu".into(), "test".into()],
165 },
166 )
167 }
168
169 #[test]
170 fn worker_registration() {
171 let worker = make_worker();
172 let msg = worker.registration_message();
173 if let WorkerToCoordinator::Register {
174 worker_id,
175 capabilities,
176 } = msg
177 {
178 assert_eq!(worker_id, "test_worker");
179 assert_eq!(capabilities.cpu_cores, 4);
180 } else {
181 panic!("wrong message type");
182 }
183 }
184
185 #[test]
186 fn worker_executes_plan_successfully() {
187 let mut worker = make_worker();
188 worker.register_filter("doubler", Box::new(TestDoubler));
189
190 let plan = SerializedPlan {
191 plan_id: "p_001".into(),
192 plan: ExecutionPlan::Execute {
193 node_id: "doubler".into(),
194 },
195 input: Some(crate::protocol::InputSource::Inline {
196 value: Value::tensor(vec![1.0, 2.0, 3.0], vec![3]),
197 }),
198 metadata: serde_json::json!({}),
199 };
200
201 let result = worker.execute_plan(&plan);
202
203 if let PlanResult::Success {
204 output,
205 duration_ms,
206 } = result
207 {
208 let (data, _) = output.as_tensor().unwrap();
209 assert_eq!(data, &[2.0, 4.0, 6.0]);
210 assert!(duration_ms < 1000);
211 } else {
212 panic!("expected success, got: {result:?}");
213 }
214 }
215
216 #[test]
217 fn worker_handles_missing_filter() {
218 let mut worker = make_worker();
219 let plan = SerializedPlan {
222 plan_id: "p_002".into(),
223 plan: ExecutionPlan::Execute {
224 node_id: "nonexistent".into(),
225 },
226 input: None,
227 metadata: serde_json::json!({}),
228 };
229
230 let result = worker.execute_plan(&plan);
231 assert!(matches!(result, PlanResult::Failed { .. }));
232 }
233
234 #[test]
235 fn worker_matches_target_by_id() {
236 let worker = make_worker();
237 assert!(
238 worker.matches_target(&somatize_core::filter::RemoteTarget::WorkerId(
239 "test_worker".into()
240 ))
241 );
242 assert!(
243 !worker.matches_target(&somatize_core::filter::RemoteTarget::WorkerId(
244 "other".into()
245 ))
246 );
247 }
248
249 #[test]
250 fn worker_matches_target_by_tag() {
251 let worker = make_worker();
252 assert!(worker.matches_target(&somatize_core::filter::RemoteTarget::Tag("cpu".into())));
253 assert!(worker.matches_target(&somatize_core::filter::RemoteTarget::Tag("test".into())));
254 assert!(!worker.matches_target(&somatize_core::filter::RemoteTarget::Tag("gpu".into())));
255 }
256
257 #[test]
258 fn worker_executes_sequence() {
259 let mut worker = make_worker();
260 worker.register_filter("d1", Box::new(TestDoubler));
261 worker.register_filter("d2", Box::new(TestDoubler));
262
263 let plan = SerializedPlan {
264 plan_id: "p_003".into(),
265 plan: ExecutionPlan::Sequence(vec![
266 ExecutionPlan::Execute {
267 node_id: "d1".into(),
268 },
269 ExecutionPlan::Execute {
270 node_id: "d2".into(),
271 },
272 ]),
273 input: Some(crate::protocol::InputSource::Inline {
274 value: Value::tensor(vec![5.0], vec![1]),
275 }),
276 metadata: serde_json::json!({}),
277 };
278
279 let result = worker.execute_plan(&plan);
280 if let PlanResult::Success { output, .. } = result {
281 let (data, _) = output.as_tensor().unwrap();
282 assert_eq!(data, &[20.0]); } else {
284 panic!("expected success");
285 }
286 }
287
288 #[test]
289 fn worker_emits_events() {
290 let mut worker = make_worker();
291 worker.register_filter("doubler", Box::new(TestDoubler));
292 let mut rx = worker.subscribe();
293
294 let plan = SerializedPlan {
295 plan_id: "p_004".into(),
296 plan: ExecutionPlan::Execute {
297 node_id: "doubler".into(),
298 },
299 input: Some(crate::protocol::InputSource::Inline {
300 value: Value::tensor(vec![1.0], vec![1]),
301 }),
302 metadata: serde_json::json!({}),
303 };
304
305 worker.execute_plan(&plan);
306
307 let mut events = Vec::new();
308 while let Ok(e) = rx.try_recv() {
309 events.push(e);
310 }
311 assert!(
312 events
313 .iter()
314 .any(|e| matches!(e, Event::NodeStarted { .. }))
315 );
316 assert!(
317 events
318 .iter()
319 .any(|e| matches!(e, Event::NodeCompleted { .. }))
320 );
321 }
322}