Skip to main content

walrus_daemon/service/
manager.rs

1//! Service lifecycle management — spawn, handshake, registry, shutdown.
2
3use crate::service::config::{ServiceConfig, ServiceKind};
4use anyhow::{Context, Result, bail};
5use std::{
6    collections::BTreeMap,
7    path::{Path, PathBuf},
8    sync::Arc,
9};
10use tokio::{
11    net::unix::{OwnedReadHalf, OwnedWriteHalf},
12    process::Child,
13    sync::Mutex,
14    time,
15};
16use wcore::{
17    ToolRegistry,
18    model::Tool,
19    protocol::{
20        PROTOCOL_VERSION,
21        codec::{read_message, write_message},
22        ext::{
23            Capability, ExtConfigure, ExtConfigured, ExtError, ExtHello, ExtReady,
24            ExtRegisterTools, ExtRequest, ExtResponse, ExtToolCall, ExtToolResult, ExtToolSchemas,
25            ToolsList, capability, ext_request, ext_response,
26        },
27    },
28};
29
30/// Handle to a connected extension service.
31pub struct ServiceHandle {
32    pub name: String,
33    pub capabilities: Vec<Capability>,
34    writer: Mutex<OwnedWriteHalf>,
35    reader: Mutex<OwnedReadHalf>,
36    /// Serializes request-response pairs to prevent interleaving.
37    rpc_lock: Mutex<()>,
38}
39
40impl ServiceHandle {
41    /// Send an extension request and read one response.
42    pub async fn request(&self, req: &ExtRequest) -> Result<ExtResponse> {
43        let _guard = self.rpc_lock.lock().await;
44        let mut w = self.writer.lock().await;
45        write_message(&mut *w, req).await.context("ext write")?;
46        drop(w);
47        let mut r = self.reader.lock().await;
48        let resp: ExtResponse = read_message(&mut *r).await.context("ext read")?;
49        Ok(resp)
50    }
51
52    /// Send a fire-and-forget extension request (no response expected).
53    pub async fn send(&self, req: &ExtRequest) -> Result<()> {
54        let _guard = self.rpc_lock.lock().await;
55        let mut w = self.writer.lock().await;
56        write_message(&mut *w, req).await.context("ext write")?;
57        Ok(())
58    }
59}
60
61/// Capability-indexed runtime state built during handshake.
62#[derive(Default)]
63pub struct ServiceRegistry {
64    /// Tool name → owning service handle.
65    pub tools: BTreeMap<String, Arc<ServiceHandle>>,
66    /// Service name → handle (for ServiceQuery routing).
67    pub query: BTreeMap<String, Arc<ServiceHandle>>,
68    /// Tool schemas collected from all extension services.
69    pub tool_schemas: Vec<Tool>,
70}
71
72impl ServiceRegistry {
73    /// Dispatch a tool call to the owning extension service.
74    /// Returns `None` if the tool is not in the registry.
75    pub async fn dispatch_tool(
76        &self,
77        name: &str,
78        args: &str,
79        agent: &str,
80        task_id: Option<u64>,
81    ) -> Option<String> {
82        let handle = self.tools.get(name)?;
83        let req = ExtRequest {
84            msg: Some(ext_request::Msg::ToolCall(ExtToolCall {
85                name: name.to_owned(),
86                args: args.to_owned(),
87                agent: agent.to_owned(),
88                task_id,
89            })),
90        };
91        Some(
92            match time::timeout(std::time::Duration::from_secs(30), handle.request(&req)).await {
93                Ok(Ok(resp)) => match resp.msg {
94                    Some(ext_response::Msg::ToolResult(ExtToolResult { result })) => result,
95                    Some(ext_response::Msg::Error(ExtError { message })) => {
96                        format!("service error: {message}")
97                    }
98                    other => format!("unexpected response: {other:?}"),
99                },
100                Ok(Err(e)) => format!("service unavailable: {name} ({e})"),
101                Err(_) => format!("service timeout: {name}"),
102            },
103        )
104    }
105
106    /// Register tool schemas into the tool registry.
107    pub async fn register_tools(&self, tools: &mut ToolRegistry) {
108        tools.insert_all(self.tool_schemas.clone());
109    }
110}
111
112/// Entry tracking a spawned service process.
113struct ServiceEntry {
114    config: ServiceConfig,
115    child: Option<Child>,
116    socket_path: PathBuf,
117}
118
119/// Manages the lifecycle of daemon child services.
120pub struct ServiceManager {
121    entries: BTreeMap<String, ServiceEntry>,
122    services_dir: PathBuf,
123    /// Daemon UDS socket path — passed to gateway services via `--daemon`.
124    daemon_socket: PathBuf,
125}
126
127const HANDSHAKE_TIMEOUT: time::Duration = time::Duration::from_secs(10);
128
129impl ServiceManager {
130    /// Create a new manager from config. Does not spawn anything yet.
131    ///
132    /// `daemon_socket` is the daemon's UDS path — forwarded to gateway services
133    /// so they can connect back.
134    pub fn new(
135        configs: &BTreeMap<String, ServiceConfig>,
136        config_dir: &Path,
137        daemon_socket: PathBuf,
138    ) -> Self {
139        let services_dir = config_dir.join("services");
140        let entries = configs
141            .iter()
142            .filter(|(_, c)| c.enabled)
143            .map(|(name, config)| {
144                let socket_path = services_dir.join(format!("{name}.sock"));
145                (
146                    name.clone(),
147                    ServiceEntry {
148                        config: config.clone(),
149                        child: None,
150                        socket_path,
151                    },
152                )
153            })
154            .collect();
155        Self {
156            entries,
157            services_dir,
158            daemon_socket,
159        }
160    }
161
162    /// Spawn all enabled services.
163    ///
164    /// Extension services get `--socket <path>` so they bind a UDS listener.
165    /// Gateway services get `--daemon <path>` and `--config <json>` so they
166    /// can connect back to the daemon.
167    pub async fn spawn_all(&mut self) -> Result<()> {
168        std::fs::create_dir_all(&self.services_dir).context("create services dir")?;
169        let logs_dir = &*wcore::paths::LOGS_DIR;
170        std::fs::create_dir_all(logs_dir).context("create logs dir")?;
171
172        for (name, entry) in &mut self.entries {
173            // Clean up stale socket.
174            if entry.socket_path.exists() {
175                let _ = std::fs::remove_file(&entry.socket_path);
176            }
177
178            // Resolve binary: try ~/.cargo/bin/<krate> first (launchd/systemd
179            // don't inherit the user's shell PATH), fall back to bare name.
180            let cargo_bin = std::env::var("HOME").ok().map(|h| {
181                PathBuf::from(h)
182                    .join(".cargo/bin")
183                    .join(&entry.config.krate)
184            });
185            let binary = match cargo_bin {
186                Some(ref p) if p.exists() => p.as_path(),
187                _ => Path::new(&entry.config.krate),
188            };
189            tracing::info!(
190                service = %name,
191                binary = %binary.display(),
192                kind = ?entry.config.kind,
193                "spawning service"
194            );
195            let mut cmd = tokio::process::Command::new(binary);
196            for (k, v) in &entry.config.env {
197                cmd.env(k, v);
198            }
199
200            // Forward RUST_LOG so child services inherit the daemon's log level.
201            if !entry.config.env.contains_key("RUST_LOG")
202                && let Ok(rust_log) = std::env::var("RUST_LOG")
203            {
204                cmd.env("RUST_LOG", rust_log);
205            }
206
207            // Redirect stdout/stderr to per-service log files.
208            let log_path = logs_dir.join(format!("{name}.log"));
209            let log_file = std::fs::File::create(&log_path)
210                .with_context(|| format!("create log file for '{name}'"))?;
211            cmd.stdout(log_file.try_clone()?);
212            cmd.stderr(log_file);
213
214            cmd.arg("serve");
215            match entry.config.kind {
216                ServiceKind::Extension => {
217                    cmd.arg("--socket").arg(&entry.socket_path);
218                }
219                ServiceKind::Gateway => {
220                    cmd.arg("--daemon").arg(&self.daemon_socket);
221                    let config_json = serde_json::to_string(&entry.config.config)
222                        .unwrap_or_else(|_| "{}".to_owned());
223                    cmd.arg("--config").arg(config_json);
224                }
225            }
226
227            cmd.kill_on_drop(true);
228            let child = cmd.spawn().with_context(|| {
229                format!("spawn service '{name}' (binary: {})", binary.display())
230            })?;
231            tracing::info!(service = %name, pid = child.id(), log = %log_path.display(), "spawned service");
232            entry.child = Some(child);
233        }
234
235        Ok(())
236    }
237
238    /// Connect to all extension services and perform the handshake.
239    /// Returns a `ServiceRegistry` with tool and query mappings.
240    pub async fn handshake_all(&self) -> ServiceRegistry {
241        let mut registry = ServiceRegistry::default();
242
243        for (name, entry) in &self.entries {
244            if !matches!(entry.config.kind, ServiceKind::Extension) {
245                continue;
246            }
247
248            match self
249                .handshake_one(name, &entry.socket_path, &entry.config.config)
250                .await
251            {
252                Ok((handle, schemas)) => {
253                    let handle = Arc::new(handle);
254                    Self::register(&mut registry, &handle);
255                    tracing::info!(
256                        service = %name,
257                        tools = schemas.len(),
258                        "extension registered"
259                    );
260                    registry.tool_schemas.extend(schemas);
261                }
262                Err(e) => {
263                    tracing::warn!(service = %name, error = %e, "extension handshake failed, skipping");
264                }
265            }
266        }
267
268        registry
269    }
270
271    /// Perform handshake with a single extension service.
272    /// Returns the handle and its declared tool schemas.
273    async fn handshake_one(
274        &self,
275        name: &str,
276        socket_path: &Path,
277        config: &serde_json::Value,
278    ) -> Result<(ServiceHandle, Vec<Tool>)> {
279        // Wait for socket file to appear (service may need startup time).
280        let deadline = time::Instant::now() + HANDSHAKE_TIMEOUT;
281        loop {
282            if socket_path.exists() {
283                break;
284            }
285            if time::Instant::now() >= deadline {
286                bail!(
287                    "socket not found after {}s: {}",
288                    HANDSHAKE_TIMEOUT.as_secs(),
289                    socket_path.display()
290                );
291            }
292            time::sleep(time::Duration::from_millis(50)).await;
293        }
294
295        let stream = time::timeout(
296            HANDSHAKE_TIMEOUT,
297            tokio::net::UnixStream::connect(socket_path),
298        )
299        .await
300        .context("connect timeout")?
301        .context("connect")?;
302
303        let (read_half, write_half) = stream.into_split();
304        let writer = Mutex::new(write_half);
305        let reader = Mutex::new(read_half);
306
307        // Hello → Ready
308        let hello = ExtRequest {
309            msg: Some(ext_request::Msg::Hello(ExtHello {
310                version: PROTOCOL_VERSION.to_owned(),
311            })),
312        };
313        {
314            let mut w = writer.lock().await;
315            write_message(&mut *w, &hello)
316                .await
317                .context("write Hello")?;
318        }
319        let ready: ExtResponse = {
320            let mut r = reader.lock().await;
321            time::timeout(HANDSHAKE_TIMEOUT, read_message(&mut *r))
322                .await
323                .context("Ready timeout")?
324                .context("read Ready")?
325        };
326        let (service, capabilities) = match ready.msg {
327            Some(ext_response::Msg::Ready(ExtReady {
328                service,
329                capabilities,
330                ..
331            })) => (service, capabilities),
332            Some(ext_response::Msg::Error(ExtError { message })) => {
333                bail!("service error: {message}")
334            }
335            other => bail!("unexpected response to Hello: {other:?}"),
336        };
337        tracing::debug!(service = %service, "handshake Hello/Ready complete");
338
339        let handle = ServiceHandle {
340            name: service,
341            capabilities,
342            writer,
343            reader,
344            rpc_lock: Mutex::new(()),
345        };
346
347        // Configure → Configured
348        let config_json = serde_json::to_string(config).context("serialize service config")?;
349        let configure_req = ExtRequest {
350            msg: Some(ext_request::Msg::Configure(ExtConfigure {
351                config: config_json,
352            })),
353        };
354        let configure_resp = time::timeout(HANDSHAKE_TIMEOUT, handle.request(&configure_req))
355            .await
356            .context("Configure timeout")?
357            .context("Configure")?;
358        match configure_resp.msg {
359            Some(ext_response::Msg::Configured(ExtConfigured {})) => {}
360            Some(ext_response::Msg::Error(ExtError { message })) => {
361                bail!("Configure error: {message}")
362            }
363            other => bail!("unexpected response to Configure: {other:?}"),
364        }
365        tracing::debug!(service = %name, "handshake Configure/Configured complete");
366
367        // RegisterTools → ToolSchemas
368        let register_tools_req = ExtRequest {
369            msg: Some(ext_request::Msg::RegisterTools(ExtRegisterTools {})),
370        };
371        let resp = time::timeout(HANDSHAKE_TIMEOUT, handle.request(&register_tools_req))
372            .await
373            .context("RegisterTools timeout")?
374            .context("RegisterTools")?;
375        let tool_defs = match resp.msg {
376            Some(ext_response::Msg::ToolSchemas(ExtToolSchemas { tools })) => tools,
377            Some(ext_response::Msg::Error(ExtError { message })) => {
378                bail!("RegisterTools error: {message}")
379            }
380            other => bail!("unexpected response to RegisterTools: {other:?}"),
381        };
382        tracing::debug!(service = %name, tools = tool_defs.len(), "handshake RegisterTools/ToolSchemas complete");
383
384        // Convert ToolDef (proto) → Tool (domain).
385        let tools: Vec<Tool> = tool_defs
386            .into_iter()
387            .map(|td| Tool {
388                name: td.name.to_string(),
389                description: td.description.to_string(),
390                parameters: serde_json::from_slice(&td.parameters).unwrap_or_else(|_| true.into()),
391                strict: td.strict,
392            })
393            .collect();
394
395        Ok((handle, tools))
396    }
397
398    /// Populate the registry from a service handle's capabilities.
399    fn register(registry: &mut ServiceRegistry, handle: &Arc<ServiceHandle>) {
400        for cap in &handle.capabilities {
401            match &cap.cap {
402                Some(capability::Cap::Tools(ToolsList { names })) => {
403                    for tool_name in names {
404                        registry.tools.insert(tool_name.clone(), Arc::clone(handle));
405                    }
406                }
407                Some(capability::Cap::Query(_)) => {
408                    registry
409                        .query
410                        .insert(handle.name.to_string(), Arc::clone(handle));
411                }
412                _ => {}
413            }
414        }
415    }
416
417    /// Graceful shutdown of all services. Signals each child to stop,
418    /// waits up to 5s, then force-kills stragglers.
419    pub async fn shutdown_all(&mut self) {
420        // Signal all children to stop.
421        for (name, entry) in &mut self.entries {
422            if let Some(ref mut child) = entry.child {
423                tracing::debug!(service = %name, pid = child.id(), "stopping service");
424                let _ = child.start_kill();
425            }
426        }
427
428        // Wait for exit, force-kill on timeout.
429        for (name, entry) in &mut self.entries {
430            if let Some(ref mut child) = entry.child {
431                match time::timeout(time::Duration::from_secs(5), child.wait()).await {
432                    Ok(Ok(status)) => {
433                        tracing::debug!(service = %name, %status, "service exited");
434                    }
435                    Ok(Err(e)) => {
436                        tracing::warn!(service = %name, error = %e, "error waiting for service");
437                    }
438                    Err(_) => {
439                        tracing::warn!(service = %name, "service did not exit in 5s, killing");
440                        let _ = child.kill().await;
441                    }
442                }
443            }
444            let _ = std::fs::remove_file(&entry.socket_path);
445        }
446    }
447}