1use crate::agent::{Agent, EndStrategy, InstrumentationSettings, RegisteredTool, ToolExecutor};
46use crate::context::{RunContext, UsageLimits};
47use crate::errors::OutputValidationError;
48use crate::history::HistoryProcessor;
49use crate::instructions::{
50 AsyncInstructionFn, AsyncSystemPromptFn, InstructionFn, SyncInstructionFn, SyncSystemPromptFn,
51 SystemPromptFn,
52};
53use crate::output::{
54 DefaultOutputSchema, JsonOutputSchema, OutputSchema, OutputValidator, SyncValidator,
55 ToolOutputSchema,
56};
57use serde::de::DeserializeOwned;
58use serde_json::Value as JsonValue;
59use serdes_ai_core::ModelSettings;
60use serdes_ai_models::{Model, ModelError};
61use serdes_ai_tools::{ToolDefinition, ToolError, ToolReturn};
62use std::future::Future;
63use std::marker::PhantomData;
64use std::sync::Arc;
65use std::time::Duration;
66
67#[derive(Debug, Clone)]
94pub struct ModelConfig {
95 pub spec: String,
97 pub api_key: Option<String>,
99 pub base_url: Option<String>,
101 pub timeout: Option<Duration>,
103}
104
105impl ModelConfig {
106 #[must_use]
116 pub fn new(spec: impl Into<String>) -> Self {
117 Self {
118 spec: spec.into(),
119 api_key: None,
120 base_url: None,
121 timeout: None,
122 }
123 }
124
125 #[must_use]
127 pub fn with_api_key(mut self, api_key: impl Into<String>) -> Self {
128 self.api_key = Some(api_key.into());
129 self
130 }
131
132 #[must_use]
134 pub fn with_base_url(mut self, base_url: impl Into<String>) -> Self {
135 self.base_url = Some(base_url.into());
136 self
137 }
138
139 #[must_use]
141 pub fn with_timeout(mut self, timeout: Duration) -> Self {
142 self.timeout = Some(timeout);
143 self
144 }
145
146 fn parse_spec(&self) -> (&str, &str) {
148 if self.spec.contains(':') {
149 let parts: Vec<&str> = self.spec.splitn(2, ':').collect();
150 (parts[0], parts[1])
151 } else {
152 ("openai", self.spec.as_str())
153 }
154 }
155
156 pub fn build_model(&self) -> Result<Arc<dyn Model>, ModelError> {
175 if self.api_key.is_none() && self.base_url.is_none() && self.timeout.is_none() {
177 return serdes_ai_models::infer_model(&self.spec);
178 }
179
180 let (provider, model_name) = self.parse_spec();
182
183 self.build_model_with_config(provider, model_name)
187 }
188
189 fn build_model_with_config(
190 &self,
191 provider: &str,
192 model_name: &str,
193 ) -> Result<Arc<dyn Model>, ModelError> {
194 serdes_ai_models::build_model_with_config(
196 provider,
197 model_name,
198 self.api_key.as_deref(),
199 self.base_url.as_deref(),
200 self.timeout,
201 )
202 }
203}
204
205pub struct AgentBuilder<Deps = (), Output = String> {
207 model: Arc<dyn Model>,
208 name: Option<String>,
209 model_settings: ModelSettings,
210 instructions: Vec<String>,
211 instruction_fns: Vec<Box<dyn InstructionFn<Deps>>>,
212 system_prompts: Vec<String>,
213 system_prompt_fns: Vec<Box<dyn SystemPromptFn<Deps>>>,
214 tools: Vec<RegisteredTool<Deps>>,
215 output_schema: Option<Box<dyn OutputSchema<Output>>>,
216 output_validators: Vec<Box<dyn OutputValidator<Output, Deps>>>,
217 end_strategy: EndStrategy,
218 max_output_retries: u32,
219 max_tool_retries: u32,
220 usage_limits: Option<UsageLimits>,
221 history_processors: Vec<Box<dyn HistoryProcessor<Deps>>>,
222 instrument: Option<InstrumentationSettings>,
223 parallel_tool_calls: bool,
224 max_concurrent_tools: Option<usize>,
225 _phantom: PhantomData<(Deps, Output)>,
226}
227
228impl<Deps, Output> AgentBuilder<Deps, Output>
229where
230 Deps: Send + Sync + 'static,
231 Output: Send + Sync + 'static,
232{
233 pub fn new<M: Model + 'static>(model: M) -> Self {
250 Self::from_arc(Arc::new(model))
251 }
252
253 pub fn from_arc(model: Arc<dyn Model>) -> Self {
270 Self {
271 model,
272 name: None,
273 model_settings: ModelSettings::default(),
274 instructions: Vec::new(),
275 instruction_fns: Vec::new(),
276 system_prompts: Vec::new(),
277 system_prompt_fns: Vec::new(),
278 tools: Vec::new(),
279 output_schema: None,
280 output_validators: Vec::new(),
281 end_strategy: EndStrategy::Early,
282 max_output_retries: 3,
283 max_tool_retries: 3,
284 usage_limits: None,
285 history_processors: Vec::new(),
286 instrument: None,
287 parallel_tool_calls: true,
288 max_concurrent_tools: None,
289 _phantom: PhantomData,
290 }
291 }
292
293 pub fn from_model(spec: impl Into<String>) -> Result<Self, ModelError> {
323 let config = ModelConfig::new(spec);
324 Self::from_config(config)
325 }
326
327 pub fn from_config(config: ModelConfig) -> Result<Self, ModelError> {
350 let model = config.build_model()?;
351 Ok(Self::from_arc(model))
352 }
353
354 #[must_use]
356 pub fn name(mut self, name: impl Into<String>) -> Self {
357 self.name = Some(name.into());
358 self
359 }
360
361 #[must_use]
363 pub fn model_settings(mut self, settings: ModelSettings) -> Self {
364 self.model_settings = settings;
365 self
366 }
367
368 #[must_use]
370 pub fn temperature(mut self, temp: f64) -> Self {
371 self.model_settings = self.model_settings.temperature(temp);
372 self
373 }
374
375 #[must_use]
377 pub fn max_tokens(mut self, tokens: u64) -> Self {
378 self.model_settings = self.model_settings.max_tokens(tokens);
379 self
380 }
381
382 #[must_use]
384 pub fn top_p(mut self, p: f64) -> Self {
385 self.model_settings = self.model_settings.top_p(p);
386 self
387 }
388
389 #[must_use]
391 pub fn instructions(mut self, instructions: impl Into<String>) -> Self {
392 self.instructions.push(instructions.into());
393 self
394 }
395
396 #[must_use]
398 pub fn instructions_fn<F, Fut>(mut self, f: F) -> Self
399 where
400 F: Fn(&RunContext<Deps>) -> Fut + Send + Sync + 'static,
401 Fut: Future<Output = Option<String>> + Send + 'static,
402 {
403 self.instruction_fns
404 .push(Box::new(AsyncInstructionFn::new(f)));
405 self
406 }
407
408 #[must_use]
410 pub fn instructions_fn_sync<F>(mut self, f: F) -> Self
411 where
412 F: Fn(&RunContext<Deps>) -> Option<String> + Send + Sync + 'static,
413 {
414 self.instruction_fns
415 .push(Box::new(SyncInstructionFn::new(f)));
416 self
417 }
418
419 #[must_use]
421 pub fn system_prompt(mut self, prompt: impl Into<String>) -> Self {
422 self.system_prompts.push(prompt.into());
423 self
424 }
425
426 #[must_use]
428 pub fn system_prompt_fn<F, Fut>(mut self, f: F) -> Self
429 where
430 F: Fn(&RunContext<Deps>) -> Fut + Send + Sync + 'static,
431 Fut: Future<Output = Option<String>> + Send + 'static,
432 {
433 self.system_prompt_fns
434 .push(Box::new(AsyncSystemPromptFn::new(f)));
435 self
436 }
437
438 #[must_use]
440 pub fn system_prompt_fn_sync<F>(mut self, f: F) -> Self
441 where
442 F: Fn(&RunContext<Deps>) -> Option<String> + Send + Sync + 'static,
443 {
444 self.system_prompt_fns
445 .push(Box::new(SyncSystemPromptFn::new(f)));
446 self
447 }
448
449 #[must_use]
451 pub fn tool_with_executor<E>(mut self, definition: ToolDefinition, executor: E) -> Self
452 where
453 E: ToolExecutor<Deps> + 'static,
454 {
455 self.tools.push(RegisteredTool {
456 definition,
457 executor: Arc::new(executor),
458 max_retries: self.max_tool_retries,
459 });
460 self
461 }
462
463 #[must_use]
465 pub fn tool_fn<F, Args>(
466 mut self,
467 name: impl Into<String>,
468 description: impl Into<String>,
469 f: F,
470 ) -> Self
471 where
472 F: Fn(&RunContext<Deps>, Args) -> Result<ToolReturn, ToolError> + Send + Sync + 'static,
473 Args: DeserializeOwned + Send + 'static,
474 {
475 let tool_name = name.into();
476 let definition = ToolDefinition::new(tool_name.clone(), description.into());
477
478 let executor = SyncFnExecutor {
479 func: Arc::new(move |ctx, args: JsonValue| {
480 let parsed: Args = serde_json::from_value(args)
481 .map_err(|e| ToolError::invalid_arguments(tool_name.clone(), e.to_string()))?;
482 f(ctx, parsed)
483 }),
484 _phantom: PhantomData,
485 };
486
487 self.tools.push(RegisteredTool {
488 definition,
489 executor: Arc::new(executor),
490 max_retries: self.max_tool_retries,
491 });
492 self
493 }
494
495 #[must_use]
497 pub fn tool_fn_async<F, Fut, Args>(
498 mut self,
499 name: impl Into<String>,
500 description: impl Into<String>,
501 f: F,
502 ) -> Self
503 where
504 F: Fn(&RunContext<Deps>, Args) -> Fut + Send + Sync + 'static,
505 Fut: Future<Output = Result<ToolReturn, ToolError>> + Send + Sync + 'static,
506 Args: DeserializeOwned + Send + Sync + 'static,
507 {
508 let tool_name = name.into();
509 let definition = ToolDefinition::new(tool_name.clone(), description.into());
510
511 let executor = AsyncFnExecutor {
512 func: Arc::new(f),
513 tool_name,
514 _phantom: PhantomData,
515 };
516
517 self.tools.push(RegisteredTool {
518 definition,
519 executor: Arc::new(executor),
520 max_retries: self.max_tool_retries,
521 });
522 self
523 }
524
525 #[must_use]
527 pub fn output_schema<S: OutputSchema<Output> + 'static>(mut self, schema: S) -> Self {
528 self.output_schema = Some(Box::new(schema));
529 self
530 }
531
532 #[must_use]
534 pub fn output_validator<V: OutputValidator<Output, Deps> + 'static>(
535 mut self,
536 validator: V,
537 ) -> Self {
538 self.output_validators.push(Box::new(validator));
539 self
540 }
541
542 #[must_use]
544 pub fn output_validator_fn<F>(mut self, f: F) -> Self
545 where
546 F: Fn(Output, &RunContext<Deps>) -> Result<Output, OutputValidationError>
547 + Send
548 + Sync
549 + 'static,
550 {
551 self.output_validators.push(Box::new(SyncValidator::new(f)));
552 self
553 }
554
555 #[must_use]
557 pub fn end_strategy(mut self, strategy: EndStrategy) -> Self {
558 self.end_strategy = strategy;
559 self
560 }
561
562 #[must_use]
564 pub fn max_output_retries(mut self, retries: u32) -> Self {
565 self.max_output_retries = retries;
566 self
567 }
568
569 #[must_use]
571 pub fn max_tool_retries(mut self, retries: u32) -> Self {
572 self.max_tool_retries = retries;
573 self
574 }
575
576 #[must_use]
578 pub fn usage_limits(mut self, limits: UsageLimits) -> Self {
579 self.usage_limits = Some(limits);
580 self
581 }
582
583 #[must_use]
585 pub fn history_processor<P: HistoryProcessor<Deps> + 'static>(mut self, processor: P) -> Self {
586 self.history_processors.push(Box::new(processor));
587 self
588 }
589
590 #[must_use]
592 pub fn instrument(mut self, settings: InstrumentationSettings) -> Self {
593 self.instrument = Some(settings);
594 self
595 }
596
597 #[must_use]
604 pub fn parallel_tool_calls(mut self, enabled: bool) -> Self {
605 self.parallel_tool_calls = enabled;
606 self
607 }
608
609 #[must_use]
616 pub fn max_concurrent_tools(mut self, max: usize) -> Self {
617 self.max_concurrent_tools = Some(max);
618 self
619 }
620
621 pub fn build(self) -> Agent<Deps, Output>
623 where
624 Output: serde::de::DeserializeOwned,
625 {
626 let output_schema = self
627 .output_schema
628 .unwrap_or_else(|| Box::new(DefaultOutputSchema::<Output>::new()));
629
630 let static_system_prompt = {
633 let mut parts = Vec::new();
634
635 for prompt in &self.system_prompts {
637 if !prompt.is_empty() {
638 parts.push(prompt.as_str());
639 }
640 }
641
642 for instruction in &self.instructions {
644 if !instruction.is_empty() {
645 parts.push(instruction.as_str());
646 }
647 }
648
649 Arc::from(parts.join("\n\n"))
650 };
651
652 let cached_tool_defs = Arc::new(
655 self.tools
656 .iter()
657 .map(|t| t.definition.clone())
658 .collect::<Vec<_>>(),
659 );
660
661 Agent {
662 model: self.model,
663 name: self.name,
664 model_settings: self.model_settings,
665 static_system_prompt,
666 instruction_fns: self.instruction_fns,
667 system_prompt_fns: self.system_prompt_fns,
668 tools: self.tools,
669 cached_tool_defs,
670 output_schema,
671 output_validators: self.output_validators,
672 end_strategy: self.end_strategy,
673 max_output_retries: self.max_output_retries,
674 max_tool_retries: self.max_tool_retries,
675 usage_limits: self.usage_limits,
676 history_processors: self.history_processors,
677 instrument: self.instrument,
678 parallel_tool_calls: self.parallel_tool_calls,
679 max_concurrent_tools: self.max_concurrent_tools,
680 _phantom: PhantomData,
681 }
682 }
683}
684
685impl<Deps: Send + Sync + 'static> AgentBuilder<Deps, String> {
688 #[must_use]
690 pub fn output_type<T: DeserializeOwned + Send + Sync + 'static>(self) -> AgentBuilder<Deps, T> {
691 AgentBuilder {
692 model: self.model,
693 name: self.name,
694 model_settings: self.model_settings,
695 instructions: self.instructions,
696 instruction_fns: self.instruction_fns,
697 system_prompts: self.system_prompts,
698 system_prompt_fns: self.system_prompt_fns,
699 tools: self.tools,
700 output_schema: Some(Box::new(JsonOutputSchema::<T>::new())),
701 output_validators: Vec::new(),
702 end_strategy: self.end_strategy,
703 max_output_retries: self.max_output_retries,
704 max_tool_retries: self.max_tool_retries,
705 usage_limits: self.usage_limits,
706 history_processors: self.history_processors,
707 instrument: self.instrument,
708 parallel_tool_calls: self.parallel_tool_calls,
709 max_concurrent_tools: self.max_concurrent_tools,
710 _phantom: PhantomData,
711 }
712 }
713
714 #[must_use]
716 pub fn output_type_with_schema<T: DeserializeOwned + Send + Sync + 'static>(
717 self,
718 schema: JsonValue,
719 ) -> AgentBuilder<Deps, T> {
720 AgentBuilder {
721 model: self.model,
722 name: self.name,
723 model_settings: self.model_settings,
724 instructions: self.instructions,
725 instruction_fns: self.instruction_fns,
726 system_prompts: self.system_prompts,
727 system_prompt_fns: self.system_prompt_fns,
728 tools: self.tools,
729 output_schema: Some(Box::new(JsonOutputSchema::<T>::new().with_schema(schema))),
730 output_validators: Vec::new(),
731 end_strategy: self.end_strategy,
732 max_output_retries: self.max_output_retries,
733 max_tool_retries: self.max_tool_retries,
734 usage_limits: self.usage_limits,
735 history_processors: self.history_processors,
736 instrument: self.instrument,
737 parallel_tool_calls: self.parallel_tool_calls,
738 max_concurrent_tools: self.max_concurrent_tools,
739 _phantom: PhantomData,
740 }
741 }
742
743 #[must_use]
745 pub fn output_tool<T: DeserializeOwned + Send + Sync + 'static>(
746 self,
747 tool_name: impl Into<String>,
748 schema: JsonValue,
749 ) -> AgentBuilder<Deps, T> {
750 AgentBuilder {
751 model: self.model,
752 name: self.name,
753 model_settings: self.model_settings,
754 instructions: self.instructions,
755 instruction_fns: self.instruction_fns,
756 system_prompts: self.system_prompts,
757 system_prompt_fns: self.system_prompt_fns,
758 tools: self.tools,
759 output_schema: Some(Box::new(
760 ToolOutputSchema::<T>::new(tool_name).with_schema(schema),
761 )),
762 output_validators: Vec::new(),
763 end_strategy: self.end_strategy,
764 max_output_retries: self.max_output_retries,
765 max_tool_retries: self.max_tool_retries,
766 usage_limits: self.usage_limits,
767 history_processors: self.history_processors,
768 instrument: self.instrument,
769 parallel_tool_calls: self.parallel_tool_calls,
770 max_concurrent_tools: self.max_concurrent_tools,
771 _phantom: PhantomData,
772 }
773 }
774}
775
776#[allow(clippy::type_complexity)]
782struct SyncFnExecutor<Deps> {
783 func: Arc<dyn Fn(&RunContext<Deps>, JsonValue) -> Result<ToolReturn, ToolError> + Send + Sync>,
784 _phantom: PhantomData<Deps>,
785}
786
787#[async_trait::async_trait]
788impl<Deps: Send + Sync> ToolExecutor<Deps> for SyncFnExecutor<Deps> {
789 async fn execute(
790 &self,
791 args: JsonValue,
792 ctx: &RunContext<Deps>,
793 ) -> Result<ToolReturn, ToolError> {
794 (self.func)(ctx, args)
795 }
796}
797
798struct AsyncFnExecutor<F, Deps, Args, Fut>
800where
801 F: Fn(&RunContext<Deps>, Args) -> Fut + Send + Sync,
802 Fut: Future<Output = Result<ToolReturn, ToolError>> + Send,
803 Args: DeserializeOwned + Send,
804{
805 func: Arc<F>,
806 tool_name: String,
807 _phantom: PhantomData<(Deps, Args, Fut)>,
808}
809
810#[async_trait::async_trait]
811impl<F, Deps, Args, Fut> ToolExecutor<Deps> for AsyncFnExecutor<F, Deps, Args, Fut>
812where
813 F: Fn(&RunContext<Deps>, Args) -> Fut + Send + Sync,
814 Fut: Future<Output = Result<ToolReturn, ToolError>> + Send + Sync,
815 Args: DeserializeOwned + Send + Sync,
816 Deps: Send + Sync,
817{
818 async fn execute(
819 &self,
820 args: JsonValue,
821 ctx: &RunContext<Deps>,
822 ) -> Result<ToolReturn, ToolError> {
823 let parsed: Args = serde_json::from_value(args)
824 .map_err(|e| ToolError::invalid_arguments(self.tool_name.clone(), e.to_string()))?;
825 (self.func)(ctx, parsed).await
826 }
827}
828
829pub fn agent<M: Model + 'static>(model: M) -> AgentBuilder<(), String> {
831 AgentBuilder::new(model)
832}
833
834pub fn agent_with_deps<Deps: Send + Sync + 'static, M: Model + 'static>(
836 model: M,
837) -> AgentBuilder<Deps, String> {
838 AgentBuilder::new(model)
839}
840
841#[cfg(test)]
842mod tests {
843 use super::*;
844 use serdes_ai_models::MockModel;
845
846 fn create_mock_model() -> MockModel {
847 MockModel::new("test-model")
848 }
849
850 #[test]
851 fn test_builder_basic() {
852 let model = create_mock_model();
853 let agent = AgentBuilder::<(), String>::new(model)
854 .name("test-agent")
855 .temperature(0.7)
856 .build();
857
858 assert_eq!(agent.name(), Some("test-agent"));
859 assert_eq!(agent.model_settings().temperature, Some(0.7));
860 }
861
862 #[test]
863 fn test_builder_with_instructions() {
864 let model = create_mock_model();
865 let agent = AgentBuilder::<(), String>::new(model)
866 .system_prompt("You are helpful.")
867 .instructions("Be concise.")
868 .build();
869
870 assert!(agent.static_system_prompt.contains("You are helpful."));
872 assert!(agent.static_system_prompt.contains("Be concise."));
873 }
874
875 #[test]
876 fn test_builder_with_tool() {
877 let model = create_mock_model();
878 let agent = AgentBuilder::<(), String>::new(model)
879 .tool_fn(
880 "greet",
881 "Greet someone",
882 |_ctx: &RunContext<()>, args: serde_json::Value| {
883 let name = args["name"].as_str().unwrap_or("World");
884 Ok(ToolReturn::text(format!("Hello, {}!", name)))
885 },
886 )
887 .build();
888
889 assert_eq!(agent.tools.len(), 1);
890 assert_eq!(agent.tools[0].definition.name, "greet");
891 }
892
893 #[test]
894 fn test_builder_usage_limits() {
895 let model = create_mock_model();
896 let agent = AgentBuilder::<(), String>::new(model)
897 .usage_limits(UsageLimits::new().total_tokens(1000).requests(10))
898 .build();
899
900 let limits = agent.usage_limits().unwrap();
901 assert_eq!(limits.max_total_tokens, Some(1000));
902 assert_eq!(limits.max_requests, Some(10));
903 }
904
905 #[test]
906 fn test_builder_end_strategy() {
907 let model = create_mock_model();
908 let agent = AgentBuilder::<(), String>::new(model)
909 .end_strategy(EndStrategy::Exhaustive)
910 .build();
911
912 assert_eq!(agent.end_strategy, EndStrategy::Exhaustive);
913 }
914
915 #[test]
916 fn test_agent_convenience() {
917 let model = create_mock_model();
918 let agent = agent(model).name("quick-agent").build();
919
920 assert_eq!(agent.name(), Some("quick-agent"));
921 }
922
923 #[test]
924 fn test_builder_parallel_tool_calls_default() {
925 let model = create_mock_model();
926 let agent = AgentBuilder::<(), String>::new(model).build();
927
928 assert!(agent.parallel_tool_calls());
930 assert!(agent.max_concurrent_tools().is_none());
931 }
932
933 #[test]
934 fn test_builder_parallel_tool_calls_disabled() {
935 let model = create_mock_model();
936 let agent = AgentBuilder::<(), String>::new(model)
937 .parallel_tool_calls(false)
938 .build();
939
940 assert!(!agent.parallel_tool_calls());
941 }
942
943 #[test]
944 fn test_builder_max_concurrent_tools() {
945 let model = create_mock_model();
946 let agent = AgentBuilder::<(), String>::new(model)
947 .max_concurrent_tools(4)
948 .build();
949
950 assert!(agent.parallel_tool_calls());
951 assert_eq!(agent.max_concurrent_tools(), Some(4));
952 }
953
954 #[test]
955 fn test_builder_parallel_config_preserved_on_output_type() {
956 let model = create_mock_model();
957 let agent: Agent<(), serde_json::Value> = AgentBuilder::<(), String>::new(model)
958 .parallel_tool_calls(false)
959 .max_concurrent_tools(2)
960 .output_type()
961 .build();
962
963 assert!(!agent.parallel_tool_calls());
965 assert_eq!(agent.max_concurrent_tools(), Some(2));
966 }
967
968 #[test]
969 fn test_builder_from_arc() {
970 let model = create_mock_model();
971 let arc_model: Arc<dyn Model> = Arc::new(model);
972 let agent = AgentBuilder::<(), String>::from_arc(arc_model)
973 .name("arc-agent")
974 .build();
975
976 assert_eq!(agent.name(), Some("arc-agent"));
977 }
978
979 #[test]
980 fn test_model_config_basic() {
981 let config = ModelConfig::new("openai:gpt-4o");
982 assert_eq!(config.spec, "openai:gpt-4o");
983 assert!(config.api_key.is_none());
984 assert!(config.base_url.is_none());
985 assert!(config.timeout.is_none());
986 }
987
988 #[test]
989 fn test_model_config_with_options() {
990 let config = ModelConfig::new("anthropic:claude-3-5-sonnet-20241022")
991 .with_api_key("sk-test-key")
992 .with_base_url("https://custom.api.com")
993 .with_timeout(Duration::from_secs(60));
994
995 assert_eq!(config.spec, "anthropic:claude-3-5-sonnet-20241022");
996 assert_eq!(config.api_key, Some("sk-test-key".to_string()));
997 assert_eq!(config.base_url, Some("https://custom.api.com".to_string()));
998 assert_eq!(config.timeout, Some(Duration::from_secs(60)));
999 }
1000
1001 #[test]
1002 fn test_model_config_parse_spec_with_provider() {
1003 let config = ModelConfig::new("openai:gpt-4o");
1004 let (provider, model) = config.parse_spec();
1005 assert_eq!(provider, "openai");
1006 assert_eq!(model, "gpt-4o");
1007 }
1008
1009 #[test]
1010 fn test_model_config_parse_spec_without_provider() {
1011 let config = ModelConfig::new("gpt-4o");
1012 let (provider, model) = config.parse_spec();
1013 assert_eq!(provider, "openai");
1014 assert_eq!(model, "gpt-4o");
1015 }
1016
1017 #[test]
1018 fn test_model_config_parse_spec_anthropic() {
1019 let config = ModelConfig::new("anthropic:claude-3-5-sonnet-20241022");
1020 let (provider, model) = config.parse_spec();
1021 assert_eq!(provider, "anthropic");
1022 assert_eq!(model, "claude-3-5-sonnet-20241022");
1023 }
1024
1025 #[test]
1026 fn test_model_config_unknown_provider() {
1027 let config = ModelConfig::new("unknown:some-model");
1028 let result = config.build_model();
1029 assert!(result.is_err());
1030 match result {
1032 Err(e) => {
1033 let msg = e.to_string();
1034 assert!(
1035 msg.contains("Unknown") || msg.contains("unsupported"),
1036 "Expected error about unknown provider, got: {}",
1037 msg
1038 );
1039 }
1040 Ok(_) => panic!("Expected error for unknown provider"),
1041 }
1042 }
1043}