steer_core/app/domain/runtime/
interpreter.rs1use std::sync::Arc;
2
3use futures_util::StreamExt;
4use tokio::sync::mpsc;
5use tokio_util::sync::CancellationToken;
6
7use crate::api::Client as ApiClient;
8use crate::api::provider::{CompletionResponse, StreamChunk};
9use crate::app::SystemContext;
10use crate::app::conversation::Message;
11use crate::app::domain::delta::{StreamDelta, ToolCallDelta};
12use crate::app::domain::types::{MessageId, OpId, SessionId, ToolCallId};
13use crate::config::model::ModelId;
14use crate::tools::{SessionMcpBackends, ToolExecutor};
15use steer_tools::{ToolCall, ToolError, ToolResult, ToolSchema};
16
17#[derive(Clone)]
18pub struct EffectInterpreter {
19 api_client: Arc<ApiClient>,
20 tool_executor: Arc<ToolExecutor>,
21 session_id: Option<SessionId>,
22 session_backends: Option<Arc<SessionMcpBackends>>,
23}
24
25pub(crate) struct DeltaStreamContext {
26 tx: mpsc::Sender<StreamDelta>,
27 context: (OpId, MessageId),
28}
29
30impl DeltaStreamContext {
31 pub(crate) fn new(tx: mpsc::Sender<StreamDelta>, context: (OpId, MessageId)) -> Self {
32 Self { tx, context }
33 }
34}
35
36impl EffectInterpreter {
37 pub fn new(api_client: Arc<ApiClient>, tool_executor: Arc<ToolExecutor>) -> Self {
38 Self {
39 api_client,
40 tool_executor,
41 session_id: None,
42 session_backends: None,
43 }
44 }
45
46 pub fn with_session(mut self, session_id: SessionId) -> Self {
47 self.session_id = Some(session_id);
48 self
49 }
50
51 pub fn with_session_backends(mut self, backends: Arc<SessionMcpBackends>) -> Self {
52 self.session_backends = Some(backends);
53 self
54 }
55
56 pub fn model_context_window_tokens(&self, model: &ModelId) -> Option<u32> {
57 self.api_client.model_context_window_tokens(model)
58 }
59
60 pub async fn call_model(
61 &self,
62 model: ModelId,
63 messages: Vec<Message>,
64 system_context: Option<SystemContext>,
65 tools: Vec<ToolSchema>,
66 cancel_token: CancellationToken,
67 ) -> Result<CompletionResponse, String> {
68 self.call_model_with_deltas(model, messages, system_context, tools, cancel_token, None)
69 .await
70 }
71
72 pub(crate) async fn call_model_with_deltas(
73 &self,
74 model: ModelId,
75 messages: Vec<Message>,
76 system_context: Option<SystemContext>,
77 tools: Vec<ToolSchema>,
78 cancel_token: CancellationToken,
79 delta_stream: Option<DeltaStreamContext>,
80 ) -> Result<CompletionResponse, String> {
81 let tools_option = if tools.is_empty() { None } else { Some(tools) };
82
83 let mut stream = self
84 .api_client
85 .stream_complete(
86 &model,
87 messages,
88 system_context,
89 tools_option,
90 None,
91 cancel_token,
92 )
93 .await
94 .map_err(|e| e.to_string())?;
95
96 let mut final_response = None;
97 while let Some(chunk) = stream.next().await {
98 match chunk {
99 StreamChunk::TextDelta(text) => {
100 if let Some(delta_stream) = &delta_stream {
101 let (op_id, message_id) = &delta_stream.context;
102 let delta = StreamDelta::TextChunk {
103 op_id: *op_id,
104 message_id: message_id.clone(),
105 delta: text,
106 };
107 let _ = delta_stream.tx.send(delta).await;
108 }
109 }
110 StreamChunk::ThinkingDelta(thinking) => {
111 if let Some(delta_stream) = &delta_stream {
112 let (op_id, message_id) = &delta_stream.context;
113 let delta = StreamDelta::ThinkingChunk {
114 op_id: *op_id,
115 message_id: message_id.clone(),
116 delta: thinking,
117 };
118 let _ = delta_stream.tx.send(delta).await;
119 }
120 }
121 StreamChunk::ToolUseInputDelta { id, delta } => {
122 if let Some(delta_stream) = &delta_stream {
123 let (op_id, message_id) = &delta_stream.context;
124 let delta = StreamDelta::ToolCallChunk {
125 op_id: *op_id,
126 message_id: message_id.clone(),
127 tool_call_id: ToolCallId::from_string(&id),
128 delta: ToolCallDelta::ArgumentChunk(delta),
129 };
130 let _ = delta_stream.tx.send(delta).await;
131 }
132 }
133 StreamChunk::MessageComplete(response) => {
134 final_response = Some(response);
135 }
136 StreamChunk::Error(err) => {
137 return Err(err.to_string());
138 }
139 StreamChunk::ToolUseStart { .. } | StreamChunk::ContentBlockStop { .. } => {}
140 }
141 }
142
143 final_response.ok_or_else(|| "Stream ended without MessageComplete".to_string())
144 }
145
146 pub async fn execute_tool(
147 &self,
148 tool_call: ToolCall,
149 cancel_token: CancellationToken,
150 ) -> Result<ToolResult, ToolError> {
151 let resolver = self
152 .session_backends
153 .as_ref()
154 .map(|b| b.as_ref() as &dyn crate::tools::BackendResolver);
155
156 if let Some(session_id) = self.session_id {
157 self.tool_executor
158 .execute_tool_with_session_resolver(&tool_call, session_id, cancel_token, resolver)
159 .await
160 } else {
161 self.tool_executor
162 .execute_tool_with_resolver(&tool_call, cancel_token, resolver)
163 .await
164 }
165 }
166
167 pub async fn get_tool_schemas(&self) -> Vec<ToolSchema> {
168 let resolver = self
169 .session_backends
170 .as_ref()
171 .map(|b| b.as_ref() as &dyn crate::tools::BackendResolver);
172
173 self.tool_executor
174 .get_tool_schemas_with_resolver(resolver)
175 .await
176 }
177}
178
179#[cfg(test)]
180mod tests {
181 use super::*;
182 use crate::api::error::ApiError;
183 use crate::api::provider::{CompletionResponse, Provider, TokenUsage};
184 use crate::app::conversation::AssistantContent;
185 use crate::app::validation::ValidatorRegistry;
186 use crate::auth::ProviderRegistry;
187 use crate::config::model::{ModelId, ModelParameters};
188 use crate::config::provider::ProviderId;
189 use crate::model_registry::ModelRegistry;
190 use crate::tools::BackendRegistry;
191 use async_trait::async_trait;
192
193 #[derive(Clone)]
194 struct StubProvider;
195
196 #[async_trait]
197 impl Provider for StubProvider {
198 fn name(&self) -> &'static str {
199 "stub"
200 }
201
202 async fn complete(
203 &self,
204 _model_id: &ModelId,
205 _messages: Vec<Message>,
206 _system: Option<SystemContext>,
207 _tools: Option<Vec<ToolSchema>>,
208 _call_options: Option<ModelParameters>,
209 _token: CancellationToken,
210 ) -> Result<CompletionResponse, ApiError> {
211 Ok(CompletionResponse {
212 content: vec![AssistantContent::Text {
213 text: "ok".to_string(),
214 }],
215 usage: Some(TokenUsage::new(5, 7, 12)),
216 })
217 }
218 }
219
220 async fn create_test_deps() -> (Arc<ApiClient>, Arc<ToolExecutor>) {
221 let model_registry = Arc::new(ModelRegistry::load(&[]).expect("model registry"));
222 let provider_registry = Arc::new(ProviderRegistry::load(&[]).expect("provider registry"));
223 let api_client = Arc::new(ApiClient::new_with_deps(
224 crate::test_utils::test_llm_config_provider().unwrap(),
225 provider_registry,
226 model_registry,
227 ));
228
229 let tool_executor = Arc::new(ToolExecutor::with_components(
230 Arc::new(BackendRegistry::new()),
231 Arc::new(ValidatorRegistry::new()),
232 ));
233
234 (api_client, tool_executor)
235 }
236
237 #[tokio::test]
238 async fn call_model_preserves_completion_usage() {
239 let (api_client, tool_executor) = create_test_deps().await;
240 let provider_id = ProviderId("stub".to_string());
241 api_client.insert_test_provider(provider_id.clone(), Arc::new(StubProvider));
242
243 let interpreter = EffectInterpreter::new(api_client, tool_executor);
244 let result = interpreter
245 .call_model(
246 ModelId::new(provider_id, "stub-model"),
247 vec![],
248 None,
249 vec![],
250 CancellationToken::new(),
251 )
252 .await
253 .expect("model call should succeed");
254
255 assert_eq!(result.usage, Some(TokenUsage::new(5, 7, 12)));
256 assert!(matches!(
257 result.content.as_slice(),
258 [AssistantContent::Text { text }] if text == "ok"
259 ));
260 }
261}