steer_core/tools/
model_caller_impl.rs1use std::sync::Arc;
2
3use async_trait::async_trait;
4use tokio_util::sync::CancellationToken;
5
6use crate::api::Client as ApiClient;
7use crate::app::SystemContext;
8use crate::app::conversation::{Message, MessageData};
9use crate::config::model::ModelId;
10use crate::tools::services::{ModelCallError, ModelCaller};
11
12pub struct DefaultModelCaller {
13 api_client: Arc<ApiClient>,
14}
15
16impl DefaultModelCaller {
17 pub fn new(api_client: Arc<ApiClient>) -> Self {
18 Self { api_client }
19 }
20}
21
22#[async_trait]
23impl ModelCaller for DefaultModelCaller {
24 async fn call(
25 &self,
26 model: &ModelId,
27 messages: Vec<Message>,
28 system_context: Option<SystemContext>,
29 cancel_token: CancellationToken,
30 ) -> Result<Message, ModelCallError> {
31 let response = self
32 .api_client
33 .complete(model, messages, system_context, None, None, cancel_token)
34 .await
35 .map_err(|e| ModelCallError::Api(e.to_string()))?;
36
37 let timestamp = Message::current_timestamp();
38 Ok(Message {
39 timestamp,
40 id: Message::generate_id("assistant", timestamp),
41 parent_message_id: None,
42 data: MessageData::Assistant {
43 content: response.content,
44 },
45 })
46 }
47}