steer_core/api/
provider.rs

1use async_trait::async_trait;
2use serde::{Deserialize, Serialize};
3use std::fmt::Debug;
4use std::sync::Arc;
5use tokio_util::sync::CancellationToken;
6
7use crate::api::error::ApiError;
8use crate::app::conversation::{AssistantContent, Message};
9use crate::auth::{AuthStorage, DynAuthenticationFlow};
10use crate::config::model::{ModelId, ModelParameters};
11use steer_tools::{ToolCall, ToolSchema};
12
13/// Response from the provider's completion API
14#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
15pub struct CompletionResponse {
16    pub content: Vec<AssistantContent>,
17}
18
19impl CompletionResponse {
20    /// Extract all text content from the response
21    pub fn extract_text(&self) -> String {
22        self.content
23            .iter()
24            .filter_map(|block| {
25                if let AssistantContent::Text { text } = block {
26                    Some(text.clone())
27                } else {
28                    None
29                }
30            })
31            .collect::<Vec<String>>()
32            .join("")
33    }
34
35    /// Check if the response contains any tool calls
36    pub fn has_tool_calls(&self) -> bool {
37        self.content
38            .iter()
39            .any(|block| matches!(block, AssistantContent::ToolCall { .. }))
40    }
41
42    pub fn extract_tool_calls(&self) -> Vec<ToolCall> {
43        self.content
44            .iter()
45            .filter_map(|block| {
46                if let AssistantContent::ToolCall { tool_call } = block {
47                    Some(tool_call.clone())
48                } else {
49                    None
50                }
51            })
52            .collect()
53    }
54}
55
56/// Provider trait that all LLM providers must implement
57#[async_trait]
58pub trait Provider: Send + Sync + 'static {
59    /// Get the name of the provider
60    fn name(&self) -> &'static str;
61
62    /// Complete a prompt with the LLM
63    async fn complete(
64        &self,
65        model_id: &ModelId,
66        messages: Vec<Message>,
67        system: Option<String>,
68        tools: Option<Vec<ToolSchema>>,
69        call_options: Option<ModelParameters>,
70        token: CancellationToken,
71    ) -> Result<CompletionResponse, ApiError>;
72
73    /// Create an authentication flow for this provider
74    /// Returns None if the provider doesn't support authentication
75    fn create_auth_flow(
76        &self,
77        _storage: Arc<dyn AuthStorage>,
78    ) -> Option<Box<dyn DynAuthenticationFlow>> {
79        None
80    }
81}