Skip to main content

steer_core/api/
provider.rs

1use async_trait::async_trait;
2use futures_core::Stream;
3use serde::{Deserialize, Serialize};
4use std::fmt::Debug;
5use std::pin::Pin;
6use std::sync::Arc;
7use tokio_util::sync::CancellationToken;
8
9use crate::api::error::{ApiError, StreamError};
10use crate::app::SystemContext;
11use crate::app::conversation::{AssistantContent, Message};
12use crate::auth::{AuthStorage, DynAuthenticationFlow};
13use crate::config::model::{ModelId, ModelParameters};
14use steer_tools::{ToolCall, ToolSchema};
15
16#[derive(Debug, Clone)]
17pub enum StreamChunk {
18    TextDelta(String),
19    ThinkingDelta(String),
20    ToolUseStart {
21        id: String,
22        name: String,
23    },
24    ToolUseInputDelta {
25        id: String,
26        delta: String,
27    },
28    ContentBlockStop {
29        index: usize,
30    },
31    /// Signal to clear any in-progress streamed content for this message before restarting.
32    Reset,
33    MessageComplete(CompletionResponse),
34    Error(StreamError),
35}
36
37pub type CompletionStream = Pin<Box<dyn Stream<Item = StreamChunk> + Send>>;
38
39#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
40pub struct TokenUsage {
41    pub input_tokens: u32,
42    pub output_tokens: u32,
43    pub total_tokens: u32,
44}
45
46impl TokenUsage {
47    pub const fn new(input_tokens: u32, output_tokens: u32, total_tokens: u32) -> Self {
48        Self {
49            input_tokens,
50            output_tokens,
51            total_tokens,
52        }
53    }
54
55    pub const fn from_input_output(input_tokens: u32, output_tokens: u32) -> Self {
56        Self::new(
57            input_tokens,
58            output_tokens,
59            input_tokens.saturating_add(output_tokens),
60        )
61    }
62}
63
64#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
65pub struct CompletionResponse {
66    pub content: Vec<AssistantContent>,
67    #[serde(default, skip_serializing_if = "Option::is_none")]
68    pub usage: Option<TokenUsage>,
69}
70
71impl CompletionResponse {
72    pub fn new(content: Vec<AssistantContent>) -> Self {
73        Self {
74            content,
75            usage: None,
76        }
77    }
78
79    pub fn with_usage(mut self, usage: TokenUsage) -> Self {
80        self.usage = Some(usage);
81        self
82    }
83
84    /// Extract all text content from the response
85    pub fn extract_text(&self) -> String {
86        self.content
87            .iter()
88            .filter_map(|block| {
89                if let AssistantContent::Text { text } = block {
90                    Some(text.clone())
91                } else {
92                    None
93                }
94            })
95            .collect::<String>()
96    }
97
98    /// Check if the response contains any tool calls
99    pub fn has_tool_calls(&self) -> bool {
100        self.content
101            .iter()
102            .any(|block| matches!(block, AssistantContent::ToolCall { .. }))
103    }
104
105    pub fn extract_tool_calls(&self) -> Vec<ToolCall> {
106        self.content
107            .iter()
108            .filter_map(|block| {
109                if let AssistantContent::ToolCall { tool_call, .. } = block {
110                    Some(tool_call.clone())
111                } else {
112                    None
113                }
114            })
115            .collect()
116    }
117}
118
119#[async_trait]
120pub trait Provider: Send + Sync + 'static {
121    fn name(&self) -> &'static str;
122
123    async fn complete(
124        &self,
125        model_id: &ModelId,
126        messages: Vec<Message>,
127        system: Option<SystemContext>,
128        tools: Option<Vec<ToolSchema>>,
129        call_options: Option<ModelParameters>,
130        token: CancellationToken,
131    ) -> Result<CompletionResponse, ApiError>;
132
133    async fn stream_complete(
134        &self,
135        model_id: &ModelId,
136        messages: Vec<Message>,
137        system: Option<SystemContext>,
138        tools: Option<Vec<ToolSchema>>,
139        call_options: Option<ModelParameters>,
140        token: CancellationToken,
141    ) -> Result<CompletionStream, ApiError> {
142        let response = self
143            .complete(model_id, messages, system, tools, call_options, token)
144            .await?;
145        Ok(Box::pin(futures_util::stream::once(async move {
146            StreamChunk::MessageComplete(response)
147        })))
148    }
149
150    fn create_auth_flow(
151        &self,
152        _storage: Arc<dyn AuthStorage>,
153    ) -> Option<Box<dyn DynAuthenticationFlow>> {
154        None
155    }
156}