Skip to main content

scud/backend/
direct.rs

1//! Direct API backend.
2//!
3//! Wraps the existing `llm::agent::run_agent_loop()` behind the
4//! [`AgentBackend`] trait for in-process LLM execution.
5
6use anyhow::Result;
7use async_trait::async_trait;
8
9use super::{AgentBackend, AgentHandle, AgentRequest};
10
11#[cfg(feature = "direct-api")]
12use {
13    super::{AgentEvent, AgentResult, AgentStatus, ToolCallRecord},
14    tokio::sync::mpsc,
15    tokio_util::sync::CancellationToken,
16};
17
18/// Backend that calls LLM APIs directly (in-process).
19///
20/// Only available with the `direct-api` feature.
21#[cfg(feature = "direct-api")]
22pub struct DirectApiBackend {
23    max_tokens: u32,
24}
25
26#[cfg(feature = "direct-api")]
27impl DirectApiBackend {
28    pub fn new() -> Self {
29        Self { max_tokens: 16_000 }
30    }
31}
32
33#[cfg(feature = "direct-api")]
34#[async_trait]
35impl AgentBackend for DirectApiBackend {
36    async fn execute(&self, req: AgentRequest) -> Result<AgentHandle> {
37        use crate::commands::spawn::headless::events::{StreamEvent, StreamEventKind};
38        use crate::llm::agent;
39        use crate::llm::provider::AgentProvider;
40
41        let (event_tx, rx) = mpsc::channel(1000);
42        let cancel = CancellationToken::new();
43
44        let provider = if let Some(ref p) = req.provider {
45            AgentProvider::from_provider_str(p)?
46        } else {
47            AgentProvider::Anthropic
48        };
49
50        let model = req.model.clone();
51        let max_tokens = self.max_tokens;
52        let prompt = req.prompt.clone();
53        let working_dir = req.working_dir.clone();
54        let system_prompt = req.system_prompt.clone();
55
56        // Bridge: run_agent_loop emits StreamEvent, we convert to AgentEvent
57        let (stream_tx, mut stream_rx) = mpsc::channel::<StreamEvent>(1000);
58        let cancel_clone = cancel.clone();
59
60        // Spawn the agent loop
61        let stream_tx_err = stream_tx.clone();
62        tokio::spawn(async move {
63            if let Err(e) = agent::run_agent_loop(
64                &prompt,
65                system_prompt.as_deref(),
66                &working_dir,
67                model.as_deref(),
68                max_tokens,
69                stream_tx,
70                &provider,
71            )
72            .await
73            {
74                let _ = stream_tx_err.send(StreamEvent::error(&e.to_string())).await;
75                let _ = stream_tx_err.send(StreamEvent::complete(false)).await;
76            }
77        });
78
79        // Bridge task: StreamEvent -> AgentEvent
80        tokio::spawn(async move {
81            let mut text_parts = Vec::new();
82            let mut tool_calls: Vec<ToolCallRecord> = Vec::new();
83
84            loop {
85                tokio::select! {
86                    _ = cancel_clone.cancelled() => {
87                        let _ = event_tx.send(AgentEvent::Complete(AgentResult {
88                            text: text_parts.join(""),
89                            status: AgentStatus::Cancelled,
90                            tool_calls,
91                            usage: None,
92                        })).await;
93                        break;
94                    }
95                    event = stream_rx.recv() => {
96                        match event {
97                            Some(stream_event) => {
98                                let agent_event = match &stream_event.kind {
99                                    StreamEventKind::TextDelta { text } => {
100                                        text_parts.push(text.clone());
101                                        AgentEvent::TextDelta(text.clone())
102                                    }
103                                    StreamEventKind::ToolStart { tool_name, tool_id, .. } => {
104                                        tool_calls.push(ToolCallRecord {
105                                            id: tool_id.clone(),
106                                            name: tool_name.clone(),
107                                            output: String::new(),
108                                        });
109                                        AgentEvent::ToolCallStart {
110                                            id: tool_id.clone(),
111                                            name: tool_name.clone(),
112                                        }
113                                    }
114                                    StreamEventKind::ToolResult { tool_id, success, .. } => {
115                                        if let Some(record) = tool_calls.iter_mut().find(|r| r.id == *tool_id) {
116                                            record.output = if *success { "ok".into() } else { "error".into() };
117                                        }
118                                        AgentEvent::ToolCallEnd {
119                                            id: tool_id.clone(),
120                                            output: if *success { "ok".into() } else { "error".into() },
121                                        }
122                                    }
123                                    StreamEventKind::Complete { success } => {
124                                        let status = if *success {
125                                            AgentStatus::Completed
126                                        } else {
127                                            AgentStatus::Failed("Agent reported failure".into())
128                                        };
129                                        let _ = event_tx.send(AgentEvent::Complete(AgentResult {
130                                            text: text_parts.join(""),
131                                            status,
132                                            tool_calls: tool_calls.clone(),
133                                            usage: None,
134                                        })).await;
135                                        break;
136                                    }
137                                    StreamEventKind::Error { message } => {
138                                        AgentEvent::Error(message.clone())
139                                    }
140                                    StreamEventKind::SessionAssigned { .. } => continue,
141                                };
142                                if event_tx.send(agent_event).await.is_err() {
143                                    break;
144                                }
145                            }
146                            None => {
147                                let _ = event_tx.send(AgentEvent::Complete(AgentResult {
148                                    text: text_parts.join(""),
149                                    status: AgentStatus::Completed,
150                                    tool_calls,
151                                    usage: None,
152                                })).await;
153                                break;
154                            }
155                        }
156                    }
157                }
158            }
159        });
160
161        Ok(AgentHandle { events: rx, cancel })
162    }
163}
164
165// Stub when direct-api feature is not enabled
166#[cfg(not(feature = "direct-api"))]
167pub struct DirectApiBackend;
168
169#[cfg(not(feature = "direct-api"))]
170impl DirectApiBackend {
171    pub fn new() -> Self {
172        Self
173    }
174}
175
176#[cfg(not(feature = "direct-api"))]
177#[async_trait]
178impl AgentBackend for DirectApiBackend {
179    async fn execute(&self, _req: AgentRequest) -> Result<AgentHandle> {
180        anyhow::bail!("Direct API backend requires the 'direct-api' feature to be enabled")
181    }
182}