steer_core/api/
provider.rs

1use async_trait::async_trait;
2use serde::{Deserialize, Serialize};
3use std::fmt::Debug;
4use std::sync::Arc;
5use tokio_util::sync::CancellationToken;
6
7use crate::api::error::ApiError;
8use crate::app::conversation::{AssistantContent, Message};
9use crate::auth::{AuthStorage, DynAuthenticationFlow};
10use steer_tools::{ToolCall, ToolSchema};
11
12use super::Model;
13
14/// Response from the provider's completion API
15#[derive(Debug, Clone, Serialize, Deserialize)]
16pub struct CompletionResponse {
17    pub content: Vec<AssistantContent>,
18}
19
20impl CompletionResponse {
21    /// Extract all text content from the response
22    pub fn extract_text(&self) -> String {
23        self.content
24            .iter()
25            .filter_map(|block| {
26                if let AssistantContent::Text { text } = block {
27                    Some(text.clone())
28                } else {
29                    None
30                }
31            })
32            .collect::<Vec<String>>()
33            .join("")
34    }
35
36    /// Check if the response contains any tool calls
37    pub fn has_tool_calls(&self) -> bool {
38        self.content
39            .iter()
40            .any(|block| matches!(block, AssistantContent::ToolCall { .. }))
41    }
42
43    pub fn extract_tool_calls(&self) -> Vec<ToolCall> {
44        self.content
45            .iter()
46            .filter_map(|block| {
47                if let AssistantContent::ToolCall { tool_call } = block {
48                    Some(tool_call.clone())
49                } else {
50                    None
51                }
52            })
53            .collect()
54    }
55}
56
57/// Provider trait that all LLM providers must implement
58#[async_trait]
59pub trait Provider: Send + Sync + 'static {
60    /// Get the name of the provider
61    fn name(&self) -> &'static str;
62
63    /// Complete a prompt with the LLM
64    async fn complete(
65        &self,
66        model: Model,
67        messages: Vec<Message>,
68        system: Option<String>,
69        tools: Option<Vec<ToolSchema>>,
70        token: CancellationToken,
71    ) -> Result<CompletionResponse, ApiError>;
72
73    /// Create an authentication flow for this provider
74    /// Returns None if the provider doesn't support authentication
75    fn create_auth_flow(
76        &self,
77        _storage: Arc<dyn AuthStorage>,
78    ) -> Option<Box<dyn DynAuthenticationFlow>> {
79        None
80    }
81}