Skip to main content

steer_core/tools/
model_caller_impl.rs

1use std::sync::Arc;
2
3use async_trait::async_trait;
4use tokio_util::sync::CancellationToken;
5
6use crate::api::Client as ApiClient;
7use crate::app::SystemContext;
8use crate::app::conversation::{Message, MessageData};
9use crate::config::model::ModelId;
10use crate::tools::services::{ModelCallError, ModelCaller};
11
12pub struct DefaultModelCaller {
13    api_client: Arc<ApiClient>,
14}
15
16impl DefaultModelCaller {
17    pub fn new(api_client: Arc<ApiClient>) -> Self {
18        Self { api_client }
19    }
20}
21
22#[async_trait]
23impl ModelCaller for DefaultModelCaller {
24    async fn call(
25        &self,
26        model: &ModelId,
27        messages: Vec<Message>,
28        system_context: Option<SystemContext>,
29        cancel_token: CancellationToken,
30    ) -> Result<Message, ModelCallError> {
31        let response = self
32            .api_client
33            .complete(model, messages, system_context, None, None, cancel_token)
34            .await
35            .map_err(|e| ModelCallError::Api(e.to_string()))?;
36
37        let timestamp = Message::current_timestamp();
38        Ok(Message {
39            timestamp,
40            id: Message::generate_id("assistant", timestamp),
41            parent_message_id: None,
42            data: MessageData::Assistant {
43                content: response.content,
44            },
45        })
46    }
47}