Skip to main content

swarm_engine_llm/
strategy_advisor.rs

1//! Strategy Advisor - LLM による探索戦略アドバイス
2//!
3//! 探索の状態に基づいて最適な Selection 戦略を推奨する。
4//!
5//! # 設計
6//!
7//! ```text
8//! StrategyContext (探索状態)
9//!         │
10//!         ▼
11//! StrategyPromptBuilder.build()
12//!         │
13//!         ▼
14//! LlmDecider.call_raw() ──────── 同期ブロッキング (~100ms)
15//!         │
16//!         ▼
17//! StrategyResponseParser.parse()
18//!         │
19//!         ▼
20//! StrategyAdvice (推奨戦略)
21//! ```
22//!
23//! # 使用例
24//!
25//! ```ignore
26//! use swarm_engine_llm::strategy_advisor::{LlmStrategyAdvisor, StrategyContext};
27//! use swarm_engine_core::exploration::SelectionKind;
28//!
29//! let advisor = LlmStrategyAdvisor::new(decider, runtime);
30//! let context = StrategyContext::new(15, 47, 0.23, SelectionKind::Ucb1);
31//! let advice = advisor.advise(&context)?;
32//! ```
33
34use std::sync::Arc;
35
36use fuzzy_parser::distance::{find_closest, Algorithm};
37use fuzzy_parser::{repair_object_fields, sanitize_json, ObjectSchema};
38
39// Core 層から型をインポート
40pub use swarm_engine_core::exploration::{
41    SelectionKind, StrategyAdvice, StrategyAdviceError, StrategyAdvisor, StrategyContext,
42};
43
44use crate::decider::{LlmDecider, LlmError, LoraConfig};
45use crate::json_prompt::strategy_selection_template;
46
47// ============================================================================
48// SelectionKind 拡張 - 文字列からの fuzzy パース
49// ============================================================================
50
51/// SelectionKind の文字列からの fuzzy パース
52pub fn parse_selection_kind_fuzzy(s: &str) -> Option<SelectionKind> {
53    // 完全一致(大文字小文字無視)
54    let upper = s.to_uppercase();
55    match upper.as_str() {
56        "FIFO" => return Some(SelectionKind::Fifo),
57        "UCB1" => return Some(SelectionKind::Ucb1),
58        "GREEDY" => return Some(SelectionKind::Greedy),
59        "THOMPSON" => return Some(SelectionKind::Thompson),
60        _ => {}
61    }
62
63    // Fuzzy match
64    let candidates = ["FIFO", "UCB1", "Greedy", "Thompson"];
65    if let Some(m) = find_closest(s, candidates, 0.6, Algorithm::JaroWinkler) {
66        match m.candidate.as_str() {
67            "FIFO" => Some(SelectionKind::Fifo),
68            "UCB1" => Some(SelectionKind::Ucb1),
69            "Greedy" => Some(SelectionKind::Greedy),
70            "Thompson" => Some(SelectionKind::Thompson),
71            _ => None,
72        }
73    } else {
74        None
75    }
76}
77
78// ============================================================================
79// StrategyAdviceError 拡張
80// ============================================================================
81
82impl From<LlmError> for StrategyAdviceError {
83    fn from(e: LlmError) -> Self {
84        Self::LlmError(e.message().to_string())
85    }
86}
87
88// ============================================================================
89// StrategyPromptBuilder - プロンプト生成
90// ============================================================================
91
92/// 戦略アドバイス用プロンプトビルダー
93#[derive(Debug, Clone, Default)]
94pub struct StrategyPromptBuilder;
95
96impl StrategyPromptBuilder {
97    /// 新規作成
98    pub fn new() -> Self {
99        Self
100    }
101
102    /// プロンプトを生成
103    pub fn build(&self, ctx: &StrategyContext) -> String {
104        let depth_info = ctx
105            .avg_depth
106            .map(|d| format!(", depth={:.1}", d))
107            .unwrap_or_default();
108
109        // コンテキスト情報を構築
110        let content = format!(
111            "Strategies: FIFO, UCB1, Greedy, Thompson\n\
112             Guidelines: visits<20→UCB1, failure>30%→Thompson, established+low failure→Greedy\n\
113             User: frontier={}, visits={}, failure={:.0}%{}, current={}",
114            ctx.frontier_count,
115            ctx.total_visits,
116            ctx.failure_rate * 100.0,
117            depth_info,
118            ctx.current_strategy,
119        );
120
121        // 共通テンプレートを使用
122        strategy_selection_template().build(&content)
123    }
124}
125
126// ============================================================================
127// StrategyResponseParser - レスポンスパース
128// ============================================================================
129
130/// 戦略レスポンス用 ObjectSchema
131const STRATEGY_FIELDS: ObjectSchema =
132    ObjectSchema::new(&["strategy", "change", "confidence", "reason"]);
133
134/// 戦略アドバイス用レスポンスパーサー
135#[derive(Debug, Clone, Default)]
136pub struct StrategyResponseParser;
137
138impl StrategyResponseParser {
139    /// 新規作成
140    pub fn new() -> Self {
141        Self
142    }
143
144    /// レスポンスをパース
145    pub fn parse(&self, response: &str) -> Result<StrategyAdvice, StrategyAdviceError> {
146        // JSON を抽出
147        let json_str = self.extract_json(response)?;
148
149        // 構文修復
150        let sanitized = sanitize_json(&json_str);
151        tracing::debug!(sanitized = %sanitized, "Sanitized strategy JSON");
152
153        // パース
154        self.parse_json(&sanitized)
155    }
156
157    /// JSON を抽出(```json ブロック対応、自然言語フォールバック)
158    fn extract_json(&self, text: &str) -> Result<String, StrategyAdviceError> {
159        // ```json ... ``` ブロックを探す
160        if let Some(start) = text.find("```json") {
161            let content_start = start + 7;
162            let remaining = &text[content_start..];
163            if let Some(end) = remaining.find("```") {
164                let json = remaining[..end].trim();
165                if !json.is_empty() {
166                    return Ok(json.to_string());
167                }
168            }
169        }
170
171        // { ... } を探す(バランスを取って)
172        if let Some(json) = self.extract_balanced_json(text) {
173            return Ok(json);
174        }
175
176        // フォールバック: 自然言語から戦略キーワードを抽出して JSON 生成
177        if let Some(json) = self.extract_from_natural_language(text) {
178            tracing::debug!(fallback_json = %json, "Extracted strategy from natural language");
179            return Ok(json);
180        }
181
182        Err(StrategyAdviceError::ParseError(format!(
183            "No JSON found in response: {}",
184            text
185        )))
186    }
187
188    /// 自然言語から戦略キーワードを抽出して JSON 生成
189    fn extract_from_natural_language(&self, text: &str) -> Option<String> {
190        let text_upper = text.to_uppercase();
191
192        // 優先度順に戦略を検索(推奨を示す文脈を考慮)
193        let recommend_patterns = ["RECOMMEND", "SUGGEST", "USE ", "PREFER", "OPTIMAL", "BEST"];
194        let strategies = [
195            ("THOMPSON", "Thompson"),
196            ("UCB1", "UCB1"),
197            ("UCB", "UCB1"),
198            ("GREEDY", "Greedy"),
199            ("FIFO", "FIFO"),
200        ];
201
202        // まず推奨文脈付きの戦略を探す
203        for pattern in &recommend_patterns {
204            if let Some(pos) = text_upper.find(pattern) {
205                // パターン後の50文字以内で戦略を探す
206                let search_range = &text_upper[pos..std::cmp::min(pos + 50, text_upper.len())];
207                for (keyword, strategy) in &strategies {
208                    if search_range.contains(keyword) {
209                        return Some(format!(
210                            r#"{{"strategy":"{}","change":true,"confidence":0.6,"reason":"Extracted from natural language response"}}"#,
211                            strategy
212                        ));
213                    }
214                }
215            }
216        }
217
218        // 推奨文脈がなければ、最初に出現した戦略を使用
219        let mut first_match: Option<(usize, &str)> = None;
220        for (keyword, strategy) in &strategies {
221            if let Some(pos) = text_upper.find(keyword) {
222                if first_match.is_none() || pos < first_match.unwrap().0 {
223                    first_match = Some((pos, strategy));
224                }
225            }
226        }
227
228        first_match.map(|(_, strategy)| {
229            format!(
230                r#"{{"strategy":"{}","change":false,"confidence":0.5,"reason":"Extracted from natural language response"}}"#,
231                strategy
232            )
233        })
234    }
235
236    /// バランスの取れた JSON を抽出
237    fn extract_balanced_json(&self, text: &str) -> Option<String> {
238        let start = text.find('{')?;
239        let chars: Vec<char> = text[start..].chars().collect();
240        let mut depth = 0;
241        let mut in_string = false;
242        let mut escape_next = false;
243
244        for (i, &ch) in chars.iter().enumerate() {
245            if escape_next {
246                escape_next = false;
247                continue;
248            }
249
250            match ch {
251                '\\' if in_string => escape_next = true,
252                '"' => in_string = !in_string,
253                '{' if !in_string => depth += 1,
254                '}' if !in_string => {
255                    depth -= 1;
256                    if depth == 0 {
257                        return Some(chars[..=i].iter().collect());
258                    }
259                }
260                _ => {}
261            }
262        }
263
264        None
265    }
266
267    /// JSON をパース(fuzzy repair 対応)
268    fn parse_json(&self, json: &str) -> Result<StrategyAdvice, StrategyAdviceError> {
269        let mut parsed: serde_json::Value = serde_json::from_str(json)
270            .map_err(|e| StrategyAdviceError::ParseError(format!("JSON parse error: {}", e)))?;
271
272        // フィールド名の typo 修復
273        if let Some(obj) = parsed.as_object_mut() {
274            let corrections = repair_object_fields(obj, &STRATEGY_FIELDS, "$", &Default::default());
275            if !corrections.is_empty() {
276                tracing::debug!(
277                    corrections = ?corrections.iter().map(|c| format!("{} → {}", c.original, c.corrected)).collect::<Vec<_>>(),
278                    "Fuzzy repaired strategy field names"
279                );
280            }
281        }
282
283        // strategy フィールドをパース(fuzzy repair 対応)
284        let strategy_str = parsed["strategy"]
285            .as_str()
286            .ok_or_else(|| StrategyAdviceError::ParseError("Missing 'strategy' field".into()))?;
287
288        let recommended = parse_selection_kind_fuzzy(strategy_str).ok_or_else(|| {
289            StrategyAdviceError::ParseError(format!("Unknown strategy: {}", strategy_str))
290        })?;
291
292        let should_change = parsed["change"].as_bool().unwrap_or(false);
293        let confidence = parsed["confidence"].as_f64().unwrap_or(0.5).clamp(0.0, 1.0);
294        let reason = parsed["reason"]
295            .as_str()
296            .unwrap_or("No reason provided")
297            .to_string();
298
299        Ok(StrategyAdvice {
300            recommended,
301            should_change,
302            reason,
303            confidence,
304        })
305    }
306}
307
308// ============================================================================
309// LlmStrategyAdvisor - LLM ベースの実装
310// ============================================================================
311
312/// LLM ベースの戦略アドバイザー
313pub struct LlmStrategyAdvisor {
314    decider: Arc<dyn LlmDecider>,
315    runtime: tokio::runtime::Handle,
316    prompt_builder: StrategyPromptBuilder,
317    response_parser: StrategyResponseParser,
318    /// 信頼度閾値(これ以下のアドバイスは変更しない)
319    confidence_threshold: f64,
320    /// LoRA 設定(None の場合はベースモデルのみ)
321    lora: Option<LoraConfig>,
322}
323
324impl LlmStrategyAdvisor {
325    /// 新しい LlmStrategyAdvisor を作成
326    pub fn new(decider: Arc<dyn LlmDecider>, runtime: tokio::runtime::Handle) -> Self {
327        Self {
328            decider,
329            runtime,
330            prompt_builder: StrategyPromptBuilder::new(),
331            response_parser: StrategyResponseParser::new(),
332            confidence_threshold: 0.6,
333            lora: None,
334        }
335    }
336
337    /// 信頼度閾値を設定
338    pub fn with_confidence_threshold(mut self, threshold: f64) -> Self {
339        self.confidence_threshold = threshold.clamp(0.0, 1.0);
340        self
341    }
342
343    /// 信頼度閾値を取得
344    pub fn confidence_threshold(&self) -> f64 {
345        self.confidence_threshold
346    }
347
348    /// LoRA 設定を設定
349    pub fn with_lora(mut self, lora: LoraConfig) -> Self {
350        self.lora = Some(lora);
351        self
352    }
353
354    /// LoRA 設定を取得
355    pub fn lora(&self) -> Option<&LoraConfig> {
356        self.lora.as_ref()
357    }
358}
359
360impl StrategyAdvisor for LlmStrategyAdvisor {
361    fn advise(&self, context: &StrategyContext) -> Result<StrategyAdvice, StrategyAdviceError> {
362        // プロンプト生成
363        let prompt = self.prompt_builder.build(context);
364        tracing::debug!(prompt = %prompt, "Strategy advisor prompt");
365
366        // 同期ブロッキング呼び出し(~100ms 想定)
367        let lora = self.lora.as_ref();
368        let response = self
369            .runtime
370            .block_on(async { self.decider.call_raw(&prompt, lora).await })?;
371
372        tracing::debug!(response = %response, "Strategy advisor raw response");
373
374        // レスポンスパース
375        let mut advice = self.response_parser.parse(&response)?;
376
377        // 信頼度が閾値以下なら変更しない
378        if advice.confidence < self.confidence_threshold {
379            tracing::debug!(
380                confidence = advice.confidence,
381                threshold = self.confidence_threshold,
382                "Low confidence, not changing strategy"
383            );
384            advice.should_change = false;
385            advice.reason = format!(
386                "Low confidence ({:.2} < {:.2}): {}",
387                advice.confidence, self.confidence_threshold, advice.reason
388            );
389        }
390
391        tracing::info!(
392            recommended = %advice.recommended,
393            should_change = advice.should_change,
394            confidence = advice.confidence,
395            reason = %advice.reason,
396            "Strategy advice"
397        );
398
399        Ok(advice)
400    }
401
402    fn name(&self) -> &str {
403        "LlmStrategyAdvisor"
404    }
405}
406
407// ============================================================================
408// Tests
409// ============================================================================
410
411#[cfg(test)]
412mod tests {
413    use super::*;
414
415    // ========================================================================
416    // SelectionKind Tests
417    // ========================================================================
418
419    #[test]
420    fn test_selection_kind_display() {
421        assert_eq!(SelectionKind::Fifo.to_string(), "FIFO");
422        assert_eq!(SelectionKind::Ucb1.to_string(), "UCB1");
423        assert_eq!(SelectionKind::Greedy.to_string(), "Greedy");
424        assert_eq!(SelectionKind::Thompson.to_string(), "Thompson");
425    }
426
427    #[test]
428    fn test_selection_kind_from_str_exact() {
429        assert_eq!(
430            parse_selection_kind_fuzzy("FIFO"),
431            Some(SelectionKind::Fifo)
432        );
433        assert_eq!(
434            parse_selection_kind_fuzzy("UCB1"),
435            Some(SelectionKind::Ucb1)
436        );
437        assert_eq!(
438            parse_selection_kind_fuzzy("Greedy"),
439            Some(SelectionKind::Greedy)
440        );
441        assert_eq!(
442            parse_selection_kind_fuzzy("Thompson"),
443            Some(SelectionKind::Thompson)
444        );
445    }
446
447    #[test]
448    fn test_selection_kind_from_str_case_insensitive() {
449        assert_eq!(
450            parse_selection_kind_fuzzy("fifo"),
451            Some(SelectionKind::Fifo)
452        );
453        assert_eq!(
454            parse_selection_kind_fuzzy("ucb1"),
455            Some(SelectionKind::Ucb1)
456        );
457        assert_eq!(
458            parse_selection_kind_fuzzy("GREEDY"),
459            Some(SelectionKind::Greedy)
460        );
461        assert_eq!(
462            parse_selection_kind_fuzzy("THOMPSON"),
463            Some(SelectionKind::Thompson)
464        );
465    }
466
467    #[test]
468    fn test_selection_kind_from_str_fuzzy() {
469        // Typos should be repaired
470        assert_eq!(
471            parse_selection_kind_fuzzy("Thomspon"),
472            Some(SelectionKind::Thompson)
473        );
474        assert_eq!(
475            parse_selection_kind_fuzzy("Gredy"),
476            Some(SelectionKind::Greedy)
477        );
478    }
479
480    #[test]
481    fn test_selection_kind_from_str_invalid() {
482        assert_eq!(parse_selection_kind_fuzzy("Unknown"), None);
483        assert_eq!(parse_selection_kind_fuzzy("Random"), None);
484    }
485
486    // ========================================================================
487    // StrategyContext Tests
488    // ========================================================================
489
490    #[test]
491    fn test_strategy_context_new() {
492        let ctx = StrategyContext::new(15, 47, 0.23, SelectionKind::Ucb1);
493        assert_eq!(ctx.frontier_count, 15);
494        assert_eq!(ctx.total_visits, 47);
495        assert!((ctx.failure_rate - 0.23).abs() < 0.001);
496        assert!((ctx.success_rate - 0.77).abs() < 0.001);
497        assert_eq!(ctx.current_strategy, SelectionKind::Ucb1);
498        assert!(ctx.avg_depth.is_none());
499    }
500
501    #[test]
502    fn test_strategy_context_with_depth() {
503        let ctx = StrategyContext::new(10, 100, 0.1, SelectionKind::Greedy).with_avg_depth(3.5);
504        assert_eq!(ctx.avg_depth, Some(3.5));
505    }
506
507    // ========================================================================
508    // StrategyAdvice Tests
509    // ========================================================================
510
511    #[test]
512    fn test_strategy_advice_no_change() {
513        let advice = StrategyAdvice::no_change(SelectionKind::Ucb1, "Exploration phase");
514        assert_eq!(advice.recommended, SelectionKind::Ucb1);
515        assert!(!advice.should_change);
516        assert_eq!(advice.reason, "Exploration phase");
517        assert!((advice.confidence - 1.0).abs() < 0.001);
518    }
519
520    #[test]
521    fn test_strategy_advice_change_to() {
522        let advice = StrategyAdvice::change_to(SelectionKind::Greedy, "Patterns established", 0.85);
523        assert_eq!(advice.recommended, SelectionKind::Greedy);
524        assert!(advice.should_change);
525        assert_eq!(advice.reason, "Patterns established");
526        assert!((advice.confidence - 0.85).abs() < 0.001);
527    }
528
529    // ========================================================================
530    // StrategyPromptBuilder Tests
531    // ========================================================================
532
533    #[test]
534    fn test_prompt_builder_basic() {
535        let builder = StrategyPromptBuilder::new();
536        let ctx = StrategyContext::new(15, 47, 0.23, SelectionKind::Ucb1);
537        let prompt = builder.build(&ctx);
538
539        // Few-shot example format
540        assert!(prompt.contains("Example interaction:"));
541        assert!(prompt.contains("Your JSON:"));
542
543        // Context values
544        assert!(prompt.contains("frontier=15"));
545        assert!(prompt.contains("visits=47"));
546        assert!(prompt.contains("failure=23%"));
547        assert!(prompt.contains("current=UCB1"));
548
549        // Strategy names
550        assert!(prompt.contains("FIFO"));
551        assert!(prompt.contains("Greedy"));
552        assert!(prompt.contains("Thompson"));
553    }
554
555    #[test]
556    fn test_prompt_builder_with_depth() {
557        let builder = StrategyPromptBuilder::new();
558        let ctx = StrategyContext::new(10, 100, 0.1, SelectionKind::Greedy).with_avg_depth(3.5);
559        let prompt = builder.build(&ctx);
560
561        assert!(prompt.contains("depth=3.5"));
562    }
563
564    // ========================================================================
565    // StrategyResponseParser Tests
566    // ========================================================================
567
568    #[test]
569    fn test_parse_valid_json() {
570        let parser = StrategyResponseParser::new();
571        let response = r#"{"strategy": "Greedy", "change": true, "confidence": 0.85, "reason": "Low failure rate"}"#;
572        let advice = parser.parse(response).unwrap();
573
574        assert_eq!(advice.recommended, SelectionKind::Greedy);
575        assert!(advice.should_change);
576        assert!((advice.confidence - 0.85).abs() < 0.001);
577        assert_eq!(advice.reason, "Low failure rate");
578    }
579
580    #[test]
581    fn test_parse_json_with_prefix() {
582        let parser = StrategyResponseParser::new();
583        let response = r#"Based on the analysis: {"strategy": "Thompson", "change": true, "confidence": 0.7, "reason": "High variance"}"#;
584        let advice = parser.parse(response).unwrap();
585
586        assert_eq!(advice.recommended, SelectionKind::Thompson);
587    }
588
589    #[test]
590    fn test_parse_json_markdown_block() {
591        let parser = StrategyResponseParser::new();
592        let response = r#"```json
593{"strategy": "UCB1", "change": false, "confidence": 0.9, "reason": "Still exploring"}
594```"#;
595        let advice = parser.parse(response).unwrap();
596
597        assert_eq!(advice.recommended, SelectionKind::Ucb1);
598        assert!(!advice.should_change);
599    }
600
601    #[test]
602    fn test_parse_json_typo_repair() {
603        let parser = StrategyResponseParser::new();
604        // "straegy" typo should be repaired to "strategy"
605        let response =
606            r#"{"straegy": "Greedy", "change": true, "confidnce": 0.8, "reason": "test"}"#;
607        let advice = parser.parse(response).unwrap();
608
609        assert_eq!(advice.recommended, SelectionKind::Greedy);
610    }
611
612    #[test]
613    fn test_parse_json_strategy_typo() {
614        let parser = StrategyResponseParser::new();
615        // "Thomspon" typo should be repaired to "Thompson"
616        let response =
617            r#"{"strategy": "Thomspon", "change": true, "confidence": 0.75, "reason": "variance"}"#;
618        let advice = parser.parse(response).unwrap();
619
620        assert_eq!(advice.recommended, SelectionKind::Thompson);
621    }
622
623    #[test]
624    fn test_parse_json_defaults() {
625        let parser = StrategyResponseParser::new();
626        // Missing change and confidence should use defaults
627        let response = r#"{"strategy": "FIFO", "reason": "simple"}"#;
628        let advice = parser.parse(response).unwrap();
629
630        assert_eq!(advice.recommended, SelectionKind::Fifo);
631        assert!(!advice.should_change); // default false
632        assert!((advice.confidence - 0.5).abs() < 0.001); // default 0.5
633    }
634
635    #[test]
636    fn test_parse_json_missing_strategy() {
637        let parser = StrategyResponseParser::new();
638        let response = r#"{"change": true, "confidence": 0.8}"#;
639        let result = parser.parse(response);
640
641        assert!(result.is_err());
642        assert!(matches!(result, Err(StrategyAdviceError::ParseError(_))));
643    }
644
645    #[test]
646    fn test_parse_no_json() {
647        let parser = StrategyResponseParser::new();
648        let response = "This is just plain text without any JSON.";
649        let result = parser.parse(response);
650
651        assert!(result.is_err());
652        assert!(matches!(result, Err(StrategyAdviceError::ParseError(_))));
653    }
654
655    #[test]
656    fn test_parse_confidence_clamping() {
657        let parser = StrategyResponseParser::new();
658        // Confidence > 1.0 should be clamped
659        let response =
660            r#"{"strategy": "Greedy", "change": true, "confidence": 1.5, "reason": "test"}"#;
661        let advice = parser.parse(response).unwrap();
662        assert!((advice.confidence - 1.0).abs() < 0.001);
663
664        // Confidence < 0.0 should be clamped
665        let response =
666            r#"{"strategy": "Greedy", "change": true, "confidence": -0.5, "reason": "test"}"#;
667        let advice = parser.parse(response).unwrap();
668        assert!((advice.confidence - 0.0).abs() < 0.001);
669    }
670}