steer_core/api/
provider.rs1use async_trait::async_trait;
2use futures_core::Stream;
3use serde::{Deserialize, Serialize};
4use std::fmt::Debug;
5use std::pin::Pin;
6use std::sync::Arc;
7use tokio_util::sync::CancellationToken;
8
9use crate::api::error::{ApiError, StreamError};
10use crate::app::SystemContext;
11use crate::app::conversation::{AssistantContent, Message};
12use crate::auth::{AuthStorage, DynAuthenticationFlow};
13use crate::config::model::{ModelId, ModelParameters};
14use steer_tools::{ToolCall, ToolSchema};
15
16#[derive(Debug, Clone)]
17pub enum StreamChunk {
18 TextDelta(String),
19 ThinkingDelta(String),
20 ToolUseStart {
21 id: String,
22 name: String,
23 },
24 ToolUseInputDelta {
25 id: String,
26 delta: String,
27 },
28 ContentBlockStop {
29 index: usize,
30 },
31 Reset,
33 MessageComplete(CompletionResponse),
34 Error(StreamError),
35}
36
37pub type CompletionStream = Pin<Box<dyn Stream<Item = StreamChunk> + Send>>;
38
39#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
40pub struct TokenUsage {
41 pub input_tokens: u32,
42 pub output_tokens: u32,
43 pub total_tokens: u32,
44}
45
46impl TokenUsage {
47 pub const fn new(input_tokens: u32, output_tokens: u32, total_tokens: u32) -> Self {
48 Self {
49 input_tokens,
50 output_tokens,
51 total_tokens,
52 }
53 }
54
55 pub const fn from_input_output(input_tokens: u32, output_tokens: u32) -> Self {
56 Self::new(
57 input_tokens,
58 output_tokens,
59 input_tokens.saturating_add(output_tokens),
60 )
61 }
62}
63
64#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
65pub struct CompletionResponse {
66 pub content: Vec<AssistantContent>,
67 #[serde(default, skip_serializing_if = "Option::is_none")]
68 pub usage: Option<TokenUsage>,
69}
70
71impl CompletionResponse {
72 pub fn new(content: Vec<AssistantContent>) -> Self {
73 Self {
74 content,
75 usage: None,
76 }
77 }
78
79 pub fn with_usage(mut self, usage: TokenUsage) -> Self {
80 self.usage = Some(usage);
81 self
82 }
83
84 pub fn extract_text(&self) -> String {
86 self.content
87 .iter()
88 .filter_map(|block| {
89 if let AssistantContent::Text { text } = block {
90 Some(text.clone())
91 } else {
92 None
93 }
94 })
95 .collect::<String>()
96 }
97
98 pub fn has_tool_calls(&self) -> bool {
100 self.content
101 .iter()
102 .any(|block| matches!(block, AssistantContent::ToolCall { .. }))
103 }
104
105 pub fn extract_tool_calls(&self) -> Vec<ToolCall> {
106 self.content
107 .iter()
108 .filter_map(|block| {
109 if let AssistantContent::ToolCall { tool_call, .. } = block {
110 Some(tool_call.clone())
111 } else {
112 None
113 }
114 })
115 .collect()
116 }
117}
118
119#[async_trait]
120pub trait Provider: Send + Sync + 'static {
121 fn name(&self) -> &'static str;
122
123 async fn complete(
124 &self,
125 model_id: &ModelId,
126 messages: Vec<Message>,
127 system: Option<SystemContext>,
128 tools: Option<Vec<ToolSchema>>,
129 call_options: Option<ModelParameters>,
130 token: CancellationToken,
131 ) -> Result<CompletionResponse, ApiError>;
132
133 async fn stream_complete(
134 &self,
135 model_id: &ModelId,
136 messages: Vec<Message>,
137 system: Option<SystemContext>,
138 tools: Option<Vec<ToolSchema>>,
139 call_options: Option<ModelParameters>,
140 token: CancellationToken,
141 ) -> Result<CompletionStream, ApiError> {
142 let response = self
143 .complete(model_id, messages, system, tools, call_options, token)
144 .await?;
145 Ok(Box::pin(futures_util::stream::once(async move {
146 StreamChunk::MessageComplete(response)
147 })))
148 }
149
150 fn create_auth_flow(
151 &self,
152 _storage: Arc<dyn AuthStorage>,
153 ) -> Option<Box<dyn DynAuthenticationFlow>> {
154 None
155 }
156}