Skip to main content

swarm_engine_llm/
llama_cpp_server.rs

1//! llama-server Decider - HTTP API 連携
2//!
3//! llama-server (llama.cpp の HTTP サーバー) を使用した推論バックエンド。
4//! 事前にサーバーを起動しておくことで、モデルロード時間を排除できる。
5//!
6//! # サーバー起動
7//!
8//! ```bash
9//! # llama-server を起動(モデルは一度だけロード)
10//! llama-server -m model.gguf --host 0.0.0.0 --port 8080
11//! ```
12//!
13//! # 特徴
14//!
15//! - **高速起動**: モデルはサーバー側で事前ロード済み
16//! - **HTTP API**: 標準的な HTTP/JSON インターフェース
17//! - **LlmDecider 互換**: LlmBatchProcessor と組み合わせて使用可能
18//!
19//! # 使用例
20//!
21//! ```ignore
22//! use swarm_engine_llm::llama_cpp_server::{LlamaCppServerDecider, LlamaCppServerConfig};
23//! use swarm_engine_llm::LlmBatchProcessor;
24//!
25//! let config = LlamaCppServerConfig::default();
26//! let decider = LlamaCppServerDecider::new(config)?;
27//!
28//! // BatchProcessor として使用
29//! let processor = LlmBatchProcessor::new(decider);
30//! ```
31//!
32//! レスポンスパースは `response_parser` モジュールに委譲。
33
34use std::future::Future;
35use std::pin::Pin;
36use std::sync::Arc;
37use std::time::Instant;
38
39use reqwest::Client;
40use serde::{Deserialize, Serialize};
41
42use swarm_engine_core::learn::lora::EndpointResolver;
43use swarm_engine_core::types::LoraConfig;
44
45use crate::debug_channel::{LlmDebugChannel, LlmDebugEvent};
46use crate::decider::{DecisionResponse, LlmDecider, LlmError, WorkerDecisionRequest};
47use crate::prompt_builder::PromptBuilder;
48use crate::response_parser;
49
50/// llama-server 設定
51#[derive(Debug, Clone)]
52pub struct LlamaCppServerConfig {
53    /// サーバーエンドポイント (e.g., "http://localhost:8080")
54    pub endpoint: String,
55    /// モデル名(表示用、サーバー側で設定済み)
56    pub model_name: String,
57    /// 最大生成トークン数
58    pub max_tokens: usize,
59    /// Temperature
60    pub temperature: f32,
61    /// Top-p
62    pub top_p: f32,
63    /// リクエストタイムアウト(秒)
64    pub timeout_secs: u64,
65    /// Chat template format (Some = 使用, None = 使用しない)
66    pub chat_template: Option<ChatTemplate>,
67}
68
69/// Chat template format
70#[derive(Debug, Clone)]
71pub enum ChatTemplate {
72    /// LFM2.5 形式: <|user|>\n{prompt}\n<|assistant|>\n
73    Lfm2,
74    /// Qwen 形式
75    Qwen,
76    /// Llama 形式
77    Llama3,
78    /// カスタム形式
79    Custom {
80        user_prefix: String,
81        user_suffix: String,
82        assistant_prefix: String,
83    },
84}
85
86impl ChatTemplate {
87    /// プロンプトをテンプレートで囲む
88    pub fn format(&self, prompt: &str) -> String {
89        match self {
90            ChatTemplate::Lfm2 => {
91                format!("<|user|>\n{}\n<|assistant|>\n", prompt)
92            }
93            ChatTemplate::Qwen => {
94                format!(
95                    "<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n",
96                    prompt
97                )
98            }
99            ChatTemplate::Llama3 => {
100                format!("<|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n", prompt)
101            }
102            ChatTemplate::Custom {
103                user_prefix,
104                user_suffix,
105                assistant_prefix,
106            } => {
107                format!(
108                    "{}{}{}{}",
109                    user_prefix, prompt, user_suffix, assistant_prefix
110                )
111            }
112        }
113    }
114
115    /// Stop tokens for this template
116    ///
117    /// Returns a static slice to avoid allocation on every call.
118    pub fn stop_tokens(&self) -> &'static [&'static str] {
119        match self {
120            ChatTemplate::Lfm2 => &["<|user|>", "<|endoftext|>"],
121            ChatTemplate::Qwen => &["<|im_end|>", "<|im_start|>", "<|endoftext|>"],
122            ChatTemplate::Llama3 => &["<|eot_id|>", "<|start_header_id|>"],
123            ChatTemplate::Custom { .. } => &[], // Custom templates should set stop tokens manually
124        }
125    }
126}
127
128impl Default for LlamaCppServerConfig {
129    fn default() -> Self {
130        Self {
131            endpoint: "http://localhost:8080".to_string(),
132            model_name: "llama-server".to_string(),
133            max_tokens: 256,
134            temperature: 0.7,
135            top_p: 0.9,
136            timeout_secs: 30,
137            chat_template: Some(ChatTemplate::Lfm2), // LFM2.5 がデフォルト
138        }
139    }
140}
141
142impl LlamaCppServerConfig {
143    /// 新しい設定を作成
144    pub fn new(endpoint: impl Into<String>) -> Self {
145        Self {
146            endpoint: endpoint.into(),
147            ..Default::default()
148        }
149    }
150
151    /// モデル名を設定
152    pub fn with_model_name(mut self, name: impl Into<String>) -> Self {
153        self.model_name = name.into();
154        self
155    }
156
157    /// 最大トークン数を設定
158    pub fn with_max_tokens(mut self, max_tokens: usize) -> Self {
159        self.max_tokens = max_tokens;
160        self
161    }
162
163    /// Temperature を設定
164    pub fn with_temperature(mut self, temperature: f32) -> Self {
165        self.temperature = temperature;
166        self
167    }
168
169    /// Top-p を設定
170    pub fn with_top_p(mut self, top_p: f32) -> Self {
171        self.top_p = top_p;
172        self
173    }
174
175    /// タイムアウトを設定
176    pub fn with_timeout(mut self, secs: u64) -> Self {
177        self.timeout_secs = secs;
178        self
179    }
180
181    /// Chat template を設定
182    pub fn with_chat_template(mut self, template: ChatTemplate) -> Self {
183        self.chat_template = Some(template);
184        self
185    }
186
187    /// Chat template を無効化
188    pub fn without_chat_template(mut self) -> Self {
189        self.chat_template = None;
190        self
191    }
192}
193
194/// llama-server LoRA adapter リクエスト
195///
196/// llama.cpp の per-request LoRA 指定用。
197/// `--lora-init-without-apply` で起動時にロードした LoRA を指定。
198#[derive(Debug, Serialize)]
199struct LoraAdapterRequest {
200    /// LoRA アダプター ID(ロード順)
201    id: u32,
202    /// 適用強度(0.0〜1.0)
203    scale: f32,
204}
205
206impl From<&LoraConfig> for LoraAdapterRequest {
207    fn from(config: &LoraConfig) -> Self {
208        Self {
209            id: config.id,
210            scale: config.scale,
211        }
212    }
213}
214
215/// llama-server completion API リクエスト
216#[derive(Debug, Serialize)]
217struct CompletionRequest {
218    prompt: String,
219    n_predict: usize,
220    temperature: f32,
221    top_p: f32,
222    stream: bool,
223    #[serde(skip_serializing_if = "Vec::is_empty")]
224    stop: Vec<String>,
225    /// LoRA アダプター設定(per-request LoRA)
226    ///
227    /// 空の場合は LoRA なし(ベースモデルのみ)。
228    /// `--lora-init-without-apply` で起動した場合に有効。
229    #[serde(skip_serializing_if = "Vec::is_empty")]
230    lora: Vec<LoraAdapterRequest>,
231}
232
233/// llama-server completion API レスポンス
234#[derive(Debug, Deserialize)]
235struct CompletionResponse {
236    content: String,
237    /// Whether generation stopped due to a stop token (unused but kept for debugging)
238    #[serde(default)]
239    _stopped_eos: bool,
240}
241
242/// llama-server health API レスポンス
243#[derive(Debug, Deserialize)]
244struct HealthResponse {
245    status: String,
246}
247
248/// llama-server Decider
249///
250/// LlmDecider trait を実装。LlmBatchProcessor と組み合わせて使用可能。
251///
252/// ## 動的エンドポイント解決
253///
254/// `with_endpoint_resolver()` で `EndpointResolver` を設定すると、
255/// リクエストごとにエンドポイントを動的に解決する。
256/// Blue-Green デプロイメントでダウンタイムなしの切り替えに使用。
257pub struct LlamaCppServerDecider {
258    config: LlamaCppServerConfig,
259    client: Arc<Client>,
260    prompt_builder: PromptBuilder,
261    /// 動的エンドポイント解決(Blue-Green 用)
262    endpoint_resolver: Option<Arc<dyn EndpointResolver>>,
263}
264
265impl Clone for LlamaCppServerDecider {
266    fn clone(&self) -> Self {
267        Self {
268            config: self.config.clone(),
269            client: Arc::clone(&self.client),
270            prompt_builder: self.prompt_builder.clone(),
271            endpoint_resolver: self.endpoint_resolver.clone(),
272        }
273    }
274}
275
276impl LlamaCppServerDecider {
277    /// 新しい LlamaCppServerDecider を作成
278    pub fn new(config: LlamaCppServerConfig) -> Result<Self, LlmError> {
279        let client = Client::builder()
280            .timeout(std::time::Duration::from_secs(config.timeout_secs))
281            .build()
282            .map_err(|e| LlmError::permanent(format!("Failed to create HTTP client: {}", e)))?;
283
284        Ok(Self {
285            config,
286            client: Arc::new(client),
287            prompt_builder: PromptBuilder::new(),
288            endpoint_resolver: None,
289        })
290    }
291
292    /// EndpointResolver を設定(Blue-Green デプロイメント用)
293    ///
294    /// 設定すると、リクエストごとに `resolver.current_endpoint()` からエンドポイントを取得。
295    /// `config.endpoint` より優先される。
296    pub fn with_endpoint_resolver(mut self, resolver: Arc<dyn EndpointResolver>) -> Self {
297        self.endpoint_resolver = Some(resolver);
298        self
299    }
300
301    /// 現在のエンドポイントを取得
302    fn current_endpoint(&self) -> String {
303        if let Some(ref resolver) = self.endpoint_resolver {
304            resolver.current_endpoint()
305        } else {
306            self.config.endpoint.clone()
307        }
308    }
309
310    /// llama-server API を呼び出し
311    ///
312    /// # Arguments
313    /// * `prompt` - 送信するプロンプト
314    /// * `lora` - LoRA 設定(None の場合はベースモデルのみ)
315    ///
316    /// # Returns
317    /// (response_content, formatted_prompt, latency_ms)
318    async fn call_server(
319        &self,
320        prompt: &str,
321        lora: Option<&LoraConfig>,
322    ) -> Result<(String, String, u64), LlmError> {
323        let start = Instant::now();
324
325        // Chat template でフォーマット
326        let (formatted_prompt, stop_tokens) = if let Some(ref template) = self.config.chat_template
327        {
328            let stop = template
329                .stop_tokens()
330                .iter()
331                .map(|s| s.to_string())
332                .collect();
333            (template.format(prompt), stop)
334        } else {
335            (prompt.to_string(), vec![])
336        };
337
338        // LoRA 設定を変換
339        let lora_adapters: Vec<LoraAdapterRequest> = lora
340            .map(|l| vec![LoraAdapterRequest::from(l)])
341            .unwrap_or_default();
342
343        let request = CompletionRequest {
344            prompt: formatted_prompt.clone(),
345            n_predict: self.config.max_tokens,
346            temperature: self.config.temperature,
347            top_p: self.config.top_p,
348            stream: false,
349            stop: stop_tokens,
350            lora: lora_adapters,
351        };
352
353        // 動的エンドポイント解決(Blue-Green 対応)
354        let endpoint = self.current_endpoint();
355        let url = format!("{}/completion", endpoint);
356
357        let response = self
358            .client
359            .post(&url)
360            .json(&request)
361            .send()
362            .await
363            .map_err(|e| {
364                if e.is_timeout() {
365                    LlmError::transient(format!("Request timeout: {}", e))
366                } else if e.is_connect() {
367                    LlmError::transient(format!("Connection error: {}", e))
368                } else {
369                    LlmError::permanent(format!("HTTP error: {}", e))
370                }
371            })?;
372
373        if !response.status().is_success() {
374            let status = response.status();
375            let body = response.text().await.unwrap_or_default();
376            return Err(LlmError::permanent(format!(
377                "Server error {}: {}",
378                status, body
379            )));
380        }
381
382        let completion: CompletionResponse = response
383            .json()
384            .await
385            .map_err(|e| LlmError::permanent(format!("Failed to parse response: {}", e)))?;
386
387        let latency_ms = start.elapsed().as_millis() as u64;
388
389        Ok((completion.content, formatted_prompt, latency_ms))
390    }
391
392    /// デバッグイベントを発行
393    fn emit_debug_event(&self, event: LlmDebugEvent) {
394        LlmDebugChannel::global().emit(event);
395    }
396}
397
398impl LlmDecider for LlamaCppServerDecider {
399    fn decide(
400        &self,
401        request: WorkerDecisionRequest,
402    ) -> Pin<Box<dyn Future<Output = Result<DecisionResponse, LlmError>> + Send + '_>> {
403        // 動的エンドポイント解決(Blue-Green 対応)
404        let current_endpoint = self.current_endpoint();
405
406        Box::pin(async move {
407            // PromptBuilder を使ってプロンプトを生成
408            let prompt = self.prompt_builder.build(&request.context);
409            let worker_id = request.worker_id.0;
410            let lora = request.lora.as_ref();
411
412            // LLM呼び出し(LoRA 設定を渡す)
413            let (raw_response, _formatted_prompt, latency_ms) =
414                match self.call_server(&prompt, lora).await {
415                    Ok(result) => result,
416                    Err(e) => {
417                        // エラー時のデバッグイベント
418                        self.emit_debug_event(
419                            LlmDebugEvent::new("decide", &self.config.model_name)
420                                .worker_id(worker_id)
421                                .endpoint(&current_endpoint)
422                                .prompt(&prompt)
423                                .lora_opt(request.lora.clone())
424                                .error(e.message()),
425                        );
426                        return Err(e);
427                    }
428                };
429
430            let candidate_names = response_parser::candidate_names(&request.context.candidates);
431
432            // Parse response
433            match response_parser::parse_response(&raw_response, &candidate_names) {
434                Ok(mut d) => {
435                    // 成功時のデバッグイベント
436                    self.emit_debug_event(
437                        LlmDebugEvent::new("decide", &self.config.model_name)
438                            .worker_id(worker_id)
439                            .endpoint(&current_endpoint)
440                            .prompt(&prompt)
441                            .response(&raw_response)
442                            .lora_opt(request.lora.clone())
443                            .latency_ms(latency_ms),
444                    );
445
446                    d.prompt = Some(prompt);
447                    d.raw_response = Some(raw_response);
448                    Ok(d)
449                }
450                Err(e) => {
451                    // パースエラー時のデバッグイベント
452                    self.emit_debug_event(
453                        LlmDebugEvent::new("decide", &self.config.model_name)
454                            .worker_id(worker_id)
455                            .endpoint(&current_endpoint)
456                            .prompt(&prompt)
457                            .response(&raw_response)
458                            .lora_opt(request.lora.clone())
459                            .error(e.message())
460                            .latency_ms(latency_ms),
461                    );
462
463                    tracing::warn!(error = %e, "Parse error");
464                    tracing::debug!(raw = %raw_response, "Raw response");
465                    Err(e)
466                }
467            }
468        })
469    }
470
471    fn call_raw(
472        &self,
473        prompt: &str,
474        lora: Option<&LoraConfig>,
475    ) -> Pin<Box<dyn Future<Output = Result<String, LlmError>> + Send + '_>> {
476        let prompt = prompt.to_string();
477        let lora_owned = lora.cloned();
478        // 動的エンドポイント解決(Blue-Green 対応)
479        let current_endpoint = self.current_endpoint();
480
481        Box::pin(async move {
482            // LoRA 設定を渡す
483            match self.call_server(&prompt, lora_owned.as_ref()).await {
484                Ok((response, _formatted_prompt, latency_ms)) => {
485                    // 成功時のデバッグイベント
486                    self.emit_debug_event(
487                        LlmDebugEvent::new("call_raw", &self.config.model_name)
488                            .endpoint(&current_endpoint)
489                            .prompt(&prompt)
490                            .response(&response)
491                            .lora_opt(lora_owned.clone())
492                            .latency_ms(latency_ms),
493                    );
494                    Ok(response)
495                }
496                Err(e) => {
497                    // エラー時のデバッグイベント
498                    self.emit_debug_event(
499                        LlmDebugEvent::new("call_raw", &self.config.model_name)
500                            .endpoint(&current_endpoint)
501                            .prompt(&prompt)
502                            .lora_opt(lora_owned)
503                            .error(e.message()),
504                    );
505                    Err(e)
506                }
507            }
508        })
509    }
510
511    fn model_name(&self) -> &str {
512        &self.config.model_name
513    }
514
515    fn is_healthy(&self) -> Pin<Box<dyn Future<Output = bool> + Send + '_>> {
516        let client = Arc::clone(&self.client);
517        // 動的エンドポイント解決(Blue-Green 対応)
518        let endpoint = self.current_endpoint();
519
520        Box::pin(async move {
521            let url = format!("{}/health", endpoint);
522            match client.get(&url).send().await {
523                Ok(response) => {
524                    if let Ok(health) = response.json::<HealthResponse>().await {
525                        health.status == "ok"
526                    } else {
527                        false
528                    }
529                }
530                Err(_) => false,
531            }
532        })
533    }
534
535    fn max_concurrency(&self) -> Pin<Box<dyn Future<Output = Option<usize>> + Send + '_>> {
536        let client = Arc::clone(&self.client);
537        // 動的エンドポイント解決(Blue-Green 対応)
538        let endpoint = self.current_endpoint();
539
540        Box::pin(async move {
541            let url = format!("{}/slots", endpoint);
542            match client.get(&url).send().await {
543                Ok(response) => {
544                    if let Ok(slots) = response.json::<Vec<serde_json::Value>>().await {
545                        Some(slots.len())
546                    } else {
547                        None
548                    }
549                }
550                Err(_) => None,
551            }
552        })
553    }
554}
555
556#[cfg(test)]
557mod tests {
558    use super::*;
559
560    // =========================================================================
561    // Config Tests
562    // =========================================================================
563
564    #[test]
565    fn test_config_default() {
566        let config = LlamaCppServerConfig::default();
567        assert_eq!(config.endpoint, "http://localhost:8080");
568        assert_eq!(config.max_tokens, 256);
569        assert!(matches!(config.chat_template, Some(ChatTemplate::Lfm2)));
570    }
571
572    #[test]
573    fn test_config_builder() {
574        let config = LlamaCppServerConfig::new("http://192.168.1.100:9000")
575            .with_model_name("my-model")
576            .with_max_tokens(512)
577            .with_temperature(0.5)
578            .with_top_p(0.95)
579            .with_timeout(60);
580
581        assert_eq!(config.endpoint, "http://192.168.1.100:9000");
582        assert_eq!(config.model_name, "my-model");
583        assert_eq!(config.max_tokens, 512);
584        assert!((config.temperature - 0.5).abs() < f32::EPSILON);
585        assert!((config.top_p - 0.95).abs() < f32::EPSILON);
586        assert_eq!(config.timeout_secs, 60);
587    }
588
589    #[test]
590    fn test_config_chat_template() {
591        let config = LlamaCppServerConfig::default().with_chat_template(ChatTemplate::Qwen);
592        assert!(matches!(config.chat_template, Some(ChatTemplate::Qwen)));
593
594        let config = LlamaCppServerConfig::default().without_chat_template();
595        assert!(config.chat_template.is_none());
596    }
597
598    // =========================================================================
599    // Chat Template Tests
600    // =========================================================================
601
602    #[test]
603    fn test_chat_template_lfm2() {
604        let template = ChatTemplate::Lfm2;
605        let formatted = template.format("Hello");
606        assert_eq!(formatted, "<|user|>\nHello\n<|assistant|>\n");
607    }
608
609    #[test]
610    fn test_chat_template_qwen() {
611        let template = ChatTemplate::Qwen;
612        let formatted = template.format("Hello");
613        assert!(formatted.contains("<|im_start|>user"));
614        assert!(formatted.contains("<|im_end|>"));
615        assert!(formatted.contains("<|im_start|>assistant"));
616    }
617
618    #[test]
619    fn test_chat_template_llama3() {
620        let template = ChatTemplate::Llama3;
621        let formatted = template.format("Hello");
622        assert!(formatted.contains("<|start_header_id|>user"));
623        assert!(formatted.contains("<|eot_id|>"));
624    }
625
626    #[test]
627    fn test_chat_template_custom() {
628        let template = ChatTemplate::Custom {
629            user_prefix: "[USER]".to_string(),
630            user_suffix: "[/USER]".to_string(),
631            assistant_prefix: "[ASSISTANT]".to_string(),
632        };
633        let formatted = template.format("Hello");
634        assert_eq!(formatted, "[USER]Hello[/USER][ASSISTANT]");
635    }
636
637    #[test]
638    fn test_chat_template_stop_tokens() {
639        // LFM2 stop tokens
640        let lfm2 = ChatTemplate::Lfm2;
641        let stop = lfm2.stop_tokens();
642        assert!(stop.contains(&"<|user|>"));
643        assert!(stop.contains(&"<|endoftext|>"));
644
645        // Qwen stop tokens
646        let qwen = ChatTemplate::Qwen;
647        let stop = qwen.stop_tokens();
648        assert!(stop.contains(&"<|im_end|>"));
649
650        // Custom has no default stop tokens
651        let custom = ChatTemplate::Custom {
652            user_prefix: "[U]".to_string(),
653            user_suffix: "[/U]".to_string(),
654            assistant_prefix: "[A]".to_string(),
655        };
656        assert!(custom.stop_tokens().is_empty());
657    }
658
659    // Note: JSON parse/fuzzy repair tests are now in response_parser module
660}