1pub mod anthropic;
7pub mod embedding;
8pub mod embedding_factory;
9pub mod factory;
10pub mod ollama_embedding;
11pub mod openai_compat;
12pub mod openai_embedding;
13pub mod pricing;
14
15use crate::error::LlmError;
16
17#[derive(Debug, Clone, Default, serde::Serialize, serde::Deserialize)]
19pub struct TokenUsage {
20 pub input_tokens: u32,
22 pub output_tokens: u32,
24}
25
26impl TokenUsage {
27 pub fn accumulate(&mut self, other: &TokenUsage) {
29 self.input_tokens += other.input_tokens;
30 self.output_tokens += other.output_tokens;
31 }
32}
33
34#[derive(Debug, Clone)]
36pub struct LlmResponse {
37 pub text: String,
39 pub usage: TokenUsage,
41 pub model: String,
43}
44
45#[derive(Debug, Clone)]
47pub struct GenerationParams {
48 pub max_tokens: u32,
50 pub temperature: f32,
52 pub system_prompt: Option<String>,
54}
55
56impl Default for GenerationParams {
57 fn default() -> Self {
58 Self {
59 max_tokens: 512,
60 temperature: 0.7,
61 system_prompt: None,
62 }
63 }
64}
65
66#[async_trait::async_trait]
71pub trait LlmProvider: Send + Sync {
72 fn name(&self) -> &str;
74
75 async fn complete(
79 &self,
80 system: &str,
81 user_message: &str,
82 params: &GenerationParams,
83 ) -> Result<LlmResponse, LlmError>;
84
85 async fn health_check(&self) -> Result<(), LlmError>;
87}
88
89#[cfg(test)]
90mod tests {
91 use super::*;
92
93 #[test]
94 fn token_usage_default_is_zero() {
95 let usage = TokenUsage::default();
96 assert_eq!(usage.input_tokens, 0);
97 assert_eq!(usage.output_tokens, 0);
98 }
99
100 #[test]
101 fn token_usage_accumulate() {
102 let mut total = TokenUsage {
103 input_tokens: 100,
104 output_tokens: 50,
105 };
106 let other = TokenUsage {
107 input_tokens: 200,
108 output_tokens: 80,
109 };
110 total.accumulate(&other);
111 assert_eq!(total.input_tokens, 300);
112 assert_eq!(total.output_tokens, 130);
113 }
114
115 #[test]
116 fn token_usage_accumulate_multiple() {
117 let mut total = TokenUsage::default();
118 for i in 1..=5 {
119 total.accumulate(&TokenUsage {
120 input_tokens: i * 10,
121 output_tokens: i * 5,
122 });
123 }
124 assert_eq!(total.input_tokens, 150);
126 assert_eq!(total.output_tokens, 75);
127 }
128
129 #[test]
130 fn token_usage_accumulate_zero() {
131 let mut total = TokenUsage {
132 input_tokens: 42,
133 output_tokens: 17,
134 };
135 total.accumulate(&TokenUsage::default());
136 assert_eq!(total.input_tokens, 42);
137 assert_eq!(total.output_tokens, 17);
138 }
139
140 #[test]
141 fn generation_params_default() {
142 let params = GenerationParams::default();
143 assert_eq!(params.max_tokens, 512);
144 assert!((params.temperature - 0.7).abs() < f32::EPSILON);
145 assert!(params.system_prompt.is_none());
146 }
147
148 #[test]
149 fn generation_params_with_system_prompt() {
150 let params = GenerationParams {
151 system_prompt: Some("You are a helpful assistant.".to_string()),
152 ..Default::default()
153 };
154 assert_eq!(
155 params.system_prompt.as_deref(),
156 Some("You are a helpful assistant.")
157 );
158 assert_eq!(params.max_tokens, 512);
159 }
160
161 #[test]
162 fn llm_response_fields() {
163 let response = LlmResponse {
164 text: "Hello, world!".to_string(),
165 usage: TokenUsage {
166 input_tokens: 10,
167 output_tokens: 3,
168 },
169 model: "gpt-4o-mini".to_string(),
170 };
171 assert_eq!(response.text, "Hello, world!");
172 assert_eq!(response.usage.input_tokens, 10);
173 assert_eq!(response.usage.output_tokens, 3);
174 assert_eq!(response.model, "gpt-4o-mini");
175 }
176
177 #[test]
178 fn token_usage_serde_roundtrip() {
179 let usage = TokenUsage {
180 input_tokens: 100,
181 output_tokens: 50,
182 };
183 let json = serde_json::to_string(&usage).expect("serialize");
184 let deserialized: TokenUsage = serde_json::from_str(&json).expect("deserialize");
185 assert_eq!(deserialized.input_tokens, 100);
186 assert_eq!(deserialized.output_tokens, 50);
187 }
188
189 #[test]
190 fn token_usage_clone() {
191 let usage = TokenUsage {
192 input_tokens: 42,
193 output_tokens: 17,
194 };
195 let cloned = usage.clone();
196 assert_eq!(cloned.input_tokens, 42);
197 assert_eq!(cloned.output_tokens, 17);
198 }
199
200 #[test]
201 fn generation_params_clone() {
202 let params = GenerationParams {
203 max_tokens: 1024,
204 temperature: 0.5,
205 system_prompt: Some("test".to_string()),
206 };
207 let cloned = params.clone();
208 assert_eq!(cloned.max_tokens, 1024);
209 assert!((cloned.temperature - 0.5).abs() < f32::EPSILON);
210 assert_eq!(cloned.system_prompt.as_deref(), Some("test"));
211 }
212}