Skip to main content

swarm_engine_llm/
decider.rs

1//! LLM Decider - Action 選択のための LLM 抽象
2//!
3//! 軽量LLM(Qwen2.5-Coder 1.5B等)による高速な Action 選択
4//!
5//! # 概念
6//!
7//! - [`LlmDecider`]: LLM への問い合わせ抽象(非同期、バッチ対応)
8//!
9//! # 型の統一
10//!
11//! LLM層はCore層の型を直接使用:
12//! - `WorkerDecisionRequest` - リクエスト
13//! - `DecisionResponse` - レスポンス
14
15use std::future::Future;
16use std::pin::Pin;
17
18/// バッチ決定の戻り値型(clippy::type_complexity 対策)
19pub type BatchDecisionFuture<'a> =
20    Pin<Box<dyn Future<Output = Vec<Result<DecisionResponse, LlmError>>> + Send + 'a>>;
21
22// Core の型を再エクスポート
23pub use swarm_engine_core::agent::{
24    ActionCandidate, ActionParam, DecisionResponse, ResolvedContext, WorkerDecisionRequest,
25};
26pub use swarm_engine_core::types::LoraConfig;
27
28/// LLM エラー
29#[derive(Debug, Clone, thiserror::Error)]
30pub enum LlmError {
31    /// 一時的エラー(リトライ可能)
32    #[error("LLM error (transient): {0}")]
33    Transient(String),
34
35    /// 恒久的エラー(リトライ不可)
36    #[error("LLM error: {0}")]
37    Permanent(String),
38}
39
40impl LlmError {
41    pub fn transient(message: impl Into<String>) -> Self {
42        Self::Transient(message.into())
43    }
44
45    pub fn permanent(message: impl Into<String>) -> Self {
46        Self::Permanent(message.into())
47    }
48
49    pub fn is_transient(&self) -> bool {
50        matches!(self, Self::Transient(_))
51    }
52
53    pub fn message(&self) -> &str {
54        match self {
55            Self::Transient(msg) => msg,
56            Self::Permanent(msg) => msg,
57        }
58    }
59}
60
61impl From<swarm_engine_core::error::SwarmError> for LlmError {
62    fn from(err: swarm_engine_core::error::SwarmError) -> Self {
63        if err.is_transient() {
64            Self::Transient(err.message())
65        } else {
66            Self::Permanent(err.message())
67        }
68    }
69}
70
71impl From<LlmError> for swarm_engine_core::error::SwarmError {
72    fn from(err: LlmError) -> Self {
73        match err {
74            LlmError::Transient(message) => {
75                swarm_engine_core::error::SwarmError::LlmTransient { message }
76            }
77            LlmError::Permanent(message) => {
78                swarm_engine_core::error::SwarmError::LlmPermanent { message }
79            }
80        }
81    }
82}
83
84/// LLM Decider trait
85///
86/// Core の `WorkerDecisionRequest` を受け取り、`DecisionResponse` を返す。
87pub trait LlmDecider: Send + Sync {
88    /// 単一の決定
89    fn decide(
90        &self,
91        request: WorkerDecisionRequest,
92    ) -> Pin<Box<dyn Future<Output = Result<DecisionResponse, LlmError>> + Send + '_>>;
93
94    /// 生のプロンプトを送信し、生のレスポンスを取得
95    ///
96    /// DependencyGraph 生成など、Action 選択以外の用途に使用。
97    /// デフォルト実装はエラーを返す(未対応)。
98    ///
99    /// # Arguments
100    /// * `prompt` - 送信するプロンプト
101    /// * `lora` - LoRA 設定(None の場合はベースモデルのみ)
102    fn call_raw(
103        &self,
104        _prompt: &str,
105        _lora: Option<&LoraConfig>,
106    ) -> Pin<Box<dyn Future<Output = Result<String, LlmError>> + Send + '_>> {
107        Box::pin(async { Err(LlmError::permanent("call_raw not implemented")) })
108    }
109
110    /// バッチ決定(100+ Agent 対応)
111    fn decide_batch(&self, requests: Vec<WorkerDecisionRequest>) -> BatchDecisionFuture<'_> {
112        // デフォルト実装: 順次処理
113        Box::pin(async move {
114            let mut results = Vec::with_capacity(requests.len());
115            for req in requests {
116                results.push(self.decide(req).await);
117            }
118            results
119        })
120    }
121
122    /// モデル名
123    fn model_name(&self) -> &str;
124
125    /// ヘルスチェック
126    fn is_healthy(&self) -> Pin<Box<dyn Future<Output = bool> + Send + '_>>;
127
128    /// 最大同時実行数を取得(サーバーのスロット数等)
129    ///
130    /// デフォルトはNone(無制限)。
131    /// 実装側でサーバーに問い合わせてスロット数を返すことができる。
132    fn max_concurrency(&self) -> Pin<Box<dyn Future<Output = Option<usize>> + Send + '_>> {
133        Box::pin(async { None })
134    }
135}
136
137/// Decider 設定
138#[derive(Debug, Clone)]
139pub struct LlmDeciderConfig {
140    /// モデル名
141    pub model: String,
142    /// エンドポイント
143    pub endpoint: String,
144    /// タイムアウト(ミリ秒)
145    pub timeout_ms: u64,
146    /// 最大バッチサイズ
147    pub max_batch_size: usize,
148    /// Temperature
149    pub temperature: f32,
150    /// カスタムシステムプロンプト(テンプレート変数: {query}, {candidates}, {world_state})
151    pub system_prompt: Option<String>,
152}
153
154impl Default for LlmDeciderConfig {
155    fn default() -> Self {
156        Self {
157            model: "qwen2.5-coder:1.5b".to_string(),
158            endpoint: "http://localhost:11434".to_string(),
159            timeout_ms: 5000,
160            max_batch_size: 100,
161            temperature: 0.1,
162            system_prompt: None,
163        }
164    }
165}
166
167#[cfg(test)]
168mod tests {
169    use super::*;
170
171    #[test]
172    fn test_llm_error_transient() {
173        let err = LlmError::transient("connection timeout");
174        assert!(err.is_transient());
175        assert_eq!(err.message(), "connection timeout");
176        assert_eq!(
177            format!("{}", err),
178            "LLM error (transient): connection timeout"
179        );
180    }
181
182    #[test]
183    fn test_llm_error_permanent() {
184        let err = LlmError::permanent("invalid model");
185        assert!(!err.is_transient());
186        assert_eq!(err.message(), "invalid model");
187    }
188
189    #[test]
190    fn test_llm_decider_config_default() {
191        let config = LlmDeciderConfig::default();
192        assert_eq!(config.model, "qwen2.5-coder:1.5b");
193        assert_eq!(config.endpoint, "http://localhost:11434");
194        assert_eq!(config.timeout_ms, 5000);
195        assert_eq!(config.max_batch_size, 100);
196        assert!((config.temperature - 0.1).abs() < 0.001);
197    }
198}