Skip to main content

tuitbot_core/llm/
mod.rs

1//! LLM provider abstraction and implementations.
2//!
3//! Provides a trait-based abstraction for LLM providers (OpenAI, Anthropic, Ollama)
4//! with typed responses, token usage tracking, and health checking.
5
6pub mod anthropic;
7pub mod embedding;
8pub mod embedding_factory;
9pub mod factory;
10pub mod ollama_embedding;
11pub mod openai_compat;
12pub mod openai_embedding;
13pub mod pricing;
14
15use crate::error::LlmError;
16
17/// Token usage information from an LLM completion.
18#[derive(Debug, Clone, Default, serde::Serialize, serde::Deserialize)]
19pub struct TokenUsage {
20    /// Number of tokens in the input/prompt.
21    pub input_tokens: u32,
22    /// Number of tokens in the output/completion.
23    pub output_tokens: u32,
24}
25
26impl TokenUsage {
27    /// Accumulate token counts from another usage record (e.g. across retries).
28    pub fn accumulate(&mut self, other: &TokenUsage) {
29        self.input_tokens += other.input_tokens;
30        self.output_tokens += other.output_tokens;
31    }
32}
33
34/// Response from an LLM completion request.
35#[derive(Debug, Clone)]
36pub struct LlmResponse {
37    /// The generated text content.
38    pub text: String,
39    /// Token usage for this completion.
40    pub usage: TokenUsage,
41    /// The model that produced this response.
42    pub model: String,
43}
44
45/// Parameters controlling LLM generation behavior.
46#[derive(Debug, Clone)]
47pub struct GenerationParams {
48    /// Maximum number of tokens to generate.
49    pub max_tokens: u32,
50    /// Sampling temperature (0.0 = deterministic, 1.0+ = creative).
51    pub temperature: f32,
52    /// Optional system prompt override. If `Some`, replaces the caller's system prompt.
53    pub system_prompt: Option<String>,
54}
55
56impl Default for GenerationParams {
57    fn default() -> Self {
58        Self {
59            max_tokens: 512,
60            temperature: 0.7,
61            system_prompt: None,
62        }
63    }
64}
65
66/// Trait abstracting all LLM provider operations.
67///
68/// Implementations include `OpenAiCompatProvider` (for OpenAI and Ollama)
69/// and `AnthropicProvider`. The trait is object-safe for use as `Box<dyn LlmProvider>`.
70#[async_trait::async_trait]
71pub trait LlmProvider: Send + Sync {
72    /// Returns the display name of this provider (e.g., "openai", "anthropic", "ollama").
73    fn name(&self) -> &str;
74
75    /// Send a completion request to the LLM.
76    ///
77    /// If `params.system_prompt` is `Some`, it overrides the `system` parameter.
78    async fn complete(
79        &self,
80        system: &str,
81        user_message: &str,
82        params: &GenerationParams,
83    ) -> Result<LlmResponse, LlmError>;
84
85    /// Check if the provider is reachable and configured correctly.
86    async fn health_check(&self) -> Result<(), LlmError>;
87}
88
89#[cfg(test)]
90mod tests {
91    use super::*;
92
93    #[test]
94    fn token_usage_default_is_zero() {
95        let usage = TokenUsage::default();
96        assert_eq!(usage.input_tokens, 0);
97        assert_eq!(usage.output_tokens, 0);
98    }
99
100    #[test]
101    fn token_usage_accumulate() {
102        let mut total = TokenUsage {
103            input_tokens: 100,
104            output_tokens: 50,
105        };
106        let other = TokenUsage {
107            input_tokens: 200,
108            output_tokens: 80,
109        };
110        total.accumulate(&other);
111        assert_eq!(total.input_tokens, 300);
112        assert_eq!(total.output_tokens, 130);
113    }
114
115    #[test]
116    fn token_usage_accumulate_multiple() {
117        let mut total = TokenUsage::default();
118        for i in 1..=5 {
119            total.accumulate(&TokenUsage {
120                input_tokens: i * 10,
121                output_tokens: i * 5,
122            });
123        }
124        // Sum of 10+20+30+40+50 = 150, sum of 5+10+15+20+25 = 75
125        assert_eq!(total.input_tokens, 150);
126        assert_eq!(total.output_tokens, 75);
127    }
128
129    #[test]
130    fn token_usage_accumulate_zero() {
131        let mut total = TokenUsage {
132            input_tokens: 42,
133            output_tokens: 17,
134        };
135        total.accumulate(&TokenUsage::default());
136        assert_eq!(total.input_tokens, 42);
137        assert_eq!(total.output_tokens, 17);
138    }
139
140    #[test]
141    fn generation_params_default() {
142        let params = GenerationParams::default();
143        assert_eq!(params.max_tokens, 512);
144        assert!((params.temperature - 0.7).abs() < f32::EPSILON);
145        assert!(params.system_prompt.is_none());
146    }
147
148    #[test]
149    fn generation_params_with_system_prompt() {
150        let params = GenerationParams {
151            system_prompt: Some("You are a helpful assistant.".to_string()),
152            ..Default::default()
153        };
154        assert_eq!(
155            params.system_prompt.as_deref(),
156            Some("You are a helpful assistant.")
157        );
158        assert_eq!(params.max_tokens, 512);
159    }
160
161    #[test]
162    fn llm_response_fields() {
163        let response = LlmResponse {
164            text: "Hello, world!".to_string(),
165            usage: TokenUsage {
166                input_tokens: 10,
167                output_tokens: 3,
168            },
169            model: "gpt-4o-mini".to_string(),
170        };
171        assert_eq!(response.text, "Hello, world!");
172        assert_eq!(response.usage.input_tokens, 10);
173        assert_eq!(response.usage.output_tokens, 3);
174        assert_eq!(response.model, "gpt-4o-mini");
175    }
176
177    #[test]
178    fn token_usage_serde_roundtrip() {
179        let usage = TokenUsage {
180            input_tokens: 100,
181            output_tokens: 50,
182        };
183        let json = serde_json::to_string(&usage).expect("serialize");
184        let deserialized: TokenUsage = serde_json::from_str(&json).expect("deserialize");
185        assert_eq!(deserialized.input_tokens, 100);
186        assert_eq!(deserialized.output_tokens, 50);
187    }
188
189    #[test]
190    fn token_usage_clone() {
191        let usage = TokenUsage {
192            input_tokens: 42,
193            output_tokens: 17,
194        };
195        let cloned = usage.clone();
196        assert_eq!(cloned.input_tokens, 42);
197        assert_eq!(cloned.output_tokens, 17);
198    }
199
200    #[test]
201    fn generation_params_clone() {
202        let params = GenerationParams {
203            max_tokens: 1024,
204            temperature: 0.5,
205            system_prompt: Some("test".to_string()),
206        };
207        let cloned = params.clone();
208        assert_eq!(cloned.max_tokens, 1024);
209        assert!((cloned.temperature - 0.5).abs() < f32::EPSILON);
210        assert_eq!(cloned.system_prompt.as_deref(), Some("test"));
211    }
212}