steer_core/api/
provider.rs1use async_trait::async_trait;
2use futures_core::Stream;
3use serde::{Deserialize, Serialize};
4use std::fmt::Debug;
5use std::pin::Pin;
6use std::sync::Arc;
7use tokio_util::sync::CancellationToken;
8
9use crate::api::error::{ApiError, StreamError};
10use crate::app::SystemContext;
11use crate::app::conversation::{AssistantContent, Message};
12use crate::auth::{AuthStorage, DynAuthenticationFlow};
13use crate::config::model::{ModelId, ModelParameters};
14use steer_tools::{ToolCall, ToolSchema};
15
16#[derive(Debug, Clone)]
17pub enum StreamChunk {
18 TextDelta(String),
19 ThinkingDelta(String),
20 ToolUseStart { id: String, name: String },
21 ToolUseInputDelta { id: String, delta: String },
22 ContentBlockStop { index: usize },
23 MessageComplete(CompletionResponse),
24 Error(StreamError),
25}
26
27pub type CompletionStream = Pin<Box<dyn Stream<Item = StreamChunk> + Send>>;
28
29#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
30pub struct TokenUsage {
31 pub input_tokens: u32,
32 pub output_tokens: u32,
33 pub total_tokens: u32,
34}
35
36impl TokenUsage {
37 pub const fn new(input_tokens: u32, output_tokens: u32, total_tokens: u32) -> Self {
38 Self {
39 input_tokens,
40 output_tokens,
41 total_tokens,
42 }
43 }
44
45 pub const fn from_input_output(input_tokens: u32, output_tokens: u32) -> Self {
46 Self::new(
47 input_tokens,
48 output_tokens,
49 input_tokens.saturating_add(output_tokens),
50 )
51 }
52}
53
54#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
55pub struct CompletionResponse {
56 pub content: Vec<AssistantContent>,
57 #[serde(default, skip_serializing_if = "Option::is_none")]
58 pub usage: Option<TokenUsage>,
59}
60
61impl CompletionResponse {
62 pub fn new(content: Vec<AssistantContent>) -> Self {
63 Self {
64 content,
65 usage: None,
66 }
67 }
68
69 pub fn with_usage(mut self, usage: TokenUsage) -> Self {
70 self.usage = Some(usage);
71 self
72 }
73
74 pub fn extract_text(&self) -> String {
76 self.content
77 .iter()
78 .filter_map(|block| {
79 if let AssistantContent::Text { text } = block {
80 Some(text.clone())
81 } else {
82 None
83 }
84 })
85 .collect::<String>()
86 }
87
88 pub fn has_tool_calls(&self) -> bool {
90 self.content
91 .iter()
92 .any(|block| matches!(block, AssistantContent::ToolCall { .. }))
93 }
94
95 pub fn extract_tool_calls(&self) -> Vec<ToolCall> {
96 self.content
97 .iter()
98 .filter_map(|block| {
99 if let AssistantContent::ToolCall { tool_call, .. } = block {
100 Some(tool_call.clone())
101 } else {
102 None
103 }
104 })
105 .collect()
106 }
107}
108
109#[async_trait]
110pub trait Provider: Send + Sync + 'static {
111 fn name(&self) -> &'static str;
112
113 async fn complete(
114 &self,
115 model_id: &ModelId,
116 messages: Vec<Message>,
117 system: Option<SystemContext>,
118 tools: Option<Vec<ToolSchema>>,
119 call_options: Option<ModelParameters>,
120 token: CancellationToken,
121 ) -> Result<CompletionResponse, ApiError>;
122
123 async fn stream_complete(
124 &self,
125 model_id: &ModelId,
126 messages: Vec<Message>,
127 system: Option<SystemContext>,
128 tools: Option<Vec<ToolSchema>>,
129 call_options: Option<ModelParameters>,
130 token: CancellationToken,
131 ) -> Result<CompletionStream, ApiError> {
132 let response = self
133 .complete(model_id, messages, system, tools, call_options, token)
134 .await?;
135 Ok(Box::pin(futures_util::stream::once(async move {
136 StreamChunk::MessageComplete(response)
137 })))
138 }
139
140 fn create_auth_flow(
141 &self,
142 _storage: Arc<dyn AuthStorage>,
143 ) -> Option<Box<dyn DynAuthenticationFlow>> {
144 None
145 }
146}