Skip to main content

steer_core/api/
provider.rs

1use async_trait::async_trait;
2use futures_core::Stream;
3use serde::{Deserialize, Serialize};
4use std::fmt::Debug;
5use std::pin::Pin;
6use std::sync::Arc;
7use tokio_util::sync::CancellationToken;
8
9use crate::api::error::{ApiError, StreamError};
10use crate::app::SystemContext;
11use crate::app::conversation::{AssistantContent, Message};
12use crate::auth::{AuthStorage, DynAuthenticationFlow};
13use crate::config::model::{ModelId, ModelParameters};
14use steer_tools::{ToolCall, ToolSchema};
15
16#[derive(Debug, Clone)]
17pub enum StreamChunk {
18    TextDelta(String),
19    ThinkingDelta(String),
20    ToolUseStart { id: String, name: String },
21    ToolUseInputDelta { id: String, delta: String },
22    ContentBlockStop { index: usize },
23    MessageComplete(CompletionResponse),
24    Error(StreamError),
25}
26
27pub type CompletionStream = Pin<Box<dyn Stream<Item = StreamChunk> + Send>>;
28
29#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
30pub struct CompletionResponse {
31    pub content: Vec<AssistantContent>,
32}
33
34impl CompletionResponse {
35    /// Extract all text content from the response
36    pub fn extract_text(&self) -> String {
37        self.content
38            .iter()
39            .filter_map(|block| {
40                if let AssistantContent::Text { text } = block {
41                    Some(text.clone())
42                } else {
43                    None
44                }
45            })
46            .collect::<String>()
47    }
48
49    /// Check if the response contains any tool calls
50    pub fn has_tool_calls(&self) -> bool {
51        self.content
52            .iter()
53            .any(|block| matches!(block, AssistantContent::ToolCall { .. }))
54    }
55
56    pub fn extract_tool_calls(&self) -> Vec<ToolCall> {
57        self.content
58            .iter()
59            .filter_map(|block| {
60                if let AssistantContent::ToolCall { tool_call, .. } = block {
61                    Some(tool_call.clone())
62                } else {
63                    None
64                }
65            })
66            .collect()
67    }
68}
69
70#[async_trait]
71pub trait Provider: Send + Sync + 'static {
72    fn name(&self) -> &'static str;
73
74    async fn complete(
75        &self,
76        model_id: &ModelId,
77        messages: Vec<Message>,
78        system: Option<SystemContext>,
79        tools: Option<Vec<ToolSchema>>,
80        call_options: Option<ModelParameters>,
81        token: CancellationToken,
82    ) -> Result<CompletionResponse, ApiError>;
83
84    async fn stream_complete(
85        &self,
86        model_id: &ModelId,
87        messages: Vec<Message>,
88        system: Option<SystemContext>,
89        tools: Option<Vec<ToolSchema>>,
90        call_options: Option<ModelParameters>,
91        token: CancellationToken,
92    ) -> Result<CompletionStream, ApiError> {
93        let response = self
94            .complete(model_id, messages, system, tools, call_options, token)
95            .await?;
96        Ok(Box::pin(futures_util::stream::once(async move {
97            StreamChunk::MessageComplete(response)
98        })))
99    }
100
101    fn create_auth_flow(
102        &self,
103        _storage: Arc<dyn AuthStorage>,
104    ) -> Option<Box<dyn DynAuthenticationFlow>> {
105        None
106    }
107}