steer_core/api/
provider.rs1use async_trait::async_trait;
2use serde::{Deserialize, Serialize};
3use std::fmt::Debug;
4use std::sync::Arc;
5use tokio_util::sync::CancellationToken;
6
7use crate::api::error::ApiError;
8use crate::app::conversation::{AssistantContent, Message};
9use crate::auth::{AuthStorage, DynAuthenticationFlow};
10use steer_tools::{ToolCall, ToolSchema};
11
12use super::Model;
13
14#[derive(Debug, Clone, Serialize, Deserialize)]
16pub struct CompletionResponse {
17 pub content: Vec<AssistantContent>,
18}
19
20impl CompletionResponse {
21 pub fn extract_text(&self) -> String {
23 self.content
24 .iter()
25 .filter_map(|block| {
26 if let AssistantContent::Text { text } = block {
27 Some(text.clone())
28 } else {
29 None
30 }
31 })
32 .collect::<Vec<String>>()
33 .join("")
34 }
35
36 pub fn has_tool_calls(&self) -> bool {
38 self.content
39 .iter()
40 .any(|block| matches!(block, AssistantContent::ToolCall { .. }))
41 }
42
43 pub fn extract_tool_calls(&self) -> Vec<ToolCall> {
44 self.content
45 .iter()
46 .filter_map(|block| {
47 if let AssistantContent::ToolCall { tool_call } = block {
48 Some(tool_call.clone())
49 } else {
50 None
51 }
52 })
53 .collect()
54 }
55}
56
57#[async_trait]
59pub trait Provider: Send + Sync + 'static {
60 fn name(&self) -> &'static str;
62
63 async fn complete(
65 &self,
66 model: Model,
67 messages: Vec<Message>,
68 system: Option<String>,
69 tools: Option<Vec<ToolSchema>>,
70 token: CancellationToken,
71 ) -> Result<CompletionResponse, ApiError>;
72
73 fn create_auth_flow(
76 &self,
77 _storage: Arc<dyn AuthStorage>,
78 ) -> Option<Box<dyn DynAuthenticationFlow>> {
79 None
80 }
81}