1use std::sync::{Arc, OnceLock};
12
13use super::{
14 CaptureProfile, ConfidenceRetryStrategy, ExtractionSchema, HtmlDiffMode, ModelEndpoint,
15 ModelPolicy, PlanningModeConfig, PromptUrlGate, ReasoningEffort, RemoteMultimodalConfig,
16 RetryPolicy, SelfHealingConfig, SynthesisConfig, ToolCallingMode, VisionRouteMode,
17};
18
19fn default_chrome_ai_max_user_chars() -> usize {
21 6000
22}
23
24#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
51#[serde(default)]
52pub struct RemoteMultimodalConfigs {
53 pub api_url: String,
55 #[serde(skip_serializing_if = "Option::is_none")]
57 pub api_key: Option<String>,
58 pub model_name: String,
60 #[serde(skip_serializing_if = "Option::is_none")]
62 pub system_prompt: Option<String>,
63 #[serde(skip_serializing_if = "Option::is_none")]
65 pub system_prompt_extra: Option<String>,
66 #[serde(skip_serializing_if = "Option::is_none")]
68 pub user_message_extra: Option<String>,
69 pub cfg: RemoteMultimodalConfig,
71 #[serde(skip_serializing_if = "Option::is_none")]
73 pub prompt_url_gate: Option<PromptUrlGate>,
74 #[serde(skip_serializing_if = "Option::is_none")]
76 pub concurrency_limit: Option<usize>,
77 #[serde(skip_serializing_if = "Option::is_none")]
81 pub vision_model: Option<ModelEndpoint>,
82 #[serde(skip_serializing_if = "Option::is_none")]
84 pub text_model: Option<ModelEndpoint>,
85 #[serde(default)]
87 pub vision_route_mode: VisionRouteMode,
88 #[serde(default, skip_serializing_if = "Vec::is_empty")]
96 pub model_pool: Vec<ModelEndpoint>,
97 #[serde(default)]
111 pub use_chrome_ai: bool,
112 #[serde(default = "default_chrome_ai_max_user_chars")]
121 pub chrome_ai_max_user_chars: usize,
122 #[cfg(feature = "skills")]
126 #[serde(skip)]
127 pub skill_registry: Option<super::skills::SkillRegistry>,
128 #[cfg(feature = "skills_s3")]
132 #[serde(skip_serializing_if = "Option::is_none")]
133 pub s3_skill_source: Option<super::skills::S3SkillSource>,
134 #[serde(skip, default = "RemoteMultimodalConfigs::default_semaphore")]
136 pub semaphore: OnceLock<Arc<tokio::sync::Semaphore>>,
137 #[serde(skip)]
139 pub relevance_credits: Arc<std::sync::atomic::AtomicU32>,
140 #[serde(skip)]
142 pub url_prefilter_cache: Arc<dashmap::DashMap<String, bool>>,
143 #[serde(default, skip_serializing_if = "Option::is_none")]
149 pub proxies: Option<Vec<String>>,
150}
151
152impl PartialEq for RemoteMultimodalConfigs {
153 fn eq(&self, other: &Self) -> bool {
154 self.api_url == other.api_url
155 && self.api_key == other.api_key
156 && self.model_name == other.model_name
157 && self.system_prompt == other.system_prompt
158 && self.system_prompt_extra == other.system_prompt_extra
159 && self.user_message_extra == other.user_message_extra
160 && self.cfg == other.cfg
161 && self.prompt_url_gate == other.prompt_url_gate
162 && self.concurrency_limit == other.concurrency_limit
163 && self.vision_model == other.vision_model
164 && self.text_model == other.text_model
165 && self.vision_route_mode == other.vision_route_mode
166 && self.model_pool == other.model_pool
167 && self.use_chrome_ai == other.use_chrome_ai
168 && self.chrome_ai_max_user_chars == other.chrome_ai_max_user_chars
169 && self.proxies == other.proxies
170 }
172}
173
174impl Eq for RemoteMultimodalConfigs {}
175
176impl Default for RemoteMultimodalConfigs {
177 fn default() -> Self {
178 Self {
179 api_url: String::new(),
180 api_key: None,
181 model_name: String::new(),
182 system_prompt: None,
183 system_prompt_extra: None,
184 user_message_extra: None,
185 cfg: RemoteMultimodalConfig::default(),
186 prompt_url_gate: None,
187 concurrency_limit: None,
188 vision_model: None,
189 text_model: None,
190 vision_route_mode: VisionRouteMode::default(),
191 model_pool: Vec::new(),
192 use_chrome_ai: false,
193 chrome_ai_max_user_chars: default_chrome_ai_max_user_chars(),
194 #[cfg(feature = "skills")]
195 skill_registry: Some(super::skills::builtin_web_challenges()),
196 #[cfg(feature = "skills_s3")]
197 s3_skill_source: None,
198 semaphore: Self::default_semaphore(),
199 relevance_credits: Arc::new(std::sync::atomic::AtomicU32::new(0)),
200 url_prefilter_cache: Arc::new(dashmap::DashMap::new()),
201 proxies: None,
202 }
203 }
204}
205
206impl RemoteMultimodalConfigs {
207 pub fn new(api_url: impl Into<String>, model_name: impl Into<String>) -> Self {
225 Self {
226 api_url: api_url.into(),
227 model_name: model_name.into(),
228 ..Default::default()
229 }
230 }
231
232 fn default_semaphore() -> OnceLock<Arc<tokio::sync::Semaphore>> {
234 OnceLock::new()
235 }
236
237 pub fn get_or_init_semaphore(&self) -> Option<Arc<tokio::sync::Semaphore>> {
240 let n = self.concurrency_limit?;
241 if n == 0 {
242 return None;
243 }
244 Some(
245 self.semaphore
246 .get_or_init(|| Arc::new(tokio::sync::Semaphore::new(n)))
247 .clone(),
248 )
249 }
250
251 pub fn with_api_key(mut self, key: impl Into<String>) -> Self {
256 self.api_key = Some(key.into());
257 self
258 }
259
260 pub fn with_system_prompt(mut self, prompt: impl Into<String>) -> Self {
265 self.system_prompt = Some(prompt.into());
266 self
267 }
268
269 pub fn with_system_prompt_extra(mut self, extra: impl Into<String>) -> Self {
274 self.system_prompt_extra = Some(extra.into());
275 self
276 }
277
278 pub fn with_user_message_extra(mut self, extra: impl Into<String>) -> Self {
282 self.user_message_extra = Some(extra.into());
283 self
284 }
285
286 pub fn with_cfg(mut self, cfg: RemoteMultimodalConfig) -> Self {
288 self.cfg = cfg;
289 self
290 }
291
292 pub fn with_prompt_url_gate(mut self, gate: PromptUrlGate) -> Self {
294 self.prompt_url_gate = Some(gate);
295 self
296 }
297
298 pub fn with_concurrency_limit(mut self, limit: usize) -> Self {
300 self.concurrency_limit = Some(limit);
301 self
302 }
303
304 pub fn with_proxies(mut self, proxies: Option<Vec<String>>) -> Self {
320 self.proxies = proxies;
321 self
322 }
323
324 pub fn with_extra_ai_data(mut self, enabled: bool) -> Self {
326 self.cfg.extra_ai_data = enabled;
327 self
328 }
329
330 pub fn with_extraction_prompt(mut self, prompt: impl Into<String>) -> Self {
332 self.cfg.extraction_prompt = Some(prompt.into());
333 self
334 }
335
336 pub fn with_screenshot(mut self, enabled: bool) -> Self {
338 self.cfg.screenshot = enabled;
339 self
340 }
341
342 pub fn with_extraction_schema(mut self, schema: ExtractionSchema) -> Self {
344 self.cfg.extraction_schema = Some(schema);
345 self
346 }
347
348 pub fn model_supports_vision(&self) -> bool {
352 super::supports_vision(&self.model_name)
353 }
354
355 pub fn should_include_screenshot(&self) -> bool {
362 match self.cfg.include_screenshot {
363 Some(explicit) => explicit,
364 None => self.model_supports_vision(),
365 }
366 }
367
368 pub fn filter_screenshot<'a>(&self, screenshot: Option<&'a str>) -> Option<&'a str> {
373 if self.should_include_screenshot() {
374 screenshot
375 } else {
376 None
377 }
378 }
379
380 pub fn with_vision_model(mut self, endpoint: ModelEndpoint) -> Self {
384 self.vision_model = Some(endpoint);
385 self
386 }
387
388 pub fn with_text_model(mut self, endpoint: ModelEndpoint) -> Self {
390 self.text_model = Some(endpoint);
391 self
392 }
393
394 pub fn with_vision_route_mode(mut self, mode: VisionRouteMode) -> Self {
396 self.vision_route_mode = mode;
397 self
398 }
399
400 pub fn with_dual_models(mut self, vision: ModelEndpoint, text: ModelEndpoint) -> Self {
402 self.vision_model = Some(vision);
403 self.text_model = Some(text);
404 self
405 }
406
407 pub fn with_model_pool(mut self, pool: Vec<ModelEndpoint>) -> Self {
413 self.model_pool = pool;
414 self
415 }
416
417 #[cfg(feature = "skills_s3")]
421 pub fn with_s3_skill_source(mut self, source: super::skills::S3SkillSource) -> Self {
422 self.s3_skill_source = Some(source);
423 self
424 }
425
426 pub fn with_relevance_gate(mut self, prompt: Option<String>) -> Self {
428 self.cfg.relevance_gate = true;
429 self.cfg.relevance_prompt = prompt;
430 self
431 }
432
433 pub fn with_url_prefilter(mut self, batch_size: Option<usize>) -> Self {
436 self.cfg.url_prefilter = true;
437 if let Some(bs) = batch_size {
438 self.cfg.url_prefilter_batch_size = bs;
439 }
440 self
441 }
442
443 pub fn with_chrome_ai(mut self, enabled: bool) -> Self {
452 self.use_chrome_ai = enabled;
453 self
454 }
455
456 pub fn with_chrome_ai_max_user_chars(mut self, chars: usize) -> Self {
458 self.chrome_ai_max_user_chars = chars;
459 self
460 }
461
462 pub fn should_use_chrome_ai(&self) -> bool {
467 self.use_chrome_ai || (self.api_url.is_empty() && self.api_key.is_none())
468 }
469
470 pub fn with_automation_timeout_ms(mut self, ms: u64) -> Self {
475 self.cfg.automation_timeout_ms = Some(ms);
476 self
477 }
478
479 pub fn with_api_url(mut self, url: impl Into<String>) -> Self {
483 self.api_url = url.into();
484 self
485 }
486
487 pub fn with_model_name(mut self, name: impl Into<String>) -> Self {
489 self.model_name = name.into();
490 self
491 }
492
493 #[cfg(feature = "skills")]
495 pub fn with_skill_registry(mut self, registry: super::skills::SkillRegistry) -> Self {
496 self.skill_registry = Some(registry);
497 self
498 }
499
500 pub fn with_include_html(mut self, include: bool) -> Self {
508 self.cfg.include_html = include;
509 self
510 }
511
512 pub fn with_html_max_bytes(mut self, bytes: usize) -> Self {
514 self.cfg.html_max_bytes = bytes;
515 self
516 }
517
518 pub fn with_include_url(mut self, include: bool) -> Self {
520 self.cfg.include_url = include;
521 self
522 }
523
524 pub fn with_include_title(mut self, include: bool) -> Self {
526 self.cfg.include_title = include;
527 self
528 }
529
530 pub fn with_include_screenshot(mut self, include: Option<bool>) -> Self {
536 self.cfg.include_screenshot = include;
537 self
538 }
539
540 pub fn with_temperature(mut self, temp: f32) -> Self {
542 self.cfg.temperature = temp;
543 self
544 }
545
546 pub fn with_max_tokens(mut self, tokens: u16) -> Self {
548 self.cfg.max_tokens = tokens;
549 self
550 }
551
552 pub fn with_request_json_object(mut self, enabled: bool) -> Self {
554 self.cfg.request_json_object = enabled;
555 self
556 }
557
558 pub fn with_best_effort_json_extract(mut self, enabled: bool) -> Self {
560 self.cfg.best_effort_json_extract = enabled;
561 self
562 }
563
564 pub fn with_reasoning_effort(mut self, effort: ReasoningEffort) -> Self {
566 self.cfg.reasoning_effort = Some(effort);
567 self
568 }
569
570 pub fn with_thinking_budget(mut self, budget: u32) -> Self {
572 self.cfg.thinking_budget = Some(budget);
573 self
574 }
575
576 pub fn with_max_rounds(mut self, rounds: usize) -> Self {
578 self.cfg.max_rounds = rounds;
579 self
580 }
581
582 pub fn with_retry(mut self, retry: RetryPolicy) -> Self {
584 self.cfg.retry = retry;
585 self
586 }
587
588 pub fn with_capture_profile(mut self, profile: CaptureProfile) -> Self {
590 self.cfg.capture_profiles.push(profile);
591 self
592 }
593
594 pub fn with_model_policy(mut self, policy: ModelPolicy) -> Self {
596 self.cfg.model_policy = policy;
597 self
598 }
599
600 pub fn with_post_plan_wait_ms(mut self, ms: u64) -> Self {
602 self.cfg.post_plan_wait_ms = ms;
603 self
604 }
605
606 pub fn with_max_inflight_requests(mut self, max: usize) -> Self {
608 self.cfg.max_inflight_requests = Some(max);
609 self
610 }
611
612 pub fn with_tool_calling_mode(mut self, mode: ToolCallingMode) -> Self {
614 self.cfg.tool_calling_mode = mode;
615 self
616 }
617
618 pub fn with_html_diff_mode(mut self, mode: HtmlDiffMode) -> Self {
620 self.cfg.html_diff_mode = mode;
621 self
622 }
623
624 pub fn with_planning_mode(mut self, config: PlanningModeConfig) -> Self {
626 self.cfg.planning_mode = Some(config);
627 self
628 }
629
630 pub fn with_synthesis_config(mut self, config: SynthesisConfig) -> Self {
632 self.cfg.synthesis_config = Some(config);
633 self
634 }
635
636 pub fn with_confidence_strategy(mut self, strategy: ConfidenceRetryStrategy) -> Self {
638 self.cfg.confidence_strategy = Some(strategy);
639 self
640 }
641
642 pub fn with_self_healing(mut self, config: SelfHealingConfig) -> Self {
644 self.cfg.self_healing = Some(config);
645 self
646 }
647
648 pub fn with_concurrent_execution(mut self, enabled: bool) -> Self {
650 self.cfg.concurrent_execution = enabled;
651 self
652 }
653
654 pub fn with_max_skills_per_round(mut self, max: usize) -> Self {
656 self.cfg.max_skills_per_round = max;
657 self
658 }
659
660 pub fn with_max_skill_context_chars(mut self, max: usize) -> Self {
662 self.cfg.max_skill_context_chars = max;
663 self
664 }
665
666 pub fn automation_timeout(&self) -> Option<std::time::Duration> {
668 self.cfg
669 .automation_timeout_ms
670 .map(std::time::Duration::from_millis)
671 }
672
673 pub fn has_dual_model_routing(&self) -> bool {
676 self.vision_model.is_some() || self.text_model.is_some()
677 }
678
679 pub fn resolve_model_for_round(&self, use_vision: bool) -> (&str, &str, Option<&str>) {
687 let endpoint = if use_vision {
688 self.vision_model.as_ref()
689 } else {
690 self.text_model.as_ref()
691 };
692
693 match endpoint {
694 Some(ep) => {
695 let url = ep.api_url.as_deref().unwrap_or(&self.api_url);
696 let key = ep.api_key.as_deref().or(self.api_key.as_deref());
697 (url, &ep.model_name, key)
698 }
699 None => (&self.api_url, &self.model_name, self.api_key.as_deref()),
700 }
701 }
702
703 pub fn should_use_vision_this_round(
708 &self,
709 round_idx: usize,
710 stagnated: bool,
711 action_stuck_rounds: usize,
712 force_vision: bool,
713 ) -> bool {
714 if !self.has_dual_model_routing() {
715 return true; }
717 if force_vision {
718 return true;
719 }
720 match self.vision_route_mode {
721 VisionRouteMode::AlwaysPrimary => true,
722 VisionRouteMode::TextFirst => round_idx == 0 || stagnated || action_stuck_rounds >= 3,
723 VisionRouteMode::VisionFirst => round_idx < 2 || stagnated || action_stuck_rounds >= 3,
724 VisionRouteMode::AgentDriven => false,
725 }
726 }
727}
728
729#[cfg(test)]
730mod tests {
731 use super::*;
732
733 #[test]
734 fn test_remote_multimodal_configs_new() {
735 let configs = RemoteMultimodalConfigs::new(
736 "http://localhost:11434/v1/chat/completions",
737 "qwen2.5-vl",
738 );
739
740 assert_eq!(
741 configs.api_url,
742 "http://localhost:11434/v1/chat/completions"
743 );
744 assert_eq!(configs.model_name, "qwen2.5-vl");
745 assert!(configs.api_key.is_none());
746 assert!(configs.system_prompt.is_none());
747 }
748
749 #[test]
750 fn test_remote_multimodal_configs_builder() {
751 let configs =
752 RemoteMultimodalConfigs::new("https://api.openai.com/v1/chat/completions", "gpt-4o")
753 .with_api_key("sk-test")
754 .with_system_prompt("You are a helpful assistant.")
755 .with_concurrency_limit(5)
756 .with_screenshot(true);
757
758 assert_eq!(configs.api_key, Some("sk-test".to_string()));
759 assert_eq!(
760 configs.system_prompt,
761 Some("You are a helpful assistant.".to_string())
762 );
763 assert_eq!(configs.concurrency_limit, Some(5));
764 assert!(configs.cfg.screenshot);
765 }
766
767 #[test]
768 fn test_remote_multimodal_configs_vision_detection() {
769 let cfg =
771 RemoteMultimodalConfigs::new("https://api.openai.com/v1/chat/completions", "gpt-4o");
772 assert!(cfg.model_supports_vision());
773 assert!(cfg.should_include_screenshot());
774
775 let cfg = RemoteMultimodalConfigs::new(
777 "https://api.openai.com/v1/chat/completions",
778 "gpt-3.5-turbo",
779 );
780 assert!(!cfg.model_supports_vision());
781 assert!(!cfg.should_include_screenshot());
782
783 let mut cfg = RemoteMultimodalConfigs::new(
785 "https://api.openai.com/v1/chat/completions",
786 "gpt-3.5-turbo",
787 );
788 cfg.cfg.include_screenshot = Some(true);
789 assert!(cfg.should_include_screenshot());
790
791 let mut cfg =
793 RemoteMultimodalConfigs::new("https://api.openai.com/v1/chat/completions", "gpt-4o");
794 cfg.cfg.include_screenshot = Some(false);
795 assert!(!cfg.should_include_screenshot());
796 }
797
798 #[test]
799 fn test_filter_screenshot() {
800 let screenshot = "base64data...";
801
802 let cfg =
804 RemoteMultimodalConfigs::new("https://api.openai.com/v1/chat/completions", "gpt-4o");
805 assert_eq!(cfg.filter_screenshot(Some(screenshot)), Some(screenshot));
806
807 let cfg = RemoteMultimodalConfigs::new(
809 "https://api.openai.com/v1/chat/completions",
810 "gpt-3.5-turbo",
811 );
812 assert_eq!(cfg.filter_screenshot(Some(screenshot)), None);
813
814 let cfg =
816 RemoteMultimodalConfigs::new("https://api.openai.com/v1/chat/completions", "gpt-4o");
817 assert_eq!(cfg.filter_screenshot(None), None);
818 }
819
820 #[test]
821 fn test_has_dual_model_routing() {
822 let cfg = RemoteMultimodalConfigs::new("https://api.example.com", "gpt-4o");
824 assert!(!cfg.has_dual_model_routing());
825
826 let cfg = RemoteMultimodalConfigs::new("https://api.example.com", "gpt-4o")
828 .with_vision_model(ModelEndpoint::new("gpt-4o"));
829 assert!(cfg.has_dual_model_routing());
830
831 let cfg = RemoteMultimodalConfigs::new("https://api.example.com", "gpt-4o")
833 .with_text_model(ModelEndpoint::new("gpt-4o-mini"));
834 assert!(cfg.has_dual_model_routing());
835
836 let cfg = RemoteMultimodalConfigs::new("https://api.example.com", "gpt-4o")
838 .with_dual_models(
839 ModelEndpoint::new("gpt-4o"),
840 ModelEndpoint::new("gpt-4o-mini"),
841 );
842 assert!(cfg.has_dual_model_routing());
843 }
844
845 #[test]
846 fn test_resolve_model_for_round_no_routing() {
847 let cfg = RemoteMultimodalConfigs::new("https://api.example.com", "gpt-4o")
848 .with_api_key("sk-parent");
849
850 let (url, model, key) = cfg.resolve_model_for_round(true);
852 assert_eq!(url, "https://api.example.com");
853 assert_eq!(model, "gpt-4o");
854 assert_eq!(key, Some("sk-parent"));
855
856 let (url, model, key) = cfg.resolve_model_for_round(false);
857 assert_eq!(url, "https://api.example.com");
858 assert_eq!(model, "gpt-4o");
859 assert_eq!(key, Some("sk-parent"));
860 }
861
862 #[test]
863 fn test_resolve_model_for_round_dual() {
864 let cfg = RemoteMultimodalConfigs::new("https://api.example.com", "gpt-4o")
865 .with_api_key("sk-parent")
866 .with_dual_models(
867 ModelEndpoint::new("gpt-4o"),
868 ModelEndpoint::new("gpt-4o-mini"),
869 );
870
871 let (url, model, key) = cfg.resolve_model_for_round(true);
873 assert_eq!(model, "gpt-4o");
874 assert_eq!(url, "https://api.example.com");
875 assert_eq!(key, Some("sk-parent"));
876
877 let (url, model, key) = cfg.resolve_model_for_round(false);
879 assert_eq!(model, "gpt-4o-mini");
880 assert_eq!(url, "https://api.example.com");
881 assert_eq!(key, Some("sk-parent"));
882 }
883
884 #[test]
885 fn test_resolve_model_cross_provider() {
886 let cfg =
888 RemoteMultimodalConfigs::new("https://api.openai.com/v1/chat/completions", "gpt-4o")
889 .with_api_key("sk-openai")
890 .with_vision_model(ModelEndpoint::new("gpt-4o"))
891 .with_text_model(
892 ModelEndpoint::new("llama-3.3-70b-versatile")
893 .with_api_url("https://api.groq.com/openai/v1/chat/completions")
894 .with_api_key("gsk-groq"),
895 );
896
897 let (url, model, key) = cfg.resolve_model_for_round(true);
899 assert_eq!(url, "https://api.openai.com/v1/chat/completions");
900 assert_eq!(model, "gpt-4o");
901 assert_eq!(key, Some("sk-openai"));
902
903 let (url, model, key) = cfg.resolve_model_for_round(false);
905 assert_eq!(url, "https://api.groq.com/openai/v1/chat/completions");
906 assert_eq!(model, "llama-3.3-70b-versatile");
907 assert_eq!(key, Some("gsk-groq"));
908 }
909
910 #[test]
911 fn test_vision_route_mode_always_primary() {
912 let cfg = RemoteMultimodalConfigs::new("https://api.example.com", "gpt-4o")
913 .with_dual_models(
914 ModelEndpoint::new("gpt-4o"),
915 ModelEndpoint::new("gpt-4o-mini"),
916 )
917 .with_vision_route_mode(VisionRouteMode::AlwaysPrimary);
918
919 assert!(cfg.should_use_vision_this_round(0, false, 0, false));
921 assert!(cfg.should_use_vision_this_round(5, false, 0, false));
922 assert!(cfg.should_use_vision_this_round(10, false, 0, false));
923 }
924
925 #[test]
926 fn test_vision_route_mode_text_first() {
927 let cfg = RemoteMultimodalConfigs::new("https://api.example.com", "gpt-4o")
928 .with_dual_models(
929 ModelEndpoint::new("gpt-4o"),
930 ModelEndpoint::new("gpt-4o-mini"),
931 )
932 .with_vision_route_mode(VisionRouteMode::TextFirst);
933
934 assert!(cfg.should_use_vision_this_round(0, false, 0, false));
936 assert!(!cfg.should_use_vision_this_round(1, false, 0, false));
938 assert!(!cfg.should_use_vision_this_round(5, false, 0, false));
939 assert!(cfg.should_use_vision_this_round(3, true, 0, false));
941 assert!(cfg.should_use_vision_this_round(5, false, 3, false));
943 assert!(cfg.should_use_vision_this_round(5, false, 0, true));
945 }
946
947 #[test]
948 fn test_vision_route_mode_vision_first() {
949 let cfg = RemoteMultimodalConfigs::new("https://api.example.com", "gpt-4o")
950 .with_dual_models(
951 ModelEndpoint::new("gpt-4o"),
952 ModelEndpoint::new("gpt-4o-mini"),
953 )
954 .with_vision_route_mode(VisionRouteMode::VisionFirst);
955
956 assert!(cfg.should_use_vision_this_round(0, false, 0, false));
958 assert!(cfg.should_use_vision_this_round(1, false, 0, false));
959 assert!(!cfg.should_use_vision_this_round(2, false, 0, false));
961 assert!(!cfg.should_use_vision_this_round(5, false, 0, false));
962 assert!(cfg.should_use_vision_this_round(5, true, 0, false));
964 assert!(cfg.should_use_vision_this_round(5, false, 3, false));
966 }
967
968 #[test]
969 fn test_no_dual_routing_always_returns_true() {
970 let cfg = RemoteMultimodalConfigs::new("https://api.example.com", "gpt-4o");
973 assert!(!cfg.has_dual_model_routing());
974 assert!(cfg.should_use_vision_this_round(0, false, 0, false));
975 assert!(cfg.should_use_vision_this_round(5, false, 0, false));
976 assert!(cfg.should_use_vision_this_round(99, false, 0, false));
977 }
978
979 #[test]
980 fn test_with_dual_models_builder() {
981 let cfg = RemoteMultimodalConfigs::new("https://api.example.com", "primary")
982 .with_dual_models(
983 ModelEndpoint::new("vision-model"),
984 ModelEndpoint::new("text-model"),
985 )
986 .with_vision_route_mode(VisionRouteMode::TextFirst);
987
988 assert!(cfg.has_dual_model_routing());
989 assert_eq!(
990 cfg.vision_model.as_ref().unwrap().model_name,
991 "vision-model"
992 );
993 assert_eq!(cfg.text_model.as_ref().unwrap().model_name, "text-model");
994 assert_eq!(cfg.vision_route_mode, VisionRouteMode::TextFirst);
995 }
996
997 #[test]
998 fn test_configs_serde_with_dual_models() {
999 let cfg = RemoteMultimodalConfigs::new("https://api.example.com", "gpt-4o")
1000 .with_api_key("sk-test")
1001 .with_dual_models(
1002 ModelEndpoint::new("gpt-4o"),
1003 ModelEndpoint::new("gpt-4o-mini")
1004 .with_api_url("https://other.api.com")
1005 .with_api_key("sk-other"),
1006 )
1007 .with_vision_route_mode(VisionRouteMode::TextFirst);
1008
1009 let json = serde_json::to_string(&cfg).unwrap();
1010 let deserialized: RemoteMultimodalConfigs = serde_json::from_str(&json).unwrap();
1011
1012 assert_eq!(deserialized.model_name, "gpt-4o");
1013 assert!(deserialized.has_dual_model_routing());
1014 assert_eq!(
1015 deserialized.vision_model.as_ref().unwrap().model_name,
1016 "gpt-4o"
1017 );
1018 assert_eq!(
1019 deserialized.text_model.as_ref().unwrap().model_name,
1020 "gpt-4o-mini"
1021 );
1022 assert_eq!(
1023 deserialized.text_model.as_ref().unwrap().api_url.as_deref(),
1024 Some("https://other.api.com")
1025 );
1026 assert_eq!(deserialized.vision_route_mode, VisionRouteMode::TextFirst);
1027 }
1028
1029 #[cfg(feature = "skills")]
1030 #[test]
1031 fn test_default_configs_auto_load_builtin_skills() {
1032 let cfg = RemoteMultimodalConfigs::default();
1033 let registry = cfg
1034 .skill_registry
1035 .as_ref()
1036 .expect("default config should auto-load built-in skills");
1037 assert!(
1038 registry.get("image-grid-selection").is_some(),
1039 "expected image-grid-selection built-in skill"
1040 );
1041 assert!(
1042 registry.get("tic-tac-toe").is_some(),
1043 "expected tic-tac-toe built-in skill"
1044 );
1045 }
1046
1047 #[cfg(feature = "skills")]
1048 #[test]
1049 fn test_new_configs_auto_load_builtin_skills() {
1050 let cfg = RemoteMultimodalConfigs::new("https://api.example.com", "model");
1051 let registry = cfg
1052 .skill_registry
1053 .as_ref()
1054 .expect("new config should auto-load built-in skills");
1055 assert!(
1056 registry.get("word-search").is_some(),
1057 "expected word-search built-in skill"
1058 );
1059 }
1060
1061 #[test]
1064 fn test_selector_to_dual_model_config() {
1065 use super::super::router::{ModelRequirements, ModelSelector, SelectionStrategy};
1066
1067 let mut selector = ModelSelector::new(&["gpt-4o", "gpt-4o-mini", "gpt-3.5-turbo"]);
1068
1069 let vision_reqs = ModelRequirements::default().with_vision();
1071 let vision_pick = selector
1072 .select(&vision_reqs)
1073 .expect("should find a vision model");
1074
1075 selector.set_strategy(SelectionStrategy::CheapestFirst);
1077 let text_reqs = ModelRequirements::default();
1078 let text_pick = selector
1079 .select(&text_reqs)
1080 .expect("should find a text model");
1081
1082 let cfg = RemoteMultimodalConfigs::new("https://api.example.com", &vision_pick.name)
1084 .with_dual_models(
1085 ModelEndpoint::new(&vision_pick.name),
1086 ModelEndpoint::new(&text_pick.name),
1087 )
1088 .with_vision_route_mode(VisionRouteMode::TextFirst);
1089
1090 let (_, model, _) = cfg.resolve_model_for_round(true);
1092 assert_eq!(
1093 model, vision_pick.name,
1094 "vision round should use vision pick"
1095 );
1096
1097 let (_, model, _) = cfg.resolve_model_for_round(false);
1098 assert_eq!(model, text_pick.name, "text round should use text pick");
1099 }
1100
1101 #[test]
1102 fn test_auto_policy_to_configs_round_trip() {
1103 use super::super::router::auto_policy;
1104
1105 let policy = auto_policy(&["gpt-4o", "gpt-4o-mini", "gpt-3.5-turbo"]);
1106
1107 let cfg = RemoteMultimodalConfigs::new("https://api.example.com", &policy.large)
1109 .with_dual_models(
1110 ModelEndpoint::new(&policy.large),
1111 ModelEndpoint::new(&policy.small),
1112 );
1113
1114 let json = serde_json::to_string(&cfg).unwrap();
1116 let deserialized: RemoteMultimodalConfigs = serde_json::from_str(&json).unwrap();
1117
1118 let (_, vision_model, _) = deserialized.resolve_model_for_round(true);
1120 let (_, text_model, _) = deserialized.resolve_model_for_round(false);
1121 assert_eq!(vision_model, policy.large);
1122 assert_eq!(text_model, policy.small);
1123 }
1124
1125 #[test]
1126 fn test_vision_routing_with_real_capabilities() {
1127 let cfg = RemoteMultimodalConfigs::new("https://api.example.com", "gpt-4o")
1129 .with_dual_models(
1130 ModelEndpoint::new("gpt-4o"), ModelEndpoint::new("gpt-3.5-turbo"), )
1133 .with_vision_route_mode(VisionRouteMode::TextFirst);
1134
1135 assert!(cfg.should_use_vision_this_round(0, false, 0, false));
1137 let (_, model, _) = cfg.resolve_model_for_round(true);
1138 assert_eq!(model, "gpt-4o");
1139 assert!(
1140 llm_models_spider::supports_vision(model),
1141 "vision-round model should support vision"
1142 );
1143
1144 assert!(!cfg.should_use_vision_this_round(3, false, 0, false));
1146 let (_, model, _) = cfg.resolve_model_for_round(false);
1147 assert_eq!(model, "gpt-3.5-turbo");
1148 assert!(
1149 !llm_models_spider::supports_vision(model),
1150 "text-round model should NOT support vision"
1151 );
1152 }
1153
1154 #[test]
1155 fn test_single_model_config_e2e() {
1156 use super::super::router::auto_policy;
1157
1158 let policy = auto_policy(&["gpt-4o"]);
1160
1161 let cfg = RemoteMultimodalConfigs::new(
1163 "https://api.openai.com/v1/chat/completions",
1164 &policy.large,
1165 )
1166 .with_api_key("sk-test");
1167
1168 assert!(!cfg.has_dual_model_routing());
1170
1171 let (url, model, key) = cfg.resolve_model_for_round(true);
1173 assert_eq!(model, "gpt-4o");
1174 assert_eq!(key, Some("sk-test"));
1175
1176 let (url2, model2, key2) = cfg.resolve_model_for_round(false);
1177 assert_eq!(url, url2, "single model: same URL for both modes");
1178 assert_eq!(model, model2, "single model: same model for both modes");
1179 assert_eq!(key, key2, "single model: same key for both modes");
1180
1181 let json = serde_json::to_string(&cfg).unwrap();
1183 let deserialized: RemoteMultimodalConfigs = serde_json::from_str(&json).unwrap();
1184 assert!(!deserialized.has_dual_model_routing());
1185 let (_, m, _) = deserialized.resolve_model_for_round(true);
1186 assert_eq!(m, "gpt-4o");
1187 }
1188
1189 #[test]
1190 fn test_model_resolution_consistency() {
1191 let cfg = RemoteMultimodalConfigs::new("https://api.example.com", "gpt-4o")
1192 .with_api_key("sk-test")
1193 .with_dual_models(
1194 ModelEndpoint::new("gpt-4o"),
1195 ModelEndpoint::new("gpt-4o-mini")
1196 .with_api_url("https://other.api.com")
1197 .with_api_key("sk-other"),
1198 );
1199
1200 for _ in 0..100 {
1202 let (url, model, key) = cfg.resolve_model_for_round(true);
1203 assert_eq!(url, "https://api.example.com");
1204 assert_eq!(model, "gpt-4o");
1205 assert_eq!(key, Some("sk-test"));
1206
1207 let (url, model, key) = cfg.resolve_model_for_round(false);
1208 assert_eq!(url, "https://other.api.com");
1209 assert_eq!(model, "gpt-4o-mini");
1210 assert_eq!(key, Some("sk-other"));
1211 }
1212 }
1213
1214 #[test]
1217 fn test_model_pool_default_empty() {
1218 let cfg = RemoteMultimodalConfigs::default();
1219 assert!(cfg.model_pool.is_empty());
1220 }
1221
1222 #[test]
1223 fn test_model_pool_builder() {
1224 let cfg = RemoteMultimodalConfigs::new("https://api.example.com", "gpt-4o")
1225 .with_model_pool(vec![
1226 ModelEndpoint::new("gpt-4o"),
1227 ModelEndpoint::new("gpt-4o-mini"),
1228 ModelEndpoint::new("deepseek-chat")
1229 .with_api_url("https://api.deepseek.com/v1/chat/completions")
1230 .with_api_key("sk-ds"),
1231 ]);
1232 assert_eq!(cfg.model_pool.len(), 3);
1233 assert_eq!(cfg.model_pool[2].model_name, "deepseek-chat");
1234 assert_eq!(
1235 cfg.model_pool[2].api_url.as_deref(),
1236 Some("https://api.deepseek.com/v1/chat/completions")
1237 );
1238 }
1239
1240 #[test]
1241 fn test_model_pool_serde_round_trip() {
1242 let cfg = RemoteMultimodalConfigs::new("https://api.example.com", "gpt-4o")
1243 .with_model_pool(vec![
1244 ModelEndpoint::new("gpt-4o"),
1245 ModelEndpoint::new("gpt-4o-mini"),
1246 ModelEndpoint::new("deepseek-chat"),
1247 ]);
1248
1249 let json = serde_json::to_string(&cfg).unwrap();
1250 assert!(json.contains("model_pool"));
1251 let deserialized: RemoteMultimodalConfigs = serde_json::from_str(&json).unwrap();
1252 assert_eq!(deserialized.model_pool.len(), 3);
1253 assert_eq!(deserialized.model_pool[0].model_name, "gpt-4o");
1254 assert_eq!(deserialized.model_pool[1].model_name, "gpt-4o-mini");
1255 assert_eq!(deserialized.model_pool[2].model_name, "deepseek-chat");
1256 }
1257
1258 #[test]
1259 fn test_model_pool_empty_omitted_from_json() {
1260 let cfg = RemoteMultimodalConfigs::new("https://api.example.com", "gpt-4o");
1261 let json = serde_json::to_string(&cfg).unwrap();
1262 assert!(
1263 !json.contains("model_pool"),
1264 "empty model_pool should be omitted from JSON"
1265 );
1266 }
1267
1268 #[test]
1269 fn test_cfg_convenience_builders() {
1270 use super::super::{ReasoningEffort, ToolCallingMode};
1271
1272 let cfg = RemoteMultimodalConfigs::new("https://api.example.com", "gpt-4o")
1273 .with_api_url("https://other.api.com")
1274 .with_model_name("gpt-4o-mini")
1275 .with_include_html(true)
1276 .with_html_max_bytes(10_000)
1277 .with_include_url(true)
1278 .with_include_title(true)
1279 .with_include_screenshot(Some(false))
1280 .with_temperature(0.5)
1281 .with_max_tokens(2048)
1282 .with_request_json_object(true)
1283 .with_best_effort_json_extract(true)
1284 .with_reasoning_effort(ReasoningEffort::High)
1285 .with_thinking_budget(4096)
1286 .with_max_rounds(10)
1287 .with_post_plan_wait_ms(500)
1288 .with_max_inflight_requests(8)
1289 .with_tool_calling_mode(ToolCallingMode::Auto)
1290 .with_concurrent_execution(true)
1291 .with_max_skills_per_round(5)
1292 .with_max_skill_context_chars(8000);
1293
1294 assert_eq!(cfg.api_url, "https://other.api.com");
1296 assert_eq!(cfg.model_name, "gpt-4o-mini");
1297
1298 assert!(cfg.cfg.include_html);
1300 assert_eq!(cfg.cfg.html_max_bytes, 10_000);
1301 assert!(cfg.cfg.include_url);
1302 assert!(cfg.cfg.include_title);
1303 assert_eq!(cfg.cfg.include_screenshot, Some(false));
1304 assert!((cfg.cfg.temperature - 0.5).abs() < f32::EPSILON);
1305 assert_eq!(cfg.cfg.max_tokens, 2048);
1306 assert!(cfg.cfg.request_json_object);
1307 assert!(cfg.cfg.best_effort_json_extract);
1308 assert_eq!(cfg.cfg.reasoning_effort, Some(ReasoningEffort::High));
1309 assert_eq!(cfg.cfg.thinking_budget, Some(4096));
1310 assert_eq!(cfg.cfg.max_rounds, 10);
1311 assert_eq!(cfg.cfg.post_plan_wait_ms, 500);
1312 assert_eq!(cfg.cfg.max_inflight_requests, Some(8));
1313 assert_eq!(cfg.cfg.tool_calling_mode, ToolCallingMode::Auto);
1314 assert!(cfg.cfg.concurrent_execution);
1315 assert_eq!(cfg.cfg.max_skills_per_round, 5);
1316 assert_eq!(cfg.cfg.max_skill_context_chars, 8000);
1317 }
1318
1319 #[test]
1320 fn test_model_pool_equality() {
1321 let a = RemoteMultimodalConfigs::new("https://api.example.com", "gpt-4o")
1322 .with_model_pool(vec![ModelEndpoint::new("gpt-4o")]);
1323 let b = RemoteMultimodalConfigs::new("https://api.example.com", "gpt-4o")
1324 .with_model_pool(vec![ModelEndpoint::new("gpt-4o")]);
1325 let c = RemoteMultimodalConfigs::new("https://api.example.com", "gpt-4o")
1326 .with_model_pool(vec![ModelEndpoint::new("gpt-4o-mini")]);
1327 assert_eq!(a, b);
1328 assert_ne!(a, c);
1329 }
1330}