Skip to main content

steer_core/app/domain/runtime/
interpreter.rs

1use std::sync::Arc;
2
3use futures_util::StreamExt;
4use tokio::sync::mpsc;
5use tokio_util::sync::CancellationToken;
6
7use crate::api::Client as ApiClient;
8use crate::api::provider::{CompletionResponse, StreamChunk};
9use crate::app::SystemContext;
10use crate::app::conversation::Message;
11use crate::app::domain::delta::{StreamDelta, ToolCallDelta};
12use crate::app::domain::types::{MessageId, OpId, SessionId, ToolCallId};
13use crate::config::model::ModelId;
14use crate::tools::{SessionMcpBackends, ToolExecutor};
15use steer_tools::{ToolCall, ToolError, ToolResult, ToolSchema};
16
17#[derive(Clone)]
18pub struct EffectInterpreter {
19    api_client: Arc<ApiClient>,
20    tool_executor: Arc<ToolExecutor>,
21    session_id: Option<SessionId>,
22    session_backends: Option<Arc<SessionMcpBackends>>,
23}
24
25pub(crate) struct DeltaStreamContext {
26    tx: mpsc::Sender<StreamDelta>,
27    context: (OpId, MessageId),
28}
29
30impl DeltaStreamContext {
31    pub(crate) fn new(tx: mpsc::Sender<StreamDelta>, context: (OpId, MessageId)) -> Self {
32        Self { tx, context }
33    }
34}
35
36impl EffectInterpreter {
37    pub fn new(api_client: Arc<ApiClient>, tool_executor: Arc<ToolExecutor>) -> Self {
38        Self {
39            api_client,
40            tool_executor,
41            session_id: None,
42            session_backends: None,
43        }
44    }
45
46    pub fn with_session(mut self, session_id: SessionId) -> Self {
47        self.session_id = Some(session_id);
48        self
49    }
50
51    pub fn with_session_backends(mut self, backends: Arc<SessionMcpBackends>) -> Self {
52        self.session_backends = Some(backends);
53        self
54    }
55
56    pub fn model_context_window_tokens(&self, model: &ModelId) -> Option<u32> {
57        self.api_client.model_context_window_tokens(model)
58    }
59
60    pub fn model_max_output_tokens(&self, model: &ModelId) -> Option<u32> {
61        self.api_client.model_max_output_tokens(model)
62    }
63
64    pub async fn call_model(
65        &self,
66        model: ModelId,
67        messages: Vec<Message>,
68        system_context: Option<SystemContext>,
69        tools: Vec<ToolSchema>,
70        cancel_token: CancellationToken,
71    ) -> Result<CompletionResponse, String> {
72        self.call_model_with_deltas(model, messages, system_context, tools, cancel_token, None)
73            .await
74    }
75
76    pub(crate) async fn call_model_with_deltas(
77        &self,
78        model: ModelId,
79        messages: Vec<Message>,
80        system_context: Option<SystemContext>,
81        tools: Vec<ToolSchema>,
82        cancel_token: CancellationToken,
83        delta_stream: Option<DeltaStreamContext>,
84    ) -> Result<CompletionResponse, String> {
85        let tools_option = if tools.is_empty() { None } else { Some(tools) };
86
87        let mut stream = self
88            .api_client
89            .stream_complete(
90                &model,
91                messages,
92                system_context,
93                tools_option,
94                None,
95                cancel_token,
96            )
97            .await
98            .map_err(|e| e.to_string())?;
99
100        let mut final_response = None;
101        while let Some(chunk) = stream.next().await {
102            match chunk {
103                StreamChunk::TextDelta(text) => {
104                    if let Some(delta_stream) = &delta_stream {
105                        let (op_id, message_id) = &delta_stream.context;
106                        let delta = StreamDelta::TextChunk {
107                            op_id: *op_id,
108                            message_id: message_id.clone(),
109                            delta: text,
110                        };
111                        let _ = delta_stream.tx.send(delta).await;
112                    }
113                }
114                StreamChunk::ThinkingDelta(thinking) => {
115                    if let Some(delta_stream) = &delta_stream {
116                        let (op_id, message_id) = &delta_stream.context;
117                        let delta = StreamDelta::ThinkingChunk {
118                            op_id: *op_id,
119                            message_id: message_id.clone(),
120                            delta: thinking,
121                        };
122                        let _ = delta_stream.tx.send(delta).await;
123                    }
124                }
125                StreamChunk::ToolUseInputDelta { id, delta } => {
126                    if let Some(delta_stream) = &delta_stream {
127                        let (op_id, message_id) = &delta_stream.context;
128                        let delta = StreamDelta::ToolCallChunk {
129                            op_id: *op_id,
130                            message_id: message_id.clone(),
131                            tool_call_id: ToolCallId::from_string(&id),
132                            delta: ToolCallDelta::ArgumentChunk(delta),
133                        };
134                        let _ = delta_stream.tx.send(delta).await;
135                    }
136                }
137                StreamChunk::Reset => {
138                    if let Some(delta_stream) = &delta_stream {
139                        let (op_id, message_id) = &delta_stream.context;
140                        let delta = StreamDelta::Reset {
141                            op_id: *op_id,
142                            message_id: message_id.clone(),
143                        };
144                        let _ = delta_stream.tx.send(delta).await;
145                    }
146                }
147                StreamChunk::MessageComplete(response) => {
148                    final_response = Some(response);
149                }
150                StreamChunk::Error(err) => {
151                    return Err(err.to_string());
152                }
153                StreamChunk::ToolUseStart { .. } | StreamChunk::ContentBlockStop { .. } => {}
154            }
155        }
156
157        final_response.ok_or_else(|| "Stream ended without MessageComplete".to_string())
158    }
159
160    pub async fn execute_tool(
161        &self,
162        tool_call: ToolCall,
163        invoking_model: Option<ModelId>,
164        cancel_token: CancellationToken,
165    ) -> Result<ToolResult, ToolError> {
166        let resolver = self
167            .session_backends
168            .as_ref()
169            .map(|b| b.as_ref() as &dyn crate::tools::BackendResolver);
170
171        if let Some(session_id) = self.session_id {
172            self.tool_executor
173                .execute_tool_with_session_resolver(
174                    &tool_call,
175                    session_id,
176                    invoking_model,
177                    cancel_token,
178                    resolver,
179                )
180                .await
181        } else {
182            self.tool_executor
183                .execute_tool_with_resolver(&tool_call, cancel_token, resolver)
184                .await
185        }
186    }
187
188    pub async fn get_tool_schemas(&self) -> Vec<ToolSchema> {
189        let resolver = self
190            .session_backends
191            .as_ref()
192            .map(|b| b.as_ref() as &dyn crate::tools::BackendResolver);
193
194        self.tool_executor
195            .get_tool_schemas_with_resolver(resolver)
196            .await
197    }
198}
199
200#[cfg(test)]
201mod tests {
202    use super::*;
203    use crate::api::error::ApiError;
204    use crate::api::provider::{CompletionResponse, Provider, TokenUsage};
205    use crate::app::conversation::AssistantContent;
206    use crate::app::validation::ValidatorRegistry;
207    use crate::auth::ProviderRegistry;
208    use crate::config::model::{ModelId, ModelParameters};
209    use crate::config::provider::ProviderId;
210    use crate::model_registry::ModelRegistry;
211    use crate::tools::BackendRegistry;
212    use async_trait::async_trait;
213
214    #[derive(Clone)]
215    struct StubProvider;
216
217    #[async_trait]
218    impl Provider for StubProvider {
219        fn name(&self) -> &'static str {
220            "stub"
221        }
222
223        async fn complete(
224            &self,
225            _model_id: &ModelId,
226            _messages: Vec<Message>,
227            _system: Option<SystemContext>,
228            _tools: Option<Vec<ToolSchema>>,
229            _call_options: Option<ModelParameters>,
230            _token: CancellationToken,
231        ) -> Result<CompletionResponse, ApiError> {
232            Ok(CompletionResponse {
233                content: vec![AssistantContent::Text {
234                    text: "ok".to_string(),
235                }],
236                usage: Some(TokenUsage::new(5, 7, 12)),
237            })
238        }
239    }
240
241    async fn create_test_deps() -> (Arc<ApiClient>, Arc<ToolExecutor>) {
242        let model_registry = Arc::new(ModelRegistry::load(&[]).expect("model registry"));
243        let provider_registry = Arc::new(ProviderRegistry::load(&[]).expect("provider registry"));
244        let api_client = Arc::new(ApiClient::new_with_deps(
245            crate::test_utils::test_llm_config_provider().unwrap(),
246            provider_registry,
247            model_registry,
248        ));
249
250        let tool_executor = Arc::new(ToolExecutor::with_components(
251            Arc::new(BackendRegistry::new()),
252            Arc::new(ValidatorRegistry::new()),
253        ));
254
255        (api_client, tool_executor)
256    }
257
258    #[tokio::test]
259    async fn call_model_preserves_completion_usage() {
260        let (api_client, tool_executor) = create_test_deps().await;
261        let provider_id = ProviderId("stub".to_string());
262        api_client.insert_test_provider(provider_id.clone(), Arc::new(StubProvider));
263
264        let interpreter = EffectInterpreter::new(api_client, tool_executor);
265        let result = interpreter
266            .call_model(
267                ModelId::new(provider_id, "stub-model"),
268                vec![],
269                None,
270                vec![],
271                CancellationToken::new(),
272            )
273            .await
274            .expect("model call should succeed");
275
276        assert_eq!(result.usage, Some(TokenUsage::new(5, 7, 12)));
277        assert!(matches!(
278            result.content.as_slice(),
279            [AssistantContent::Text { text }] if text == "ok"
280        ));
281    }
282}