1use std::future::Future;
35use std::pin::Pin;
36use std::sync::Arc;
37use std::time::Instant;
38
39use reqwest::Client;
40use serde::{Deserialize, Serialize};
41
42use swarm_engine_core::learn::lora::EndpointResolver;
43use swarm_engine_core::types::LoraConfig;
44
45use crate::debug_channel::{LlmDebugChannel, LlmDebugEvent};
46use crate::decider::{DecisionResponse, LlmDecider, LlmError, WorkerDecisionRequest};
47use crate::prompt_builder::PromptBuilder;
48use crate::response_parser;
49
50#[derive(Debug, Clone)]
52pub struct LlamaCppServerConfig {
53 pub endpoint: String,
55 pub model_name: String,
57 pub max_tokens: usize,
59 pub temperature: f32,
61 pub top_p: f32,
63 pub timeout_secs: u64,
65 pub chat_template: Option<ChatTemplate>,
67}
68
69#[derive(Debug, Clone)]
71pub enum ChatTemplate {
72 Lfm2,
74 Qwen,
76 Llama3,
78 Custom {
80 user_prefix: String,
81 user_suffix: String,
82 assistant_prefix: String,
83 },
84}
85
86impl ChatTemplate {
87 pub fn format(&self, prompt: &str) -> String {
89 match self {
90 ChatTemplate::Lfm2 => {
91 format!("<|user|>\n{}\n<|assistant|>\n", prompt)
92 }
93 ChatTemplate::Qwen => {
94 format!(
95 "<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n",
96 prompt
97 )
98 }
99 ChatTemplate::Llama3 => {
100 format!("<|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n", prompt)
101 }
102 ChatTemplate::Custom {
103 user_prefix,
104 user_suffix,
105 assistant_prefix,
106 } => {
107 format!(
108 "{}{}{}{}",
109 user_prefix, prompt, user_suffix, assistant_prefix
110 )
111 }
112 }
113 }
114
115 pub fn stop_tokens(&self) -> &'static [&'static str] {
119 match self {
120 ChatTemplate::Lfm2 => &["<|user|>", "<|endoftext|>"],
121 ChatTemplate::Qwen => &["<|im_end|>", "<|im_start|>", "<|endoftext|>"],
122 ChatTemplate::Llama3 => &["<|eot_id|>", "<|start_header_id|>"],
123 ChatTemplate::Custom { .. } => &[], }
125 }
126}
127
128impl Default for LlamaCppServerConfig {
129 fn default() -> Self {
130 Self {
131 endpoint: "http://localhost:8080".to_string(),
132 model_name: "llama-server".to_string(),
133 max_tokens: 256,
134 temperature: 0.7,
135 top_p: 0.9,
136 timeout_secs: 30,
137 chat_template: Some(ChatTemplate::Lfm2), }
139 }
140}
141
142impl LlamaCppServerConfig {
143 pub fn new(endpoint: impl Into<String>) -> Self {
145 Self {
146 endpoint: endpoint.into(),
147 ..Default::default()
148 }
149 }
150
151 pub fn with_model_name(mut self, name: impl Into<String>) -> Self {
153 self.model_name = name.into();
154 self
155 }
156
157 pub fn with_max_tokens(mut self, max_tokens: usize) -> Self {
159 self.max_tokens = max_tokens;
160 self
161 }
162
163 pub fn with_temperature(mut self, temperature: f32) -> Self {
165 self.temperature = temperature;
166 self
167 }
168
169 pub fn with_top_p(mut self, top_p: f32) -> Self {
171 self.top_p = top_p;
172 self
173 }
174
175 pub fn with_timeout(mut self, secs: u64) -> Self {
177 self.timeout_secs = secs;
178 self
179 }
180
181 pub fn with_chat_template(mut self, template: ChatTemplate) -> Self {
183 self.chat_template = Some(template);
184 self
185 }
186
187 pub fn without_chat_template(mut self) -> Self {
189 self.chat_template = None;
190 self
191 }
192}
193
194#[derive(Debug, Serialize)]
199struct LoraAdapterRequest {
200 id: u32,
202 scale: f32,
204}
205
206impl From<&LoraConfig> for LoraAdapterRequest {
207 fn from(config: &LoraConfig) -> Self {
208 Self {
209 id: config.id,
210 scale: config.scale,
211 }
212 }
213}
214
215#[derive(Debug, Serialize)]
217struct CompletionRequest {
218 prompt: String,
219 n_predict: usize,
220 temperature: f32,
221 top_p: f32,
222 stream: bool,
223 #[serde(skip_serializing_if = "Vec::is_empty")]
224 stop: Vec<String>,
225 #[serde(skip_serializing_if = "Vec::is_empty")]
230 lora: Vec<LoraAdapterRequest>,
231}
232
233#[derive(Debug, Deserialize)]
235struct CompletionResponse {
236 content: String,
237 #[serde(default)]
239 _stopped_eos: bool,
240}
241
242#[derive(Debug, Deserialize)]
244struct HealthResponse {
245 status: String,
246}
247
248pub struct LlamaCppServerDecider {
258 config: LlamaCppServerConfig,
259 client: Arc<Client>,
260 prompt_builder: PromptBuilder,
261 endpoint_resolver: Option<Arc<dyn EndpointResolver>>,
263}
264
265impl Clone for LlamaCppServerDecider {
266 fn clone(&self) -> Self {
267 Self {
268 config: self.config.clone(),
269 client: Arc::clone(&self.client),
270 prompt_builder: self.prompt_builder.clone(),
271 endpoint_resolver: self.endpoint_resolver.clone(),
272 }
273 }
274}
275
276impl LlamaCppServerDecider {
277 pub fn new(config: LlamaCppServerConfig) -> Result<Self, LlmError> {
279 let client = Client::builder()
280 .timeout(std::time::Duration::from_secs(config.timeout_secs))
281 .build()
282 .map_err(|e| LlmError::permanent(format!("Failed to create HTTP client: {}", e)))?;
283
284 Ok(Self {
285 config,
286 client: Arc::new(client),
287 prompt_builder: PromptBuilder::new(),
288 endpoint_resolver: None,
289 })
290 }
291
292 pub fn with_endpoint_resolver(mut self, resolver: Arc<dyn EndpointResolver>) -> Self {
297 self.endpoint_resolver = Some(resolver);
298 self
299 }
300
301 fn current_endpoint(&self) -> String {
303 if let Some(ref resolver) = self.endpoint_resolver {
304 resolver.current_endpoint()
305 } else {
306 self.config.endpoint.clone()
307 }
308 }
309
310 async fn call_server(
319 &self,
320 prompt: &str,
321 lora: Option<&LoraConfig>,
322 ) -> Result<(String, String, u64), LlmError> {
323 let start = Instant::now();
324
325 let (formatted_prompt, stop_tokens) = if let Some(ref template) = self.config.chat_template
327 {
328 let stop = template
329 .stop_tokens()
330 .iter()
331 .map(|s| s.to_string())
332 .collect();
333 (template.format(prompt), stop)
334 } else {
335 (prompt.to_string(), vec![])
336 };
337
338 let lora_adapters: Vec<LoraAdapterRequest> = lora
340 .map(|l| vec![LoraAdapterRequest::from(l)])
341 .unwrap_or_default();
342
343 let request = CompletionRequest {
344 prompt: formatted_prompt.clone(),
345 n_predict: self.config.max_tokens,
346 temperature: self.config.temperature,
347 top_p: self.config.top_p,
348 stream: false,
349 stop: stop_tokens,
350 lora: lora_adapters,
351 };
352
353 let endpoint = self.current_endpoint();
355 let url = format!("{}/completion", endpoint);
356
357 let response = self
358 .client
359 .post(&url)
360 .json(&request)
361 .send()
362 .await
363 .map_err(|e| {
364 if e.is_timeout() {
365 LlmError::transient(format!("Request timeout: {}", e))
366 } else if e.is_connect() {
367 LlmError::transient(format!("Connection error: {}", e))
368 } else {
369 LlmError::permanent(format!("HTTP error: {}", e))
370 }
371 })?;
372
373 if !response.status().is_success() {
374 let status = response.status();
375 let body = response.text().await.unwrap_or_default();
376 return Err(LlmError::permanent(format!(
377 "Server error {}: {}",
378 status, body
379 )));
380 }
381
382 let completion: CompletionResponse = response
383 .json()
384 .await
385 .map_err(|e| LlmError::permanent(format!("Failed to parse response: {}", e)))?;
386
387 let latency_ms = start.elapsed().as_millis() as u64;
388
389 Ok((completion.content, formatted_prompt, latency_ms))
390 }
391
392 fn emit_debug_event(&self, event: LlmDebugEvent) {
394 LlmDebugChannel::global().emit(event);
395 }
396}
397
398impl LlmDecider for LlamaCppServerDecider {
399 fn decide(
400 &self,
401 request: WorkerDecisionRequest,
402 ) -> Pin<Box<dyn Future<Output = Result<DecisionResponse, LlmError>> + Send + '_>> {
403 let current_endpoint = self.current_endpoint();
405
406 Box::pin(async move {
407 let prompt = self.prompt_builder.build(&request.context);
409 let worker_id = request.worker_id.0;
410 let lora = request.lora.as_ref();
411
412 let (raw_response, _formatted_prompt, latency_ms) =
414 match self.call_server(&prompt, lora).await {
415 Ok(result) => result,
416 Err(e) => {
417 self.emit_debug_event(
419 LlmDebugEvent::new("decide", &self.config.model_name)
420 .worker_id(worker_id)
421 .endpoint(¤t_endpoint)
422 .prompt(&prompt)
423 .lora_opt(request.lora.clone())
424 .error(e.message()),
425 );
426 return Err(e);
427 }
428 };
429
430 let candidate_names = response_parser::candidate_names(&request.context.candidates);
431
432 match response_parser::parse_response(&raw_response, &candidate_names) {
434 Ok(mut d) => {
435 self.emit_debug_event(
437 LlmDebugEvent::new("decide", &self.config.model_name)
438 .worker_id(worker_id)
439 .endpoint(¤t_endpoint)
440 .prompt(&prompt)
441 .response(&raw_response)
442 .lora_opt(request.lora.clone())
443 .latency_ms(latency_ms),
444 );
445
446 d.prompt = Some(prompt);
447 d.raw_response = Some(raw_response);
448 Ok(d)
449 }
450 Err(e) => {
451 self.emit_debug_event(
453 LlmDebugEvent::new("decide", &self.config.model_name)
454 .worker_id(worker_id)
455 .endpoint(¤t_endpoint)
456 .prompt(&prompt)
457 .response(&raw_response)
458 .lora_opt(request.lora.clone())
459 .error(e.message())
460 .latency_ms(latency_ms),
461 );
462
463 tracing::warn!(error = %e, "Parse error");
464 tracing::debug!(raw = %raw_response, "Raw response");
465 Err(e)
466 }
467 }
468 })
469 }
470
471 fn call_raw(
472 &self,
473 prompt: &str,
474 lora: Option<&LoraConfig>,
475 ) -> Pin<Box<dyn Future<Output = Result<String, LlmError>> + Send + '_>> {
476 let prompt = prompt.to_string();
477 let lora_owned = lora.cloned();
478 let current_endpoint = self.current_endpoint();
480
481 Box::pin(async move {
482 match self.call_server(&prompt, lora_owned.as_ref()).await {
484 Ok((response, _formatted_prompt, latency_ms)) => {
485 self.emit_debug_event(
487 LlmDebugEvent::new("call_raw", &self.config.model_name)
488 .endpoint(¤t_endpoint)
489 .prompt(&prompt)
490 .response(&response)
491 .lora_opt(lora_owned.clone())
492 .latency_ms(latency_ms),
493 );
494 Ok(response)
495 }
496 Err(e) => {
497 self.emit_debug_event(
499 LlmDebugEvent::new("call_raw", &self.config.model_name)
500 .endpoint(¤t_endpoint)
501 .prompt(&prompt)
502 .lora_opt(lora_owned)
503 .error(e.message()),
504 );
505 Err(e)
506 }
507 }
508 })
509 }
510
511 fn model_name(&self) -> &str {
512 &self.config.model_name
513 }
514
515 fn is_healthy(&self) -> Pin<Box<dyn Future<Output = bool> + Send + '_>> {
516 let client = Arc::clone(&self.client);
517 let endpoint = self.current_endpoint();
519
520 Box::pin(async move {
521 let url = format!("{}/health", endpoint);
522 match client.get(&url).send().await {
523 Ok(response) => {
524 if let Ok(health) = response.json::<HealthResponse>().await {
525 health.status == "ok"
526 } else {
527 false
528 }
529 }
530 Err(_) => false,
531 }
532 })
533 }
534
535 fn max_concurrency(&self) -> Pin<Box<dyn Future<Output = Option<usize>> + Send + '_>> {
536 let client = Arc::clone(&self.client);
537 let endpoint = self.current_endpoint();
539
540 Box::pin(async move {
541 let url = format!("{}/slots", endpoint);
542 match client.get(&url).send().await {
543 Ok(response) => {
544 if let Ok(slots) = response.json::<Vec<serde_json::Value>>().await {
545 Some(slots.len())
546 } else {
547 None
548 }
549 }
550 Err(_) => None,
551 }
552 })
553 }
554}
555
556#[cfg(test)]
557mod tests {
558 use super::*;
559
560 #[test]
565 fn test_config_default() {
566 let config = LlamaCppServerConfig::default();
567 assert_eq!(config.endpoint, "http://localhost:8080");
568 assert_eq!(config.max_tokens, 256);
569 assert!(matches!(config.chat_template, Some(ChatTemplate::Lfm2)));
570 }
571
572 #[test]
573 fn test_config_builder() {
574 let config = LlamaCppServerConfig::new("http://192.168.1.100:9000")
575 .with_model_name("my-model")
576 .with_max_tokens(512)
577 .with_temperature(0.5)
578 .with_top_p(0.95)
579 .with_timeout(60);
580
581 assert_eq!(config.endpoint, "http://192.168.1.100:9000");
582 assert_eq!(config.model_name, "my-model");
583 assert_eq!(config.max_tokens, 512);
584 assert!((config.temperature - 0.5).abs() < f32::EPSILON);
585 assert!((config.top_p - 0.95).abs() < f32::EPSILON);
586 assert_eq!(config.timeout_secs, 60);
587 }
588
589 #[test]
590 fn test_config_chat_template() {
591 let config = LlamaCppServerConfig::default().with_chat_template(ChatTemplate::Qwen);
592 assert!(matches!(config.chat_template, Some(ChatTemplate::Qwen)));
593
594 let config = LlamaCppServerConfig::default().without_chat_template();
595 assert!(config.chat_template.is_none());
596 }
597
598 #[test]
603 fn test_chat_template_lfm2() {
604 let template = ChatTemplate::Lfm2;
605 let formatted = template.format("Hello");
606 assert_eq!(formatted, "<|user|>\nHello\n<|assistant|>\n");
607 }
608
609 #[test]
610 fn test_chat_template_qwen() {
611 let template = ChatTemplate::Qwen;
612 let formatted = template.format("Hello");
613 assert!(formatted.contains("<|im_start|>user"));
614 assert!(formatted.contains("<|im_end|>"));
615 assert!(formatted.contains("<|im_start|>assistant"));
616 }
617
618 #[test]
619 fn test_chat_template_llama3() {
620 let template = ChatTemplate::Llama3;
621 let formatted = template.format("Hello");
622 assert!(formatted.contains("<|start_header_id|>user"));
623 assert!(formatted.contains("<|eot_id|>"));
624 }
625
626 #[test]
627 fn test_chat_template_custom() {
628 let template = ChatTemplate::Custom {
629 user_prefix: "[USER]".to_string(),
630 user_suffix: "[/USER]".to_string(),
631 assistant_prefix: "[ASSISTANT]".to_string(),
632 };
633 let formatted = template.format("Hello");
634 assert_eq!(formatted, "[USER]Hello[/USER][ASSISTANT]");
635 }
636
637 #[test]
638 fn test_chat_template_stop_tokens() {
639 let lfm2 = ChatTemplate::Lfm2;
641 let stop = lfm2.stop_tokens();
642 assert!(stop.contains(&"<|user|>"));
643 assert!(stop.contains(&"<|endoftext|>"));
644
645 let qwen = ChatTemplate::Qwen;
647 let stop = qwen.stop_tokens();
648 assert!(stop.contains(&"<|im_end|>"));
649
650 let custom = ChatTemplate::Custom {
652 user_prefix: "[U]".to_string(),
653 user_suffix: "[/U]".to_string(),
654 assistant_prefix: "[A]".to_string(),
655 };
656 assert!(custom.stop_tokens().is_empty());
657 }
658
659 }