skill_runtime/generation/
llm_provider.rs1use anyhow::Result;
7use async_trait::async_trait;
8use std::pin::Pin;
9use futures_util::Stream;
10
11use crate::search_config::{AiIngestionConfig, AiProvider};
12
13#[derive(Debug, Clone)]
15pub struct LlmResponse {
16 pub content: String,
18 pub model: String,
20 pub usage: Option<TokenUsage>,
22 pub finish_reason: Option<String>,
24}
25
26#[derive(Debug, Clone, Default)]
28pub struct TokenUsage {
29 pub prompt_tokens: u32,
31 pub completion_tokens: u32,
33 pub total_tokens: u32,
35}
36
37#[derive(Debug, Clone)]
39pub struct LlmChunk {
40 pub delta: String,
42 pub is_final: bool,
44}
45
46#[derive(Debug, Clone)]
48pub struct ChatMessage {
49 pub role: String,
51 pub content: String,
53}
54
55impl ChatMessage {
56 pub fn system(content: impl Into<String>) -> Self {
58 Self {
59 role: "system".to_string(),
60 content: content.into(),
61 }
62 }
63
64 pub fn user(content: impl Into<String>) -> Self {
66 Self {
67 role: "user".to_string(),
68 content: content.into(),
69 }
70 }
71
72 pub fn assistant(content: impl Into<String>) -> Self {
74 Self {
75 role: "assistant".to_string(),
76 content: content.into(),
77 }
78 }
79}
80
81#[derive(Debug, Clone)]
83pub struct CompletionRequest {
84 pub messages: Vec<ChatMessage>,
86 pub temperature: Option<f32>,
88 pub max_tokens: Option<u32>,
90 pub stop: Option<Vec<String>>,
92}
93
94impl CompletionRequest {
95 pub fn new(prompt: impl Into<String>) -> Self {
97 Self {
98 messages: vec![ChatMessage::user(prompt)],
99 temperature: None,
100 max_tokens: None,
101 stop: None,
102 }
103 }
104
105 pub fn with_system(system: impl Into<String>, user: impl Into<String>) -> Self {
107 Self {
108 messages: vec![
109 ChatMessage::system(system),
110 ChatMessage::user(user),
111 ],
112 temperature: None,
113 max_tokens: None,
114 stop: None,
115 }
116 }
117
118 pub fn temperature(mut self, temp: f32) -> Self {
120 self.temperature = Some(temp.clamp(0.0, 2.0));
121 self
122 }
123
124 pub fn max_tokens(mut self, max: u32) -> Self {
126 self.max_tokens = Some(max);
127 self
128 }
129
130 pub fn stop(mut self, sequences: Vec<String>) -> Self {
132 self.stop = Some(sequences);
133 self
134 }
135}
136
137#[async_trait]
139pub trait LlmProvider: Send + Sync {
140 fn name(&self) -> &str;
142
143 fn model(&self) -> &str;
145
146 async fn complete(&self, request: &CompletionRequest) -> Result<LlmResponse>;
148
149 async fn complete_stream(
151 &self,
152 request: &CompletionRequest,
153 ) -> Result<Pin<Box<dyn Stream<Item = Result<LlmChunk>> + Send>>>;
154}
155
156#[cfg(feature = "ollama")]
161pub mod ollama {
162 use super::*;
163 use ollama_rs::generation::completion::request::GenerationRequest;
164 use ollama_rs::generation::chat::request::ChatMessageRequest;
165 use ollama_rs::generation::chat::ChatMessage as OllamaMessage;
166 use ollama_rs::Ollama;
167
168 pub struct OllamaProvider {
170 client: Ollama,
171 model: String,
172 }
173
174 impl OllamaProvider {
175 pub fn new(host: &str, model: &str) -> Result<Self> {
177 let url = url::Url::parse(host)
179 .with_context(|| format!("Invalid Ollama host URL: {}", host))?;
180
181 let host_str = url.host_str().unwrap_or("localhost");
182 let port = url.port().unwrap_or(11434);
183
184 let client = Ollama::new(format!("http://{}", host_str), port);
185
186 Ok(Self {
187 client,
188 model: model.to_string(),
189 })
190 }
191
192 pub fn from_config(config: &AiIngestionConfig) -> Result<Self> {
194 let model = config.get_model().to_string();
195 Self::new(&config.ollama.host, &model)
196 }
197 }
198
199 #[async_trait]
200 impl LlmProvider for OllamaProvider {
201 fn name(&self) -> &str {
202 "ollama"
203 }
204
205 fn model(&self) -> &str {
206 &self.model
207 }
208
209 async fn complete(&self, request: &CompletionRequest) -> Result<LlmResponse> {
210 let messages: Vec<OllamaMessage> = request
212 .messages
213 .iter()
214 .map(|m| {
215 let role = match m.role.as_str() {
216 "system" => ollama_rs::generation::chat::MessageRole::System,
217 "user" => ollama_rs::generation::chat::MessageRole::User,
218 "assistant" => ollama_rs::generation::chat::MessageRole::Assistant,
219 _ => ollama_rs::generation::chat::MessageRole::User,
220 };
221 OllamaMessage::new(role, m.content.clone())
222 })
223 .collect();
224
225 let mut chat_request = ChatMessageRequest::new(self.model.clone(), messages);
226
227 if let Some(temp) = request.temperature {
229 let options = ollama_rs::generation::options::GenerationOptions::default()
230 .temperature(temp as f64);
231 chat_request = chat_request.options(options);
232 }
233
234 let response = self.client.send_chat_messages(chat_request).await
235 .context("Ollama chat request failed")?;
236
237 let content = response.message.map(|m| m.content).unwrap_or_default();
238
239 Ok(LlmResponse {
240 content,
241 model: self.model.clone(),
242 usage: None, finish_reason: Some("stop".to_string()),
244 })
245 }
246
247 async fn complete_stream(
248 &self,
249 request: &CompletionRequest,
250 ) -> Result<Pin<Box<dyn Stream<Item = Result<LlmChunk>> + Send>>> {
251 use futures_util::StreamExt;
252 use tokio_stream::wrappers::ReceiverStream;
253
254 let messages: Vec<OllamaMessage> = request
255 .messages
256 .iter()
257 .map(|m| {
258 let role = match m.role.as_str() {
259 "system" => ollama_rs::generation::chat::MessageRole::System,
260 "user" => ollama_rs::generation::chat::MessageRole::User,
261 "assistant" => ollama_rs::generation::chat::MessageRole::Assistant,
262 _ => ollama_rs::generation::chat::MessageRole::User,
263 };
264 OllamaMessage::new(role, m.content.clone())
265 })
266 .collect();
267
268 let mut chat_request = ChatMessageRequest::new(self.model.clone(), messages);
269
270 if let Some(temp) = request.temperature {
271 let options = ollama_rs::generation::options::GenerationOptions::default()
272 .temperature(temp as f64);
273 chat_request = chat_request.options(options);
274 }
275
276 let (tx, rx) = tokio::sync::mpsc::channel::<Result<LlmChunk>>(100);
277
278 let client = self.client.clone();
280
281 tokio::spawn(async move {
282 let mut stream = match client.send_chat_messages_stream(chat_request).await {
283 Ok(s) => s,
284 Err(e) => {
285 let _ = tx.send(Err(anyhow::anyhow!("Stream error: {}", e))).await;
286 return;
287 }
288 };
289
290 while let Some(chunk_result) = stream.next().await {
291 match chunk_result {
292 Ok(chunk) => {
293 let content = chunk.message.map(|m| m.content).unwrap_or_default();
294 let is_final = chunk.done;
295
296 if tx.send(Ok(LlmChunk {
297 delta: content,
298 is_final,
299 })).await.is_err() {
300 break;
301 }
302 }
303 Err(e) => {
304 let _ = tx.send(Err(anyhow::anyhow!("Chunk error: {}", e))).await;
305 break;
306 }
307 }
308 }
309 });
310
311 Ok(Box::pin(ReceiverStream::new(rx)))
312 }
313 }
314}
315
316#[cfg(feature = "openai")]
321pub mod openai {
322 use super::*;
323 use async_openai::{
324 types::{
325 ChatCompletionRequestMessage, ChatCompletionRequestSystemMessage,
326 ChatCompletionRequestUserMessage, ChatCompletionRequestAssistantMessage,
327 CreateChatCompletionRequestArgs,
328 },
329 Client,
330 };
331
332 pub struct OpenAIProvider {
334 client: Client<async_openai::config::OpenAIConfig>,
335 model: String,
336 }
337
338 impl OpenAIProvider {
339 pub fn new(model: &str) -> Result<Self> {
341 let client = Client::new();
343 Ok(Self {
344 client,
345 model: model.to_string(),
346 })
347 }
348
349 pub fn with_api_key(api_key: &str, model: &str) -> Result<Self> {
351 let config = async_openai::config::OpenAIConfig::new().with_api_key(api_key);
352 let client = Client::with_config(config);
353 Ok(Self {
354 client,
355 model: model.to_string(),
356 })
357 }
358
359 pub fn from_config(config: &AiIngestionConfig) -> Result<Self> {
361 let model = config.get_model().to_string();
362
363 if let Some(ref env_var) = config.openai.api_key_env {
365 if let Ok(key) = std::env::var(env_var) {
366 return Self::with_api_key(&key, &model);
367 }
368 }
369
370 Self::new(&model)
372 }
373 }
374
375 #[async_trait]
376 impl LlmProvider for OpenAIProvider {
377 fn name(&self) -> &str {
378 "openai"
379 }
380
381 fn model(&self) -> &str {
382 &self.model
383 }
384
385 async fn complete(&self, request: &CompletionRequest) -> Result<LlmResponse> {
386 let messages: Vec<ChatCompletionRequestMessage> = request
387 .messages
388 .iter()
389 .map(|m| match m.role.as_str() {
390 "system" => ChatCompletionRequestMessage::System(
391 ChatCompletionRequestSystemMessage {
392 content: async_openai::types::ChatCompletionRequestSystemMessageContent::Text(m.content.clone()),
393 name: None,
394 }
395 ),
396 "assistant" => ChatCompletionRequestMessage::Assistant(
397 ChatCompletionRequestAssistantMessage {
398 content: Some(async_openai::types::ChatCompletionRequestAssistantMessageContent::Text(m.content.clone())),
399 name: None,
400 tool_calls: None,
401 refusal: None,
402 audio: None,
403 }
404 ),
405 _ => ChatCompletionRequestMessage::User(
406 ChatCompletionRequestUserMessage {
407 content: async_openai::types::ChatCompletionRequestUserMessageContent::Text(m.content.clone()),
408 name: None,
409 }
410 ),
411 })
412 .collect();
413
414 let mut builder = CreateChatCompletionRequestArgs::default();
415 builder.model(&self.model).messages(messages);
416
417 if let Some(temp) = request.temperature {
418 builder.temperature(temp);
419 }
420 if let Some(max) = request.max_tokens {
421 builder.max_completion_tokens(max);
422 }
423 if let Some(ref stop) = request.stop {
424 builder.stop(stop.clone());
425 }
426
427 let req = builder.build()?;
428 let response = self.client.chat().create(req).await?;
429
430 let choice = response.choices.first()
431 .context("No completion choices returned")?;
432
433 let content = choice.message.content.clone().unwrap_or_default();
434
435 let usage = response.usage.map(|u| TokenUsage {
436 prompt_tokens: u.prompt_tokens,
437 completion_tokens: u.completion_tokens,
438 total_tokens: u.total_tokens,
439 });
440
441 Ok(LlmResponse {
442 content,
443 model: response.model,
444 usage,
445 finish_reason: choice.finish_reason.as_ref().map(|r| format!("{:?}", r)),
446 })
447 }
448
449 async fn complete_stream(
450 &self,
451 request: &CompletionRequest,
452 ) -> Result<Pin<Box<dyn Stream<Item = Result<LlmChunk>> + Send>>> {
453 use futures_util::StreamExt;
454 use tokio_stream::wrappers::ReceiverStream;
455
456 let messages: Vec<ChatCompletionRequestMessage> = request
457 .messages
458 .iter()
459 .map(|m| match m.role.as_str() {
460 "system" => ChatCompletionRequestMessage::System(
461 ChatCompletionRequestSystemMessage {
462 content: async_openai::types::ChatCompletionRequestSystemMessageContent::Text(m.content.clone()),
463 name: None,
464 }
465 ),
466 "assistant" => ChatCompletionRequestMessage::Assistant(
467 ChatCompletionRequestAssistantMessage {
468 content: Some(async_openai::types::ChatCompletionRequestAssistantMessageContent::Text(m.content.clone())),
469 name: None,
470 tool_calls: None,
471 refusal: None,
472 audio: None,
473 }
474 ),
475 _ => ChatCompletionRequestMessage::User(
476 ChatCompletionRequestUserMessage {
477 content: async_openai::types::ChatCompletionRequestUserMessageContent::Text(m.content.clone()),
478 name: None,
479 }
480 ),
481 })
482 .collect();
483
484 let mut builder = CreateChatCompletionRequestArgs::default();
485 builder.model(&self.model).messages(messages);
486
487 if let Some(temp) = request.temperature {
488 builder.temperature(temp);
489 }
490 if let Some(max) = request.max_tokens {
491 builder.max_completion_tokens(max);
492 }
493
494 let req = builder.build()?;
495 let (tx, rx) = tokio::sync::mpsc::channel::<Result<LlmChunk>>(100);
496
497 let client = self.client.clone();
498
499 tokio::spawn(async move {
500 let mut stream = match client.chat().create_stream(req).await {
501 Ok(s) => s,
502 Err(e) => {
503 let _ = tx.send(Err(anyhow::anyhow!("Stream error: {}", e))).await;
504 return;
505 }
506 };
507
508 while let Some(result) = stream.next().await {
509 match result {
510 Ok(response) => {
511 if let Some(choice) = response.choices.first() {
512 let delta = choice.delta.content.clone().unwrap_or_default();
513 let is_final = choice.finish_reason.is_some();
514
515 if tx.send(Ok(LlmChunk { delta, is_final })).await.is_err() {
516 break;
517 }
518 }
519 }
520 Err(e) => {
521 let _ = tx.send(Err(anyhow::anyhow!("Chunk error: {}", e))).await;
522 break;
523 }
524 }
525 }
526 });
527
528 Ok(Box::pin(ReceiverStream::new(rx)))
529 }
530 }
531}
532
533use std::sync::Arc;
538
539pub fn create_llm_provider(config: &AiIngestionConfig) -> Result<Arc<dyn LlmProvider>> {
541 match config.provider {
542 #[cfg(feature = "ollama")]
543 AiProvider::Ollama => {
544 let provider = ollama::OllamaProvider::from_config(config)?;
545 Ok(Arc::new(provider))
546 }
547 #[cfg(not(feature = "ollama"))]
548 AiProvider::Ollama => {
549 anyhow::bail!("Ollama support not enabled. Rebuild with --features ollama")
550 }
551
552 #[cfg(feature = "openai")]
553 AiProvider::OpenAi => {
554 let provider = openai::OpenAIProvider::from_config(config)?;
555 Ok(Arc::new(provider))
556 }
557 #[cfg(not(feature = "openai"))]
558 AiProvider::OpenAi => {
559 anyhow::bail!("OpenAI support not enabled. Rebuild with --features openai")
560 }
561
562 AiProvider::Anthropic => {
563 anyhow::bail!(
566 "Anthropic provider not yet implemented. Use 'ollama' or 'openai' instead. \
567 You can use Claude models through OpenRouter with the 'openai' provider."
568 )
569 }
570 }
571}
572
573#[cfg(test)]
574mod tests {
575 use super::*;
576
577 #[test]
578 fn test_chat_message_creation() {
579 let system = ChatMessage::system("You are a helpful assistant");
580 assert_eq!(system.role, "system");
581
582 let user = ChatMessage::user("Hello");
583 assert_eq!(user.role, "user");
584
585 let assistant = ChatMessage::assistant("Hi there!");
586 assert_eq!(assistant.role, "assistant");
587 }
588
589 #[test]
590 fn test_completion_request() {
591 let req = CompletionRequest::new("Test prompt")
592 .temperature(0.7)
593 .max_tokens(1000)
594 .stop(vec!["###".to_string()]);
595
596 assert_eq!(req.messages.len(), 1);
597 assert_eq!(req.messages[0].role, "user");
598 assert_eq!(req.temperature, Some(0.7));
599 assert_eq!(req.max_tokens, Some(1000));
600 assert!(req.stop.is_some());
601 }
602
603 #[test]
604 fn test_completion_request_with_system() {
605 let req = CompletionRequest::with_system(
606 "You are a CLI expert",
607 "How do I list files?"
608 );
609
610 assert_eq!(req.messages.len(), 2);
611 assert_eq!(req.messages[0].role, "system");
612 assert_eq!(req.messages[1].role, "user");
613 }
614
615 #[test]
616 fn test_temperature_clamping() {
617 let req = CompletionRequest::new("test").temperature(5.0);
618 assert_eq!(req.temperature, Some(2.0));
619
620 let req = CompletionRequest::new("test").temperature(-1.0);
621 assert_eq!(req.temperature, Some(0.0));
622 }
623}