steer_core/app/domain/runtime/
interpreter.rs1use std::sync::Arc;
2
3use futures_util::StreamExt;
4use tokio::sync::mpsc;
5use tokio_util::sync::CancellationToken;
6
7use crate::api::Client as ApiClient;
8use crate::api::provider::{CompletionResponse, StreamChunk};
9use crate::app::SystemContext;
10use crate::app::conversation::Message;
11use crate::app::domain::delta::{StreamDelta, ToolCallDelta};
12use crate::app::domain::types::{MessageId, OpId, SessionId, ToolCallId};
13use crate::config::model::ModelId;
14use crate::tools::{SessionMcpBackends, ToolExecutor};
15use steer_tools::{ToolCall, ToolError, ToolResult, ToolSchema};
16
17#[derive(Clone)]
18pub struct EffectInterpreter {
19 api_client: Arc<ApiClient>,
20 tool_executor: Arc<ToolExecutor>,
21 session_id: Option<SessionId>,
22 session_backends: Option<Arc<SessionMcpBackends>>,
23}
24
25pub(crate) struct DeltaStreamContext {
26 tx: mpsc::Sender<StreamDelta>,
27 context: (OpId, MessageId),
28}
29
30impl DeltaStreamContext {
31 pub(crate) fn new(tx: mpsc::Sender<StreamDelta>, context: (OpId, MessageId)) -> Self {
32 Self { tx, context }
33 }
34}
35
36impl EffectInterpreter {
37 pub fn new(api_client: Arc<ApiClient>, tool_executor: Arc<ToolExecutor>) -> Self {
38 Self {
39 api_client,
40 tool_executor,
41 session_id: None,
42 session_backends: None,
43 }
44 }
45
46 pub fn with_session(mut self, session_id: SessionId) -> Self {
47 self.session_id = Some(session_id);
48 self
49 }
50
51 pub fn with_session_backends(mut self, backends: Arc<SessionMcpBackends>) -> Self {
52 self.session_backends = Some(backends);
53 self
54 }
55
56 pub fn model_context_window_tokens(&self, model: &ModelId) -> Option<u32> {
57 self.api_client.model_context_window_tokens(model)
58 }
59
60 pub fn model_max_output_tokens(&self, model: &ModelId) -> Option<u32> {
61 self.api_client.model_max_output_tokens(model)
62 }
63
64 pub async fn call_model(
65 &self,
66 model: ModelId,
67 messages: Vec<Message>,
68 system_context: Option<SystemContext>,
69 tools: Vec<ToolSchema>,
70 cancel_token: CancellationToken,
71 ) -> Result<CompletionResponse, String> {
72 self.call_model_with_deltas(model, messages, system_context, tools, cancel_token, None)
73 .await
74 }
75
76 pub(crate) async fn call_model_with_deltas(
77 &self,
78 model: ModelId,
79 messages: Vec<Message>,
80 system_context: Option<SystemContext>,
81 tools: Vec<ToolSchema>,
82 cancel_token: CancellationToken,
83 delta_stream: Option<DeltaStreamContext>,
84 ) -> Result<CompletionResponse, String> {
85 let tools_option = if tools.is_empty() { None } else { Some(tools) };
86
87 let mut stream = self
88 .api_client
89 .stream_complete(
90 &model,
91 messages,
92 system_context,
93 tools_option,
94 None,
95 cancel_token,
96 )
97 .await
98 .map_err(|e| e.to_string())?;
99
100 let mut final_response = None;
101 while let Some(chunk) = stream.next().await {
102 match chunk {
103 StreamChunk::TextDelta(text) => {
104 if let Some(delta_stream) = &delta_stream {
105 let (op_id, message_id) = &delta_stream.context;
106 let delta = StreamDelta::TextChunk {
107 op_id: *op_id,
108 message_id: message_id.clone(),
109 delta: text,
110 };
111 let _ = delta_stream.tx.send(delta).await;
112 }
113 }
114 StreamChunk::ThinkingDelta(thinking) => {
115 if let Some(delta_stream) = &delta_stream {
116 let (op_id, message_id) = &delta_stream.context;
117 let delta = StreamDelta::ThinkingChunk {
118 op_id: *op_id,
119 message_id: message_id.clone(),
120 delta: thinking,
121 };
122 let _ = delta_stream.tx.send(delta).await;
123 }
124 }
125 StreamChunk::ToolUseInputDelta { id, delta } => {
126 if let Some(delta_stream) = &delta_stream {
127 let (op_id, message_id) = &delta_stream.context;
128 let delta = StreamDelta::ToolCallChunk {
129 op_id: *op_id,
130 message_id: message_id.clone(),
131 tool_call_id: ToolCallId::from_string(&id),
132 delta: ToolCallDelta::ArgumentChunk(delta),
133 };
134 let _ = delta_stream.tx.send(delta).await;
135 }
136 }
137 StreamChunk::Reset => {
138 if let Some(delta_stream) = &delta_stream {
139 let (op_id, message_id) = &delta_stream.context;
140 let delta = StreamDelta::Reset {
141 op_id: *op_id,
142 message_id: message_id.clone(),
143 };
144 let _ = delta_stream.tx.send(delta).await;
145 }
146 }
147 StreamChunk::MessageComplete(response) => {
148 final_response = Some(response);
149 }
150 StreamChunk::Error(err) => {
151 return Err(err.to_string());
152 }
153 StreamChunk::ToolUseStart { .. } | StreamChunk::ContentBlockStop { .. } => {}
154 }
155 }
156
157 final_response.ok_or_else(|| "Stream ended without MessageComplete".to_string())
158 }
159
160 pub async fn execute_tool(
161 &self,
162 tool_call: ToolCall,
163 invoking_model: Option<ModelId>,
164 cancel_token: CancellationToken,
165 ) -> Result<ToolResult, ToolError> {
166 let resolver = self
167 .session_backends
168 .as_ref()
169 .map(|b| b.as_ref() as &dyn crate::tools::BackendResolver);
170
171 if let Some(session_id) = self.session_id {
172 self.tool_executor
173 .execute_tool_with_session_resolver(
174 &tool_call,
175 session_id,
176 invoking_model,
177 cancel_token,
178 resolver,
179 )
180 .await
181 } else {
182 self.tool_executor
183 .execute_tool_with_resolver(&tool_call, cancel_token, resolver)
184 .await
185 }
186 }
187
188 pub async fn get_tool_schemas(&self) -> Vec<ToolSchema> {
189 let resolver = self
190 .session_backends
191 .as_ref()
192 .map(|b| b.as_ref() as &dyn crate::tools::BackendResolver);
193
194 self.tool_executor
195 .get_tool_schemas_with_resolver(resolver)
196 .await
197 }
198}
199
200#[cfg(test)]
201mod tests {
202 use super::*;
203 use crate::api::error::ApiError;
204 use crate::api::provider::{CompletionResponse, Provider, TokenUsage};
205 use crate::app::conversation::AssistantContent;
206 use crate::app::validation::ValidatorRegistry;
207 use crate::auth::ProviderRegistry;
208 use crate::config::model::{ModelId, ModelParameters};
209 use crate::config::provider::ProviderId;
210 use crate::model_registry::ModelRegistry;
211 use crate::tools::BackendRegistry;
212 use async_trait::async_trait;
213
214 #[derive(Clone)]
215 struct StubProvider;
216
217 #[async_trait]
218 impl Provider for StubProvider {
219 fn name(&self) -> &'static str {
220 "stub"
221 }
222
223 async fn complete(
224 &self,
225 _model_id: &ModelId,
226 _messages: Vec<Message>,
227 _system: Option<SystemContext>,
228 _tools: Option<Vec<ToolSchema>>,
229 _call_options: Option<ModelParameters>,
230 _token: CancellationToken,
231 ) -> Result<CompletionResponse, ApiError> {
232 Ok(CompletionResponse {
233 content: vec![AssistantContent::Text {
234 text: "ok".to_string(),
235 }],
236 usage: Some(TokenUsage::new(5, 7, 12)),
237 })
238 }
239 }
240
241 async fn create_test_deps() -> (Arc<ApiClient>, Arc<ToolExecutor>) {
242 let model_registry = Arc::new(ModelRegistry::load(&[]).expect("model registry"));
243 let provider_registry = Arc::new(ProviderRegistry::load(&[]).expect("provider registry"));
244 let api_client = Arc::new(ApiClient::new_with_deps(
245 crate::test_utils::test_llm_config_provider().unwrap(),
246 provider_registry,
247 model_registry,
248 ));
249
250 let tool_executor = Arc::new(ToolExecutor::with_components(
251 Arc::new(BackendRegistry::new()),
252 Arc::new(ValidatorRegistry::new()),
253 ));
254
255 (api_client, tool_executor)
256 }
257
258 #[tokio::test]
259 async fn call_model_preserves_completion_usage() {
260 let (api_client, tool_executor) = create_test_deps().await;
261 let provider_id = ProviderId("stub".to_string());
262 api_client.insert_test_provider(provider_id.clone(), Arc::new(StubProvider));
263
264 let interpreter = EffectInterpreter::new(api_client, tool_executor);
265 let result = interpreter
266 .call_model(
267 ModelId::new(provider_id, "stub-model"),
268 vec![],
269 None,
270 vec![],
271 CancellationToken::new(),
272 )
273 .await
274 .expect("model call should succeed");
275
276 assert_eq!(result.usage, Some(TokenUsage::new(5, 7, 12)));
277 assert!(matches!(
278 result.content.as_slice(),
279 [AssistantContent::Text { text }] if text == "ok"
280 ));
281 }
282}