Skip to main content

tuitbot_core/llm/
mod.rs

1//! LLM provider abstraction and implementations.
2//!
3//! Provides a trait-based abstraction for LLM providers (OpenAI, Anthropic, Ollama)
4//! with typed responses, token usage tracking, and health checking.
5
6pub mod anthropic;
7pub mod factory;
8pub mod openai_compat;
9pub mod pricing;
10
11use crate::error::LlmError;
12
13/// Token usage information from an LLM completion.
14#[derive(Debug, Clone, Default, serde::Serialize, serde::Deserialize)]
15pub struct TokenUsage {
16    /// Number of tokens in the input/prompt.
17    pub input_tokens: u32,
18    /// Number of tokens in the output/completion.
19    pub output_tokens: u32,
20}
21
22impl TokenUsage {
23    /// Accumulate token counts from another usage record (e.g. across retries).
24    pub fn accumulate(&mut self, other: &TokenUsage) {
25        self.input_tokens += other.input_tokens;
26        self.output_tokens += other.output_tokens;
27    }
28}
29
30/// Response from an LLM completion request.
31#[derive(Debug, Clone)]
32pub struct LlmResponse {
33    /// The generated text content.
34    pub text: String,
35    /// Token usage for this completion.
36    pub usage: TokenUsage,
37    /// The model that produced this response.
38    pub model: String,
39}
40
41/// Parameters controlling LLM generation behavior.
42#[derive(Debug, Clone)]
43pub struct GenerationParams {
44    /// Maximum number of tokens to generate.
45    pub max_tokens: u32,
46    /// Sampling temperature (0.0 = deterministic, 1.0+ = creative).
47    pub temperature: f32,
48    /// Optional system prompt override. If `Some`, replaces the caller's system prompt.
49    pub system_prompt: Option<String>,
50}
51
52impl Default for GenerationParams {
53    fn default() -> Self {
54        Self {
55            max_tokens: 512,
56            temperature: 0.7,
57            system_prompt: None,
58        }
59    }
60}
61
62/// Trait abstracting all LLM provider operations.
63///
64/// Implementations include `OpenAiCompatProvider` (for OpenAI and Ollama)
65/// and `AnthropicProvider`. The trait is object-safe for use as `Box<dyn LlmProvider>`.
66#[async_trait::async_trait]
67pub trait LlmProvider: Send + Sync {
68    /// Returns the display name of this provider (e.g., "openai", "anthropic", "ollama").
69    fn name(&self) -> &str;
70
71    /// Send a completion request to the LLM.
72    ///
73    /// If `params.system_prompt` is `Some`, it overrides the `system` parameter.
74    async fn complete(
75        &self,
76        system: &str,
77        user_message: &str,
78        params: &GenerationParams,
79    ) -> Result<LlmResponse, LlmError>;
80
81    /// Check if the provider is reachable and configured correctly.
82    async fn health_check(&self) -> Result<(), LlmError>;
83}
84
85#[cfg(test)]
86mod tests {
87    use super::*;
88
89    #[test]
90    fn token_usage_default_is_zero() {
91        let usage = TokenUsage::default();
92        assert_eq!(usage.input_tokens, 0);
93        assert_eq!(usage.output_tokens, 0);
94    }
95
96    #[test]
97    fn token_usage_accumulate() {
98        let mut total = TokenUsage {
99            input_tokens: 100,
100            output_tokens: 50,
101        };
102        let other = TokenUsage {
103            input_tokens: 200,
104            output_tokens: 80,
105        };
106        total.accumulate(&other);
107        assert_eq!(total.input_tokens, 300);
108        assert_eq!(total.output_tokens, 130);
109    }
110
111    #[test]
112    fn token_usage_accumulate_multiple() {
113        let mut total = TokenUsage::default();
114        for i in 1..=5 {
115            total.accumulate(&TokenUsage {
116                input_tokens: i * 10,
117                output_tokens: i * 5,
118            });
119        }
120        // Sum of 10+20+30+40+50 = 150, sum of 5+10+15+20+25 = 75
121        assert_eq!(total.input_tokens, 150);
122        assert_eq!(total.output_tokens, 75);
123    }
124
125    #[test]
126    fn token_usage_accumulate_zero() {
127        let mut total = TokenUsage {
128            input_tokens: 42,
129            output_tokens: 17,
130        };
131        total.accumulate(&TokenUsage::default());
132        assert_eq!(total.input_tokens, 42);
133        assert_eq!(total.output_tokens, 17);
134    }
135
136    #[test]
137    fn generation_params_default() {
138        let params = GenerationParams::default();
139        assert_eq!(params.max_tokens, 512);
140        assert!((params.temperature - 0.7).abs() < f32::EPSILON);
141        assert!(params.system_prompt.is_none());
142    }
143
144    #[test]
145    fn generation_params_with_system_prompt() {
146        let params = GenerationParams {
147            system_prompt: Some("You are a helpful assistant.".to_string()),
148            ..Default::default()
149        };
150        assert_eq!(
151            params.system_prompt.as_deref(),
152            Some("You are a helpful assistant.")
153        );
154        assert_eq!(params.max_tokens, 512);
155    }
156
157    #[test]
158    fn llm_response_fields() {
159        let response = LlmResponse {
160            text: "Hello, world!".to_string(),
161            usage: TokenUsage {
162                input_tokens: 10,
163                output_tokens: 3,
164            },
165            model: "gpt-4o-mini".to_string(),
166        };
167        assert_eq!(response.text, "Hello, world!");
168        assert_eq!(response.usage.input_tokens, 10);
169        assert_eq!(response.usage.output_tokens, 3);
170        assert_eq!(response.model, "gpt-4o-mini");
171    }
172
173    #[test]
174    fn token_usage_serde_roundtrip() {
175        let usage = TokenUsage {
176            input_tokens: 100,
177            output_tokens: 50,
178        };
179        let json = serde_json::to_string(&usage).expect("serialize");
180        let deserialized: TokenUsage = serde_json::from_str(&json).expect("deserialize");
181        assert_eq!(deserialized.input_tokens, 100);
182        assert_eq!(deserialized.output_tokens, 50);
183    }
184
185    #[test]
186    fn token_usage_clone() {
187        let usage = TokenUsage {
188            input_tokens: 42,
189            output_tokens: 17,
190        };
191        let cloned = usage.clone();
192        assert_eq!(cloned.input_tokens, 42);
193        assert_eq!(cloned.output_tokens, 17);
194    }
195
196    #[test]
197    fn generation_params_clone() {
198        let params = GenerationParams {
199            max_tokens: 1024,
200            temperature: 0.5,
201            system_prompt: Some("test".to_string()),
202        };
203        let cloned = params.clone();
204        assert_eq!(cloned.max_tokens, 1024);
205        assert!((cloned.temperature - 0.5).abs() < f32::EPSILON);
206        assert_eq!(cloned.system_prompt.as_deref(), Some("test"));
207    }
208}