skill_runtime/generation/
llm_provider.rs

1//! LLM Provider abstraction for AI-powered generation
2//!
3//! Provides a unified interface for multiple LLM providers (Ollama, OpenAI, Anthropic)
4//! with streaming support.
5
6use anyhow::Result;
7use async_trait::async_trait;
8use std::pin::Pin;
9use futures_util::Stream;
10
11use crate::search_config::{AiIngestionConfig, AiProvider};
12
13/// Response from an LLM completion
14#[derive(Debug, Clone)]
15pub struct LlmResponse {
16    /// Generated text content
17    pub content: String,
18    /// Model that generated the response
19    pub model: String,
20    /// Token usage statistics (if available)
21    pub usage: Option<TokenUsage>,
22    /// Completion finish reason
23    pub finish_reason: Option<String>,
24}
25
26/// Token usage statistics
27#[derive(Debug, Clone, Default)]
28pub struct TokenUsage {
29    /// Prompt tokens used
30    pub prompt_tokens: u32,
31    /// Completion tokens generated
32    pub completion_tokens: u32,
33    /// Total tokens
34    pub total_tokens: u32,
35}
36
37/// A chunk from streaming completion
38#[derive(Debug, Clone)]
39pub struct LlmChunk {
40    /// Delta text content
41    pub delta: String,
42    /// Whether this is the final chunk
43    pub is_final: bool,
44}
45
46/// Chat message for multi-turn conversations
47#[derive(Debug, Clone)]
48pub struct ChatMessage {
49    /// Role (system, user, assistant)
50    pub role: String,
51    /// Message content
52    pub content: String,
53}
54
55impl ChatMessage {
56    /// Create a system message
57    pub fn system(content: impl Into<String>) -> Self {
58        Self {
59            role: "system".to_string(),
60            content: content.into(),
61        }
62    }
63
64    /// Create a user message
65    pub fn user(content: impl Into<String>) -> Self {
66        Self {
67            role: "user".to_string(),
68            content: content.into(),
69        }
70    }
71
72    /// Create an assistant message
73    pub fn assistant(content: impl Into<String>) -> Self {
74        Self {
75            role: "assistant".to_string(),
76            content: content.into(),
77        }
78    }
79}
80
81/// LLM completion request
82#[derive(Debug, Clone)]
83pub struct CompletionRequest {
84    /// Messages for chat completion
85    pub messages: Vec<ChatMessage>,
86    /// Temperature (0.0-1.0)
87    pub temperature: Option<f32>,
88    /// Maximum tokens to generate
89    pub max_tokens: Option<u32>,
90    /// Stop sequences
91    pub stop: Option<Vec<String>>,
92}
93
94impl CompletionRequest {
95    /// Create a new completion request with a single user prompt
96    pub fn new(prompt: impl Into<String>) -> Self {
97        Self {
98            messages: vec![ChatMessage::user(prompt)],
99            temperature: None,
100            max_tokens: None,
101            stop: None,
102        }
103    }
104
105    /// Create a request with system prompt and user message
106    pub fn with_system(system: impl Into<String>, user: impl Into<String>) -> Self {
107        Self {
108            messages: vec![
109                ChatMessage::system(system),
110                ChatMessage::user(user),
111            ],
112            temperature: None,
113            max_tokens: None,
114            stop: None,
115        }
116    }
117
118    /// Set temperature
119    pub fn temperature(mut self, temp: f32) -> Self {
120        self.temperature = Some(temp.clamp(0.0, 2.0));
121        self
122    }
123
124    /// Set max tokens
125    pub fn max_tokens(mut self, max: u32) -> Self {
126        self.max_tokens = Some(max);
127        self
128    }
129
130    /// Add stop sequences
131    pub fn stop(mut self, sequences: Vec<String>) -> Self {
132        self.stop = Some(sequences);
133        self
134    }
135}
136
137/// Trait for LLM providers
138#[async_trait]
139pub trait LlmProvider: Send + Sync {
140    /// Get provider name
141    fn name(&self) -> &str;
142
143    /// Get model name
144    fn model(&self) -> &str;
145
146    /// Generate a completion (non-streaming)
147    async fn complete(&self, request: &CompletionRequest) -> Result<LlmResponse>;
148
149    /// Generate a streaming completion
150    async fn complete_stream(
151        &self,
152        request: &CompletionRequest,
153    ) -> Result<Pin<Box<dyn Stream<Item = Result<LlmChunk>> + Send>>>;
154}
155
156// =============================================================================
157// Ollama Provider
158// =============================================================================
159
160#[cfg(feature = "ollama")]
161pub mod ollama {
162    use super::*;
163    use ollama_rs::generation::completion::request::GenerationRequest;
164    use ollama_rs::generation::chat::request::ChatMessageRequest;
165    use ollama_rs::generation::chat::ChatMessage as OllamaMessage;
166    use ollama_rs::Ollama;
167
168    /// Ollama LLM provider for local model inference
169    pub struct OllamaProvider {
170        client: Ollama,
171        model: String,
172    }
173
174    impl OllamaProvider {
175        /// Create a new Ollama provider
176        pub fn new(host: &str, model: &str) -> Result<Self> {
177            // Parse host URL
178            let url = url::Url::parse(host)
179                .with_context(|| format!("Invalid Ollama host URL: {}", host))?;
180
181            let host_str = url.host_str().unwrap_or("localhost");
182            let port = url.port().unwrap_or(11434);
183
184            let client = Ollama::new(format!("http://{}", host_str), port);
185
186            Ok(Self {
187                client,
188                model: model.to_string(),
189            })
190        }
191
192        /// Create from config
193        pub fn from_config(config: &AiIngestionConfig) -> Result<Self> {
194            let model = config.get_model().to_string();
195            Self::new(&config.ollama.host, &model)
196        }
197    }
198
199    #[async_trait]
200    impl LlmProvider for OllamaProvider {
201        fn name(&self) -> &str {
202            "ollama"
203        }
204
205        fn model(&self) -> &str {
206            &self.model
207        }
208
209        async fn complete(&self, request: &CompletionRequest) -> Result<LlmResponse> {
210            // Convert messages to Ollama format
211            let messages: Vec<OllamaMessage> = request
212                .messages
213                .iter()
214                .map(|m| {
215                    let role = match m.role.as_str() {
216                        "system" => ollama_rs::generation::chat::MessageRole::System,
217                        "user" => ollama_rs::generation::chat::MessageRole::User,
218                        "assistant" => ollama_rs::generation::chat::MessageRole::Assistant,
219                        _ => ollama_rs::generation::chat::MessageRole::User,
220                    };
221                    OllamaMessage::new(role, m.content.clone())
222                })
223                .collect();
224
225            let mut chat_request = ChatMessageRequest::new(self.model.clone(), messages);
226
227            // Apply options
228            if let Some(temp) = request.temperature {
229                let options = ollama_rs::generation::options::GenerationOptions::default()
230                    .temperature(temp as f64);
231                chat_request = chat_request.options(options);
232            }
233
234            let response = self.client.send_chat_messages(chat_request).await
235                .context("Ollama chat request failed")?;
236
237            let content = response.message.map(|m| m.content).unwrap_or_default();
238
239            Ok(LlmResponse {
240                content,
241                model: self.model.clone(),
242                usage: None, // Ollama doesn't provide token counts in basic response
243                finish_reason: Some("stop".to_string()),
244            })
245        }
246
247        async fn complete_stream(
248            &self,
249            request: &CompletionRequest,
250        ) -> Result<Pin<Box<dyn Stream<Item = Result<LlmChunk>> + Send>>> {
251            use futures_util::StreamExt;
252            use tokio_stream::wrappers::ReceiverStream;
253
254            let messages: Vec<OllamaMessage> = request
255                .messages
256                .iter()
257                .map(|m| {
258                    let role = match m.role.as_str() {
259                        "system" => ollama_rs::generation::chat::MessageRole::System,
260                        "user" => ollama_rs::generation::chat::MessageRole::User,
261                        "assistant" => ollama_rs::generation::chat::MessageRole::Assistant,
262                        _ => ollama_rs::generation::chat::MessageRole::User,
263                    };
264                    OllamaMessage::new(role, m.content.clone())
265                })
266                .collect();
267
268            let mut chat_request = ChatMessageRequest::new(self.model.clone(), messages);
269
270            if let Some(temp) = request.temperature {
271                let options = ollama_rs::generation::options::GenerationOptions::default()
272                    .temperature(temp as f64);
273                chat_request = chat_request.options(options);
274            }
275
276            let (tx, rx) = tokio::sync::mpsc::channel::<Result<LlmChunk>>(100);
277
278            // Clone for the async task
279            let client = self.client.clone();
280
281            tokio::spawn(async move {
282                let mut stream = match client.send_chat_messages_stream(chat_request).await {
283                    Ok(s) => s,
284                    Err(e) => {
285                        let _ = tx.send(Err(anyhow::anyhow!("Stream error: {}", e))).await;
286                        return;
287                    }
288                };
289
290                while let Some(chunk_result) = stream.next().await {
291                    match chunk_result {
292                        Ok(chunk) => {
293                            let content = chunk.message.map(|m| m.content).unwrap_or_default();
294                            let is_final = chunk.done;
295
296                            if tx.send(Ok(LlmChunk {
297                                delta: content,
298                                is_final,
299                            })).await.is_err() {
300                                break;
301                            }
302                        }
303                        Err(e) => {
304                            let _ = tx.send(Err(anyhow::anyhow!("Chunk error: {}", e))).await;
305                            break;
306                        }
307                    }
308                }
309            });
310
311            Ok(Box::pin(ReceiverStream::new(rx)))
312        }
313    }
314}
315
316// =============================================================================
317// OpenAI Provider
318// =============================================================================
319
320#[cfg(feature = "openai")]
321pub mod openai {
322    use super::*;
323    use async_openai::{
324        types::{
325            ChatCompletionRequestMessage, ChatCompletionRequestSystemMessage,
326            ChatCompletionRequestUserMessage, ChatCompletionRequestAssistantMessage,
327            CreateChatCompletionRequestArgs,
328        },
329        Client,
330    };
331
332    /// OpenAI LLM provider
333    pub struct OpenAIProvider {
334        client: Client<async_openai::config::OpenAIConfig>,
335        model: String,
336    }
337
338    impl OpenAIProvider {
339        /// Create a new OpenAI provider
340        pub fn new(model: &str) -> Result<Self> {
341            // Uses OPENAI_API_KEY from environment by default
342            let client = Client::new();
343            Ok(Self {
344                client,
345                model: model.to_string(),
346            })
347        }
348
349        /// Create with custom API key
350        pub fn with_api_key(api_key: &str, model: &str) -> Result<Self> {
351            let config = async_openai::config::OpenAIConfig::new().with_api_key(api_key);
352            let client = Client::with_config(config);
353            Ok(Self {
354                client,
355                model: model.to_string(),
356            })
357        }
358
359        /// Create from config
360        pub fn from_config(config: &AiIngestionConfig) -> Result<Self> {
361            let model = config.get_model().to_string();
362
363            // Check for API key in environment
364            if let Some(ref env_var) = config.openai.api_key_env {
365                if let Ok(key) = std::env::var(env_var) {
366                    return Self::with_api_key(&key, &model);
367                }
368            }
369
370            // Fallback to default OPENAI_API_KEY
371            Self::new(&model)
372        }
373    }
374
375    #[async_trait]
376    impl LlmProvider for OpenAIProvider {
377        fn name(&self) -> &str {
378            "openai"
379        }
380
381        fn model(&self) -> &str {
382            &self.model
383        }
384
385        async fn complete(&self, request: &CompletionRequest) -> Result<LlmResponse> {
386            let messages: Vec<ChatCompletionRequestMessage> = request
387                .messages
388                .iter()
389                .map(|m| match m.role.as_str() {
390                    "system" => ChatCompletionRequestMessage::System(
391                        ChatCompletionRequestSystemMessage {
392                            content: async_openai::types::ChatCompletionRequestSystemMessageContent::Text(m.content.clone()),
393                            name: None,
394                        }
395                    ),
396                    "assistant" => ChatCompletionRequestMessage::Assistant(
397                        ChatCompletionRequestAssistantMessage {
398                            content: Some(async_openai::types::ChatCompletionRequestAssistantMessageContent::Text(m.content.clone())),
399                            name: None,
400                            tool_calls: None,
401                            refusal: None,
402                            audio: None,
403                        }
404                    ),
405                    _ => ChatCompletionRequestMessage::User(
406                        ChatCompletionRequestUserMessage {
407                            content: async_openai::types::ChatCompletionRequestUserMessageContent::Text(m.content.clone()),
408                            name: None,
409                        }
410                    ),
411                })
412                .collect();
413
414            let mut builder = CreateChatCompletionRequestArgs::default();
415            builder.model(&self.model).messages(messages);
416
417            if let Some(temp) = request.temperature {
418                builder.temperature(temp);
419            }
420            if let Some(max) = request.max_tokens {
421                builder.max_completion_tokens(max);
422            }
423            if let Some(ref stop) = request.stop {
424                builder.stop(stop.clone());
425            }
426
427            let req = builder.build()?;
428            let response = self.client.chat().create(req).await?;
429
430            let choice = response.choices.first()
431                .context("No completion choices returned")?;
432
433            let content = choice.message.content.clone().unwrap_or_default();
434
435            let usage = response.usage.map(|u| TokenUsage {
436                prompt_tokens: u.prompt_tokens,
437                completion_tokens: u.completion_tokens,
438                total_tokens: u.total_tokens,
439            });
440
441            Ok(LlmResponse {
442                content,
443                model: response.model,
444                usage,
445                finish_reason: choice.finish_reason.as_ref().map(|r| format!("{:?}", r)),
446            })
447        }
448
449        async fn complete_stream(
450            &self,
451            request: &CompletionRequest,
452        ) -> Result<Pin<Box<dyn Stream<Item = Result<LlmChunk>> + Send>>> {
453            use futures_util::StreamExt;
454            use tokio_stream::wrappers::ReceiverStream;
455
456            let messages: Vec<ChatCompletionRequestMessage> = request
457                .messages
458                .iter()
459                .map(|m| match m.role.as_str() {
460                    "system" => ChatCompletionRequestMessage::System(
461                        ChatCompletionRequestSystemMessage {
462                            content: async_openai::types::ChatCompletionRequestSystemMessageContent::Text(m.content.clone()),
463                            name: None,
464                        }
465                    ),
466                    "assistant" => ChatCompletionRequestMessage::Assistant(
467                        ChatCompletionRequestAssistantMessage {
468                            content: Some(async_openai::types::ChatCompletionRequestAssistantMessageContent::Text(m.content.clone())),
469                            name: None,
470                            tool_calls: None,
471                            refusal: None,
472                            audio: None,
473                        }
474                    ),
475                    _ => ChatCompletionRequestMessage::User(
476                        ChatCompletionRequestUserMessage {
477                            content: async_openai::types::ChatCompletionRequestUserMessageContent::Text(m.content.clone()),
478                            name: None,
479                        }
480                    ),
481                })
482                .collect();
483
484            let mut builder = CreateChatCompletionRequestArgs::default();
485            builder.model(&self.model).messages(messages);
486
487            if let Some(temp) = request.temperature {
488                builder.temperature(temp);
489            }
490            if let Some(max) = request.max_tokens {
491                builder.max_completion_tokens(max);
492            }
493
494            let req = builder.build()?;
495            let (tx, rx) = tokio::sync::mpsc::channel::<Result<LlmChunk>>(100);
496
497            let client = self.client.clone();
498
499            tokio::spawn(async move {
500                let mut stream = match client.chat().create_stream(req).await {
501                    Ok(s) => s,
502                    Err(e) => {
503                        let _ = tx.send(Err(anyhow::anyhow!("Stream error: {}", e))).await;
504                        return;
505                    }
506                };
507
508                while let Some(result) = stream.next().await {
509                    match result {
510                        Ok(response) => {
511                            if let Some(choice) = response.choices.first() {
512                                let delta = choice.delta.content.clone().unwrap_or_default();
513                                let is_final = choice.finish_reason.is_some();
514
515                                if tx.send(Ok(LlmChunk { delta, is_final })).await.is_err() {
516                                    break;
517                                }
518                            }
519                        }
520                        Err(e) => {
521                            let _ = tx.send(Err(anyhow::anyhow!("Chunk error: {}", e))).await;
522                            break;
523                        }
524                    }
525                }
526            });
527
528            Ok(Box::pin(ReceiverStream::new(rx)))
529        }
530    }
531}
532
533// =============================================================================
534// Provider Factory
535// =============================================================================
536
537use std::sync::Arc;
538
539/// Create an LLM provider from configuration
540pub fn create_llm_provider(config: &AiIngestionConfig) -> Result<Arc<dyn LlmProvider>> {
541    match config.provider {
542        #[cfg(feature = "ollama")]
543        AiProvider::Ollama => {
544            let provider = ollama::OllamaProvider::from_config(config)?;
545            Ok(Arc::new(provider))
546        }
547        #[cfg(not(feature = "ollama"))]
548        AiProvider::Ollama => {
549            anyhow::bail!("Ollama support not enabled. Rebuild with --features ollama")
550        }
551
552        #[cfg(feature = "openai")]
553        AiProvider::OpenAi => {
554            let provider = openai::OpenAIProvider::from_config(config)?;
555            Ok(Arc::new(provider))
556        }
557        #[cfg(not(feature = "openai"))]
558        AiProvider::OpenAi => {
559            anyhow::bail!("OpenAI support not enabled. Rebuild with --features openai")
560        }
561
562        AiProvider::Anthropic => {
563            // Anthropic uses OpenAI-compatible API for most operations
564            // For now, we'll return an error suggesting to use a different provider
565            anyhow::bail!(
566                "Anthropic provider not yet implemented. Use 'ollama' or 'openai' instead. \
567                You can use Claude models through OpenRouter with the 'openai' provider."
568            )
569        }
570    }
571}
572
573#[cfg(test)]
574mod tests {
575    use super::*;
576
577    #[test]
578    fn test_chat_message_creation() {
579        let system = ChatMessage::system("You are a helpful assistant");
580        assert_eq!(system.role, "system");
581
582        let user = ChatMessage::user("Hello");
583        assert_eq!(user.role, "user");
584
585        let assistant = ChatMessage::assistant("Hi there!");
586        assert_eq!(assistant.role, "assistant");
587    }
588
589    #[test]
590    fn test_completion_request() {
591        let req = CompletionRequest::new("Test prompt")
592            .temperature(0.7)
593            .max_tokens(1000)
594            .stop(vec!["###".to_string()]);
595
596        assert_eq!(req.messages.len(), 1);
597        assert_eq!(req.messages[0].role, "user");
598        assert_eq!(req.temperature, Some(0.7));
599        assert_eq!(req.max_tokens, Some(1000));
600        assert!(req.stop.is_some());
601    }
602
603    #[test]
604    fn test_completion_request_with_system() {
605        let req = CompletionRequest::with_system(
606            "You are a CLI expert",
607            "How do I list files?"
608        );
609
610        assert_eq!(req.messages.len(), 2);
611        assert_eq!(req.messages[0].role, "system");
612        assert_eq!(req.messages[1].role, "user");
613    }
614
615    #[test]
616    fn test_temperature_clamping() {
617        let req = CompletionRequest::new("test").temperature(5.0);
618        assert_eq!(req.temperature, Some(2.0));
619
620        let req = CompletionRequest::new("test").temperature(-1.0);
621        assert_eq!(req.temperature, Some(0.0));
622    }
623}