Skip to main content

steer_core/app/domain/runtime/
interpreter.rs

1use std::sync::Arc;
2
3use futures_util::StreamExt;
4use tokio::sync::mpsc;
5use tokio_util::sync::CancellationToken;
6
7use crate::api::Client as ApiClient;
8use crate::api::provider::{CompletionResponse, StreamChunk};
9use crate::app::SystemContext;
10use crate::app::conversation::Message;
11use crate::app::domain::delta::{StreamDelta, ToolCallDelta};
12use crate::app::domain::types::{MessageId, OpId, SessionId, ToolCallId};
13use crate::config::model::ModelId;
14use crate::tools::{SessionMcpBackends, ToolExecutor};
15use steer_tools::{ToolCall, ToolError, ToolResult, ToolSchema};
16
17#[derive(Clone)]
18pub struct EffectInterpreter {
19    api_client: Arc<ApiClient>,
20    tool_executor: Arc<ToolExecutor>,
21    session_id: Option<SessionId>,
22    session_backends: Option<Arc<SessionMcpBackends>>,
23}
24
25pub(crate) struct DeltaStreamContext {
26    tx: mpsc::Sender<StreamDelta>,
27    context: (OpId, MessageId),
28}
29
30impl DeltaStreamContext {
31    pub(crate) fn new(tx: mpsc::Sender<StreamDelta>, context: (OpId, MessageId)) -> Self {
32        Self { tx, context }
33    }
34}
35
36impl EffectInterpreter {
37    pub fn new(api_client: Arc<ApiClient>, tool_executor: Arc<ToolExecutor>) -> Self {
38        Self {
39            api_client,
40            tool_executor,
41            session_id: None,
42            session_backends: None,
43        }
44    }
45
46    pub fn with_session(mut self, session_id: SessionId) -> Self {
47        self.session_id = Some(session_id);
48        self
49    }
50
51    pub fn with_session_backends(mut self, backends: Arc<SessionMcpBackends>) -> Self {
52        self.session_backends = Some(backends);
53        self
54    }
55
56    pub fn model_context_window_tokens(&self, model: &ModelId) -> Option<u32> {
57        self.api_client.model_context_window_tokens(model)
58    }
59
60    pub async fn call_model(
61        &self,
62        model: ModelId,
63        messages: Vec<Message>,
64        system_context: Option<SystemContext>,
65        tools: Vec<ToolSchema>,
66        cancel_token: CancellationToken,
67    ) -> Result<CompletionResponse, String> {
68        self.call_model_with_deltas(model, messages, system_context, tools, cancel_token, None)
69            .await
70    }
71
72    pub(crate) async fn call_model_with_deltas(
73        &self,
74        model: ModelId,
75        messages: Vec<Message>,
76        system_context: Option<SystemContext>,
77        tools: Vec<ToolSchema>,
78        cancel_token: CancellationToken,
79        delta_stream: Option<DeltaStreamContext>,
80    ) -> Result<CompletionResponse, String> {
81        let tools_option = if tools.is_empty() { None } else { Some(tools) };
82
83        let mut stream = self
84            .api_client
85            .stream_complete(
86                &model,
87                messages,
88                system_context,
89                tools_option,
90                None,
91                cancel_token,
92            )
93            .await
94            .map_err(|e| e.to_string())?;
95
96        let mut final_response = None;
97        while let Some(chunk) = stream.next().await {
98            match chunk {
99                StreamChunk::TextDelta(text) => {
100                    if let Some(delta_stream) = &delta_stream {
101                        let (op_id, message_id) = &delta_stream.context;
102                        let delta = StreamDelta::TextChunk {
103                            op_id: *op_id,
104                            message_id: message_id.clone(),
105                            delta: text,
106                        };
107                        let _ = delta_stream.tx.send(delta).await;
108                    }
109                }
110                StreamChunk::ThinkingDelta(thinking) => {
111                    if let Some(delta_stream) = &delta_stream {
112                        let (op_id, message_id) = &delta_stream.context;
113                        let delta = StreamDelta::ThinkingChunk {
114                            op_id: *op_id,
115                            message_id: message_id.clone(),
116                            delta: thinking,
117                        };
118                        let _ = delta_stream.tx.send(delta).await;
119                    }
120                }
121                StreamChunk::ToolUseInputDelta { id, delta } => {
122                    if let Some(delta_stream) = &delta_stream {
123                        let (op_id, message_id) = &delta_stream.context;
124                        let delta = StreamDelta::ToolCallChunk {
125                            op_id: *op_id,
126                            message_id: message_id.clone(),
127                            tool_call_id: ToolCallId::from_string(&id),
128                            delta: ToolCallDelta::ArgumentChunk(delta),
129                        };
130                        let _ = delta_stream.tx.send(delta).await;
131                    }
132                }
133                StreamChunk::MessageComplete(response) => {
134                    final_response = Some(response);
135                }
136                StreamChunk::Error(err) => {
137                    return Err(err.to_string());
138                }
139                StreamChunk::ToolUseStart { .. } | StreamChunk::ContentBlockStop { .. } => {}
140            }
141        }
142
143        final_response.ok_or_else(|| "Stream ended without MessageComplete".to_string())
144    }
145
146    pub async fn execute_tool(
147        &self,
148        tool_call: ToolCall,
149        cancel_token: CancellationToken,
150    ) -> Result<ToolResult, ToolError> {
151        let resolver = self
152            .session_backends
153            .as_ref()
154            .map(|b| b.as_ref() as &dyn crate::tools::BackendResolver);
155
156        if let Some(session_id) = self.session_id {
157            self.tool_executor
158                .execute_tool_with_session_resolver(&tool_call, session_id, cancel_token, resolver)
159                .await
160        } else {
161            self.tool_executor
162                .execute_tool_with_resolver(&tool_call, cancel_token, resolver)
163                .await
164        }
165    }
166
167    pub async fn get_tool_schemas(&self) -> Vec<ToolSchema> {
168        let resolver = self
169            .session_backends
170            .as_ref()
171            .map(|b| b.as_ref() as &dyn crate::tools::BackendResolver);
172
173        self.tool_executor
174            .get_tool_schemas_with_resolver(resolver)
175            .await
176    }
177}
178
179#[cfg(test)]
180mod tests {
181    use super::*;
182    use crate::api::error::ApiError;
183    use crate::api::provider::{CompletionResponse, Provider, TokenUsage};
184    use crate::app::conversation::AssistantContent;
185    use crate::app::validation::ValidatorRegistry;
186    use crate::auth::ProviderRegistry;
187    use crate::config::model::{ModelId, ModelParameters};
188    use crate::config::provider::ProviderId;
189    use crate::model_registry::ModelRegistry;
190    use crate::tools::BackendRegistry;
191    use async_trait::async_trait;
192
193    #[derive(Clone)]
194    struct StubProvider;
195
196    #[async_trait]
197    impl Provider for StubProvider {
198        fn name(&self) -> &'static str {
199            "stub"
200        }
201
202        async fn complete(
203            &self,
204            _model_id: &ModelId,
205            _messages: Vec<Message>,
206            _system: Option<SystemContext>,
207            _tools: Option<Vec<ToolSchema>>,
208            _call_options: Option<ModelParameters>,
209            _token: CancellationToken,
210        ) -> Result<CompletionResponse, ApiError> {
211            Ok(CompletionResponse {
212                content: vec![AssistantContent::Text {
213                    text: "ok".to_string(),
214                }],
215                usage: Some(TokenUsage::new(5, 7, 12)),
216            })
217        }
218    }
219
220    async fn create_test_deps() -> (Arc<ApiClient>, Arc<ToolExecutor>) {
221        let model_registry = Arc::new(ModelRegistry::load(&[]).expect("model registry"));
222        let provider_registry = Arc::new(ProviderRegistry::load(&[]).expect("provider registry"));
223        let api_client = Arc::new(ApiClient::new_with_deps(
224            crate::test_utils::test_llm_config_provider().unwrap(),
225            provider_registry,
226            model_registry,
227        ));
228
229        let tool_executor = Arc::new(ToolExecutor::with_components(
230            Arc::new(BackendRegistry::new()),
231            Arc::new(ValidatorRegistry::new()),
232        ));
233
234        (api_client, tool_executor)
235    }
236
237    #[tokio::test]
238    async fn call_model_preserves_completion_usage() {
239        let (api_client, tool_executor) = create_test_deps().await;
240        let provider_id = ProviderId("stub".to_string());
241        api_client.insert_test_provider(provider_id.clone(), Arc::new(StubProvider));
242
243        let interpreter = EffectInterpreter::new(api_client, tool_executor);
244        let result = interpreter
245            .call_model(
246                ModelId::new(provider_id, "stub-model"),
247                vec![],
248                None,
249                vec![],
250                CancellationToken::new(),
251            )
252            .await
253            .expect("model call should succeed");
254
255        assert_eq!(result.usage, Some(TokenUsage::new(5, 7, 12)));
256        assert!(matches!(
257            result.content.as_slice(),
258            [AssistantContent::Text { text }] if text == "ok"
259        ));
260    }
261}