Skip to main content

spider_agent/llm/
mod.rs

1//! LLM provider abstractions for spider_agent.
2
3#[cfg(feature = "openai")]
4mod openai;
5
6#[cfg(feature = "openai")]
7#[allow(unused_imports)]
8pub use openai::{OpenAIProvider, OpenAiApiMode};
9
10use crate::error::AgentResult;
11use async_trait::async_trait;
12use serde::{Deserialize, Serialize};
13
14/// LLM provider trait for abstracting different LLM APIs.
15#[async_trait]
16pub trait LLMProvider: Send + Sync {
17    /// Send a completion request and return the response text.
18    async fn complete(
19        &self,
20        messages: Vec<Message>,
21        options: &CompletionOptions,
22        client: &reqwest::Client,
23    ) -> AgentResult<CompletionResponse>;
24
25    /// Provider name for logging/debugging.
26    fn provider_name(&self) -> &'static str;
27
28    /// Check if the provider is properly configured.
29    fn is_configured(&self) -> bool;
30}
31
32/// A message in a conversation.
33#[derive(Debug, Clone, Serialize, Deserialize)]
34pub struct Message {
35    /// Role: "system", "user", or "assistant".
36    pub role: String,
37    /// Message content.
38    pub content: MessageContent,
39}
40
41impl Message {
42    /// Create a system message.
43    pub fn system(content: impl Into<String>) -> Self {
44        Self {
45            role: "system".to_string(),
46            content: MessageContent::Text(content.into()),
47        }
48    }
49
50    /// Create a user message.
51    pub fn user(content: impl Into<String>) -> Self {
52        Self {
53            role: "user".to_string(),
54            content: MessageContent::Text(content.into()),
55        }
56    }
57
58    /// Create an assistant message.
59    pub fn assistant(content: impl Into<String>) -> Self {
60        Self {
61            role: "assistant".to_string(),
62            content: MessageContent::Text(content.into()),
63        }
64    }
65
66    /// Create a user message with an image.
67    pub fn user_with_image(text: impl Into<String>, image_base64: impl Into<String>) -> Self {
68        Self {
69            role: "user".to_string(),
70            content: MessageContent::MultiPart(vec![
71                ContentPart::Text { text: text.into() },
72                ContentPart::ImageUrl {
73                    image_url: ImageUrl {
74                        url: format!("data:image/png;base64,{}", image_base64.into()),
75                    },
76                },
77            ]),
78        }
79    }
80}
81
82/// Message content - either text or multi-part.
83#[derive(Debug, Clone, Serialize, Deserialize)]
84#[serde(untagged)]
85pub enum MessageContent {
86    /// Plain text content.
87    Text(String),
88    /// Multi-part content (text + images).
89    MultiPart(Vec<ContentPart>),
90}
91
92impl MessageContent {
93    /// Get text content as a string reference.
94    ///
95    /// For multi-part content, returns the concatenated text parts.
96    pub fn as_text(&self) -> &str {
97        match self {
98            Self::Text(s) => s,
99            Self::MultiPart(parts) => {
100                // Return first text part, or empty string
101                for part in parts {
102                    if let ContentPart::Text { text } = part {
103                        return text;
104                    }
105                }
106                ""
107            }
108        }
109    }
110
111    /// Get the full text from all text parts.
112    pub fn full_text(&self) -> String {
113        match self {
114            Self::Text(s) => s.clone(),
115            Self::MultiPart(parts) => parts
116                .iter()
117                .filter_map(|p| {
118                    if let ContentPart::Text { text } = p {
119                        Some(text.as_str())
120                    } else {
121                        None
122                    }
123                })
124                .collect::<Vec<_>>()
125                .join(" "),
126        }
127    }
128
129    /// Check if this is text-only content.
130    pub fn is_text(&self) -> bool {
131        matches!(self, Self::Text(_))
132    }
133
134    /// Check if this contains images.
135    pub fn has_images(&self) -> bool {
136        match self {
137            Self::Text(_) => false,
138            Self::MultiPart(parts) => parts
139                .iter()
140                .any(|p| matches!(p, ContentPart::ImageUrl { .. })),
141        }
142    }
143}
144
145/// A part of multi-part content.
146#[derive(Debug, Clone, Serialize, Deserialize)]
147#[serde(tag = "type", rename_all = "snake_case")]
148pub enum ContentPart {
149    /// Text part.
150    Text { text: String },
151    /// Image URL part.
152    ImageUrl { image_url: ImageUrl },
153}
154
155/// Image URL for vision models.
156#[derive(Debug, Clone, Serialize, Deserialize)]
157pub struct ImageUrl {
158    /// URL (can be data URL with base64).
159    pub url: String,
160}
161
162/// Options for completion requests.
163#[derive(Debug, Clone)]
164pub struct CompletionOptions {
165    /// Temperature (0.0 - 2.0).
166    pub temperature: f32,
167    /// Max tokens to generate.
168    pub max_tokens: u16,
169    /// Request JSON output.
170    pub json_mode: bool,
171}
172
173impl Default for CompletionOptions {
174    fn default() -> Self {
175        Self {
176            temperature: 0.1,
177            max_tokens: 4096,
178            json_mode: true,
179        }
180    }
181}
182
183/// Response from a completion request.
184#[derive(Debug, Clone)]
185pub struct CompletionResponse {
186    /// The generated text.
187    pub content: String,
188    /// Token usage.
189    pub usage: TokenUsage,
190}
191
192/// Token usage from a completion.
193#[derive(Debug, Clone, Default, Serialize, Deserialize)]
194pub struct TokenUsage {
195    /// Prompt tokens.
196    pub prompt_tokens: u32,
197    /// Completion tokens.
198    pub completion_tokens: u32,
199    /// Total tokens.
200    pub total_tokens: u32,
201}
202
203impl TokenUsage {
204    /// Accumulate usage from another.
205    pub fn accumulate(&mut self, other: &TokenUsage) {
206        self.prompt_tokens += other.prompt_tokens;
207        self.completion_tokens += other.completion_tokens;
208        self.total_tokens += other.total_tokens;
209    }
210}