Skip to main content

steer_core/api/
provider.rs

1use async_trait::async_trait;
2use futures_core::Stream;
3use serde::{Deserialize, Serialize};
4use std::fmt::Debug;
5use std::pin::Pin;
6use std::sync::Arc;
7use tokio_util::sync::CancellationToken;
8
9use crate::api::error::{ApiError, StreamError};
10use crate::app::SystemContext;
11use crate::app::conversation::{AssistantContent, Message};
12use crate::auth::{AuthStorage, DynAuthenticationFlow};
13use crate::config::model::{ModelId, ModelParameters};
14use steer_tools::{ToolCall, ToolSchema};
15
16#[derive(Debug, Clone)]
17pub enum StreamChunk {
18    TextDelta(String),
19    ThinkingDelta(String),
20    ToolUseStart { id: String, name: String },
21    ToolUseInputDelta { id: String, delta: String },
22    ContentBlockStop { index: usize },
23    MessageComplete(CompletionResponse),
24    Error(StreamError),
25}
26
27pub type CompletionStream = Pin<Box<dyn Stream<Item = StreamChunk> + Send>>;
28
29#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
30pub struct TokenUsage {
31    pub input_tokens: u32,
32    pub output_tokens: u32,
33    pub total_tokens: u32,
34}
35
36impl TokenUsage {
37    pub const fn new(input_tokens: u32, output_tokens: u32, total_tokens: u32) -> Self {
38        Self {
39            input_tokens,
40            output_tokens,
41            total_tokens,
42        }
43    }
44
45    pub const fn from_input_output(input_tokens: u32, output_tokens: u32) -> Self {
46        Self::new(
47            input_tokens,
48            output_tokens,
49            input_tokens.saturating_add(output_tokens),
50        )
51    }
52}
53
54#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
55pub struct CompletionResponse {
56    pub content: Vec<AssistantContent>,
57    #[serde(default, skip_serializing_if = "Option::is_none")]
58    pub usage: Option<TokenUsage>,
59}
60
61impl CompletionResponse {
62    pub fn new(content: Vec<AssistantContent>) -> Self {
63        Self {
64            content,
65            usage: None,
66        }
67    }
68
69    pub fn with_usage(mut self, usage: TokenUsage) -> Self {
70        self.usage = Some(usage);
71        self
72    }
73
74    /// Extract all text content from the response
75    pub fn extract_text(&self) -> String {
76        self.content
77            .iter()
78            .filter_map(|block| {
79                if let AssistantContent::Text { text } = block {
80                    Some(text.clone())
81                } else {
82                    None
83                }
84            })
85            .collect::<String>()
86    }
87
88    /// Check if the response contains any tool calls
89    pub fn has_tool_calls(&self) -> bool {
90        self.content
91            .iter()
92            .any(|block| matches!(block, AssistantContent::ToolCall { .. }))
93    }
94
95    pub fn extract_tool_calls(&self) -> Vec<ToolCall> {
96        self.content
97            .iter()
98            .filter_map(|block| {
99                if let AssistantContent::ToolCall { tool_call, .. } = block {
100                    Some(tool_call.clone())
101                } else {
102                    None
103                }
104            })
105            .collect()
106    }
107}
108
109#[async_trait]
110pub trait Provider: Send + Sync + 'static {
111    fn name(&self) -> &'static str;
112
113    async fn complete(
114        &self,
115        model_id: &ModelId,
116        messages: Vec<Message>,
117        system: Option<SystemContext>,
118        tools: Option<Vec<ToolSchema>>,
119        call_options: Option<ModelParameters>,
120        token: CancellationToken,
121    ) -> Result<CompletionResponse, ApiError>;
122
123    async fn stream_complete(
124        &self,
125        model_id: &ModelId,
126        messages: Vec<Message>,
127        system: Option<SystemContext>,
128        tools: Option<Vec<ToolSchema>>,
129        call_options: Option<ModelParameters>,
130        token: CancellationToken,
131    ) -> Result<CompletionStream, ApiError> {
132        let response = self
133            .complete(model_id, messages, system, tools, call_options, token)
134            .await?;
135        Ok(Box::pin(futures_util::stream::once(async move {
136            StreamChunk::MessageComplete(response)
137        })))
138    }
139
140    fn create_auth_flow(
141        &self,
142        _storage: Arc<dyn AuthStorage>,
143    ) -> Option<Box<dyn DynAuthenticationFlow>> {
144        None
145    }
146}