spec_ai_core/agent/
builder.rs

1//! Agent Builder
2//!
3//! Provides a fluent API for constructing agent instances.
4
5use crate::agent::core::AgentCore;
6use crate::agent::factory::{create_provider, resolve_api_key};
7use crate::agent::model::{ModelProvider, ProviderKind};
8#[cfg(feature = "openai")]
9use crate::agent::providers::openai::OpenAIProvider;
10#[cfg(feature = "lmstudio")]
11use crate::agent::providers::LMStudioProvider;
12#[cfg(feature = "mlx")]
13use crate::agent::providers::MLXProvider;
14use crate::config::{AgentProfile, AgentRegistry, AppConfig, ModelConfig};
15use crate::embeddings::EmbeddingsClient;
16use crate::persistence::Persistence;
17use crate::policy::PolicyEngine;
18use crate::tools::ToolRegistry;
19use anyhow::{anyhow, Context, Result};
20#[cfg(any(feature = "mlx", feature = "lmstudio"))]
21use async_openai::config::OpenAIConfig;
22use std::sync::Arc;
23use tracing::{info, warn};
24
25/// Builder for constructing AgentCore instances
26pub struct AgentBuilder {
27    profile: Option<AgentProfile>,
28    provider: Option<Arc<dyn ModelProvider>>,
29    embeddings_client: Option<EmbeddingsClient>,
30    persistence: Option<Persistence>,
31    session_id: Option<String>,
32    config: Option<AppConfig>,
33    tool_registry: Option<Arc<ToolRegistry>>,
34    policy_engine: Option<Arc<PolicyEngine>>,
35    agent_name: Option<String>,
36    speak_responses: bool,
37}
38
39impl AgentBuilder {
40    /// Create a new agent builder
41    pub fn new() -> Self {
42        Self {
43            profile: None,
44            provider: None,
45            embeddings_client: None,
46            persistence: None,
47            session_id: None,
48            config: None,
49            tool_registry: None,
50            policy_engine: None,
51            agent_name: None,
52            speak_responses: false,
53        }
54    }
55
56    /// Create an agent from the registry with the active profile
57    /// This is a convenience method for CLI use
58    pub fn new_with_registry(
59        registry: &AgentRegistry,
60        config: &AppConfig,
61        session_id: Option<String>,
62    ) -> Result<AgentCore> {
63        create_agent_from_registry(registry, config, session_id)
64    }
65
66    /// Set the agent profile
67    pub fn with_profile(mut self, profile: AgentProfile) -> Self {
68        self.profile = Some(profile);
69        self
70    }
71
72    /// Set the model provider
73    pub fn with_provider(mut self, provider: Arc<dyn ModelProvider>) -> Self {
74        self.provider = Some(provider);
75        self
76    }
77
78    /// Set a custom embeddings client
79    pub fn with_embeddings_client(mut self, embeddings_client: EmbeddingsClient) -> Self {
80        self.embeddings_client = Some(embeddings_client);
81        self
82    }
83
84    /// Set the persistence layer
85    pub fn with_persistence(mut self, persistence: Persistence) -> Self {
86        self.persistence = Some(persistence);
87        self
88    }
89
90    /// Set the session ID
91    pub fn with_session_id(mut self, session_id: impl Into<String>) -> Self {
92        self.session_id = Some(session_id.into());
93        self
94    }
95
96    /// Set the application configuration (used to derive defaults)
97    pub fn with_config(mut self, config: AppConfig) -> Self {
98        self.config = Some(config);
99        self
100    }
101
102    /// Set the tool registry
103    pub fn with_tool_registry(mut self, tool_registry: Arc<ToolRegistry>) -> Self {
104        self.tool_registry = Some(tool_registry);
105        self
106    }
107
108    /// Set the policy engine
109    pub fn with_policy_engine(mut self, policy_engine: Arc<PolicyEngine>) -> Self {
110        self.policy_engine = Some(policy_engine);
111        self
112    }
113
114    /// Set the logical agent name (used for telemetry/logging)
115    pub fn with_agent_name(mut self, agent_name: impl Into<String>) -> Self {
116        self.agent_name = Some(agent_name.into());
117        self
118    }
119
120    /// Enable speech-friendly responses (macOS `say`)
121    pub fn with_speak_responses(mut self, enabled: bool) -> Self {
122        self.speak_responses = enabled;
123        self
124    }
125
126    /// Build the agent, validating all required fields
127    pub fn build(self) -> Result<AgentCore> {
128        // Get profile (required)
129        let profile = self
130            .profile
131            .as_ref()
132            .cloned()
133            .ok_or_else(|| anyhow!("Agent profile is required"))?;
134
135        // Precompute session and speech settings before moving fields
136        let session_id = self
137            .session_id
138            .clone()
139            .unwrap_or_else(|| format!("session-{}", chrono::Utc::now().timestamp_millis()));
140        let speak_preference = self.resolve_speech_preference();
141        let agent_name = self.agent_name.clone();
142
143        // Get or create persistence (needed for tool registry)
144        let persistence = if let Some(persistence) = self.persistence {
145            persistence
146        } else if let Some(ref config) = self.config {
147            Persistence::new(&config.database.path).context("Failed to create persistence layer")?
148        } else {
149            return Err(anyhow!(
150                "Either persistence or config must be provided to build agent"
151            ));
152        };
153
154        // Get or create embeddings client
155        let embeddings_client = if let Some(client) = self.embeddings_client {
156            Some(client)
157        } else if let Some(ref config) = self.config {
158            create_embeddings_client_from_config(config)?
159        } else {
160            None
161        };
162
163        // Get or create tool registry (defaults to built-in tools)
164        // Create this before the provider so OpenAI can be configured with tools
165        let tool_registry = if let Some(registry) = self.tool_registry {
166            registry
167        } else {
168            let persistence_arc = Arc::new(persistence.clone());
169            let mut registry =
170                ToolRegistry::with_builtin_tools(Some(persistence_arc), embeddings_client.clone());
171            info!(
172                "Created tool registry with {} builtin tools",
173                registry.len()
174            );
175            for tool_name in registry.list() {
176                tracing::debug!("  - Registered tool: {}", tool_name);
177            }
178
179            // Load plugins if enabled
180            if let Some(ref config) = self.config {
181                if config.plugins.enabled {
182                    match registry.load_plugins(
183                        &config.plugins.custom_tools_dir,
184                        config.plugins.allow_override_builtin,
185                    ) {
186                        Ok(stats) => {
187                            if stats.loaded > 0 {
188                                info!(
189                                    "Loaded {} plugins with {} tools from {}",
190                                    stats.loaded,
191                                    stats.tools_loaded,
192                                    config.plugins.custom_tools_dir.display()
193                                );
194                            }
195                            if stats.failed > 0 {
196                                warn!("{} plugins failed to load", stats.failed);
197                            }
198                        }
199                        Err(e) => {
200                            if config.plugins.continue_on_error {
201                                warn!("Plugin loading failed (continuing): {}", e);
202                            } else {
203                                // Re-wrap as anyhow error to propagate
204                                return Err(anyhow::anyhow!("Plugin loading failed: {}", e));
205                            }
206                        }
207                    }
208                }
209            }
210
211            Arc::new(registry)
212        };
213
214        // Get or create provider with tools configured (for OpenAI-compatible providers)
215        let provider = if let Some(provider) = self.provider {
216            provider
217        } else if let Some(ref config) = self.config {
218            let mut base_provider =
219                create_provider(&config.model).context("Failed to create provider from config")?;
220
221            // Configure OpenAI provider with tools for native function calling
222            #[cfg(feature = "openai")]
223            {
224                if base_provider.kind() == ProviderKind::OpenAI {
225                    let tools = tool_registry.to_openai_tools();
226                    if !tools.is_empty() {
227                        info!(
228                            "Configuring OpenAI provider with {} tools for native function calling",
229                            tools.len()
230                        );
231
232                        // Recreate OpenAI provider with tools
233                        let api_key = if let Some(source) = &config.model.api_key_source {
234                            resolve_api_key(source)?
235                        } else {
236                            // Default to OPENAI_API_KEY environment variable
237                            std::env::var("OPENAI_API_KEY")
238                                .context("OPENAI_API_KEY environment variable not set")?
239                        };
240
241                        let mut openai_provider = OpenAIProvider::with_api_key(api_key);
242
243                        // Set model if specified in config
244                        if let Some(model_name) = &config.model.model_name {
245                            openai_provider = openai_provider.with_model(model_name.clone());
246                        }
247
248                        // Configure with tools and cast to trait object
249                        base_provider = Arc::new(openai_provider.with_tools(tools));
250                    }
251                }
252            }
253
254            // Configure MLX provider with tools for native function calling (OpenAI-compatible API)
255            #[cfg(feature = "mlx")]
256            {
257                if base_provider.kind() == ProviderKind::MLX {
258                    let tools = tool_registry.to_openai_tools();
259                    if !tools.is_empty() {
260                        info!(
261                            "Configuring MLX provider with {} tools for native function calling",
262                            tools.len()
263                        );
264
265                        // MLX requires a model name; mirror create_provider's behavior
266                        let model_name = config
267                            .model
268                            .model_name
269                            .as_ref()
270                            .ok_or_else(|| {
271                                anyhow!("MLX provider requires a model_name to be specified")
272                            })?
273                            .clone();
274
275                        let mlx_provider = if let Ok(endpoint) = std::env::var("MLX_ENDPOINT") {
276                            MLXProvider::with_endpoint(endpoint, model_name)
277                        } else {
278                            MLXProvider::new(model_name)
279                        };
280
281                        base_provider = Arc::new(mlx_provider.with_tools(tools));
282                    }
283                }
284            }
285
286            #[cfg(feature = "lmstudio")]
287            {
288                if base_provider.kind() == ProviderKind::LMStudio {
289                    let tools = tool_registry.to_openai_tools();
290                    if !tools.is_empty() {
291                        info!(
292                            "Configuring LM Studio provider with {} tools for native function calling",
293                            tools.len()
294                        );
295
296                        let model_name = config
297                            .model
298                            .model_name
299                            .as_ref()
300                            .ok_or_else(|| {
301                                anyhow!("LM Studio provider requires a model_name to be specified")
302                            })?
303                            .clone();
304
305                        let lmstudio_provider =
306                            if let Ok(endpoint) = std::env::var("LMSTUDIO_ENDPOINT") {
307                                LMStudioProvider::with_endpoint(endpoint, model_name)
308                            } else {
309                                LMStudioProvider::new(model_name)
310                            };
311
312                        base_provider = Arc::new(lmstudio_provider.with_tools(tools));
313                    }
314                }
315            }
316
317            base_provider
318        } else {
319            return Err(anyhow!(
320                "Either provider or config must be provided to build agent"
321            ));
322        };
323
324        // Get or create policy engine (defaults to empty policy engine, or load from persistence)
325        let policy_engine = if let Some(engine) = self.policy_engine {
326            engine
327        } else {
328            // Try to load from persistence, or create empty engine with default allow rule
329            let mut engine = PolicyEngine::load_from_persistence(&persistence)
330                .unwrap_or_else(|_| PolicyEngine::new());
331
332            // If the policy engine has no rules at all, add a default allow-all for tools
333            if engine.rule_count() == 0 {
334                tracing::debug!(
335                    "Empty policy engine detected, adding default allow-all rule for tools"
336                );
337                engine.add_rule(crate::policy::PolicyRule {
338                    agent: "*".to_string(),
339                    action: "tool_call".to_string(),
340                    resource: "*".to_string(),
341                    effect: crate::policy::PolicyEffect::Allow,
342                });
343            }
344
345            Arc::new(engine)
346        };
347
348        let fast_provider = if profile.fast_reasoning {
349            match (&profile.fast_model_provider, &profile.fast_model_name) {
350                (Some(provider_name), Some(model_name)) => {
351                    let fast_config = ModelConfig {
352                        provider: provider_name.clone(),
353                        model_name: Some(model_name.clone()),
354                        embeddings_model: None,
355                        api_key_source: None,
356                        temperature: profile.fast_model_temperature,
357                    };
358                    match create_provider(&fast_config) {
359                        Ok(provider) => Some(provider),
360                        Err(err) => {
361                            warn!(
362                                "Failed to create fast provider {}:{} - {}",
363                                provider_name, model_name, err
364                            );
365                            None
366                        }
367                    }
368                }
369                _ => None,
370            }
371        } else {
372            None
373        };
374
375        let mut agent = AgentCore::new(
376            profile,
377            provider,
378            embeddings_client,
379            persistence,
380            session_id,
381            agent_name,
382            tool_registry,
383            policy_engine,
384            speak_preference,
385        );
386
387        if let Some(fast_provider) = fast_provider {
388            agent = agent.with_fast_provider(fast_provider);
389        }
390
391        Ok(agent)
392    }
393
394    fn resolve_speech_preference(&self) -> bool {
395        let config_pref = self
396            .config
397            .as_ref()
398            .map(|c| c.audio.speak_responses)
399            .unwrap_or(false);
400        let requested = self.speak_responses || config_pref;
401
402        #[cfg(target_os = "macos")]
403        {
404            requested
405        }
406
407        #[cfg(not(target_os = "macos"))]
408        {
409            let _ = requested; // silence unused warning when speech is unavailable
410            false
411        }
412    }
413}
414
415impl Default for AgentBuilder {
416    fn default() -> Self {
417        Self::new()
418    }
419}
420
421/// Create an agent from the active profile in the registry
422pub fn create_agent_from_registry(
423    registry: &AgentRegistry,
424    config: &AppConfig,
425    session_id: Option<String>,
426) -> Result<AgentCore> {
427    let (agent_name, profile) = registry
428        .active()
429        .context("No active agent profile in registry")?
430        .ok_or_else(|| anyhow!("No active agent set in registry"))?;
431
432    let mut builder = AgentBuilder::new()
433        .with_profile(profile)
434        .with_config(config.clone())
435        .with_persistence(registry.persistence().clone())
436        .with_agent_name(agent_name.clone());
437
438    if let Some(sid) = session_id {
439        builder = builder.with_session_id(sid);
440    }
441
442    builder.build()
443}
444
445fn create_embeddings_client_from_config(config: &AppConfig) -> Result<Option<EmbeddingsClient>> {
446    let model = &config.model;
447    let Some(model_name) = &model.embeddings_model else {
448        return Ok(None);
449    };
450
451    #[cfg(feature = "mlx")]
452    {
453        if ProviderKind::from_str(&model.provider) == Some(ProviderKind::MLX) {
454            return Ok(Some(build_mlx_embeddings_client(model_name)));
455        }
456    }
457
458    #[cfg(feature = "lmstudio")]
459    {
460        if ProviderKind::from_str(&model.provider) == Some(ProviderKind::LMStudio) {
461            return Ok(Some(build_lmstudio_embeddings_client(model_name)));
462        }
463    }
464
465    let client = if let Some(source) = &model.api_key_source {
466        let api_key = resolve_api_key(source)?;
467        EmbeddingsClient::with_api_key(model_name.clone(), api_key)
468    } else {
469        EmbeddingsClient::new(model_name.clone())
470    };
471
472    Ok(Some(client))
473}
474
475#[cfg(feature = "mlx")]
476fn build_mlx_embeddings_client(model_name: &str) -> EmbeddingsClient {
477    let endpoint =
478        std::env::var("MLX_ENDPOINT").unwrap_or_else(|_| "http://localhost:10240".to_string());
479    let api_base = if endpoint.ends_with("/v1") {
480        endpoint
481    } else {
482        format!("{}/v1", endpoint)
483    };
484
485    let config = OpenAIConfig::new()
486        .with_api_base(api_base)
487        .with_api_key("mlx-key");
488
489    EmbeddingsClient::with_config(model_name.to_string(), config)
490}
491
492#[cfg(feature = "lmstudio")]
493fn build_lmstudio_embeddings_client(model_name: &str) -> EmbeddingsClient {
494    let endpoint =
495        std::env::var("LMSTUDIO_ENDPOINT").unwrap_or_else(|_| "http://localhost:1234".to_string());
496    let api_base = if endpoint.ends_with("/v1") {
497        endpoint
498    } else {
499        format!("{}/v1", endpoint)
500    };
501
502    let config = OpenAIConfig::new()
503        .with_api_base(api_base)
504        .with_api_key("lmstudio-key");
505
506    EmbeddingsClient::with_config(model_name.to_string(), config)
507}
508
509#[cfg(test)]
510mod tests {
511    use super::*;
512    use crate::agent::providers::MockProvider;
513    use crate::config::{
514        AgentProfile, AudioConfig, DatabaseConfig, LoggingConfig, ModelConfig, PluginConfig,
515        SyncConfig, UiConfig,
516    };
517    use std::collections::HashMap;
518    use tempfile::tempdir;
519
520    fn create_test_config() -> AppConfig {
521        let dir = tempdir().unwrap();
522        let db_path = dir.path().join("test.duckdb");
523
524        AppConfig {
525            database: DatabaseConfig { path: db_path },
526            model: ModelConfig {
527                provider: "mock".to_string(),
528                model_name: Some("test-model".to_string()),
529                embeddings_model: None,
530                api_key_source: None,
531                temperature: 0.7,
532            },
533            ui: UiConfig {
534                prompt: "> ".to_string(),
535                theme: "default".to_string(),
536            },
537            logging: LoggingConfig {
538                level: "info".to_string(),
539            },
540            audio: AudioConfig::default(),
541            mesh: crate::config::MeshConfig::default(),
542            plugins: PluginConfig::default(),
543            sync: SyncConfig::default(),
544            agents: HashMap::new(),
545            default_agent: None,
546        }
547    }
548
549    fn create_test_profile() -> AgentProfile {
550        AgentProfile {
551            prompt: Some("Test system prompt".to_string()),
552            style: None,
553            temperature: Some(0.8),
554            model_provider: None,
555            model_name: None,
556            allowed_tools: None,
557            denied_tools: None,
558            memory_k: 10,
559            top_p: 0.95,
560            max_context_tokens: Some(4096),
561            enable_graph: false,
562            graph_memory: false,
563            auto_graph: false,
564            graph_steering: false,
565            graph_depth: 3,
566            graph_weight: 0.5,
567            graph_threshold: 0.7,
568            fast_reasoning: false,
569            fast_model_provider: None,
570            fast_model_name: None,
571            fast_model_temperature: 0.3,
572            fast_model_tasks: vec![],
573            escalation_threshold: 0.6,
574            show_reasoning: false,
575            enable_audio_transcription: false,
576            audio_response_mode: "immediate".to_string(),
577            audio_scenario: None,
578        }
579    }
580
581    #[test]
582    fn test_builder_with_all_fields() {
583        let dir = tempdir().unwrap();
584        let db_path = dir.path().join("test.duckdb");
585        let persistence = Persistence::new(&db_path).unwrap();
586
587        let profile = create_test_profile();
588        let provider = Arc::new(MockProvider::default());
589
590        let agent = AgentBuilder::new()
591            .with_profile(profile)
592            .with_provider(provider)
593            .with_persistence(persistence)
594            .with_session_id("test-session")
595            .build()
596            .unwrap();
597
598        assert_eq!(agent.session_id(), "test-session");
599        assert_eq!(
600            agent.profile().prompt,
601            Some("Test system prompt".to_string())
602        );
603    }
604
605    #[test]
606    fn test_builder_with_config() {
607        let config = create_test_config();
608        let profile = create_test_profile();
609
610        let agent = AgentBuilder::new()
611            .with_profile(profile)
612            .with_config(config)
613            .build()
614            .unwrap();
615
616        // Should auto-generate session ID with timestamp
617        assert!(agent.session_id().starts_with("session-"));
618    }
619
620    #[test]
621    fn test_builder_missing_profile() {
622        let config = create_test_config();
623
624        let result = AgentBuilder::new().with_config(config).build();
625
626        assert!(result.is_err());
627        assert!(result.err().unwrap().to_string().contains("profile"));
628    }
629
630    #[test]
631    fn test_builder_missing_provider_and_config() {
632        let dir = tempdir().unwrap();
633        let db_path = dir.path().join("test.duckdb");
634        let persistence = Persistence::new(&db_path).unwrap();
635
636        let profile = create_test_profile();
637
638        let result = AgentBuilder::new()
639            .with_profile(profile)
640            .with_persistence(persistence)
641            .build();
642
643        assert!(result.is_err());
644        assert!(result
645            .err()
646            .unwrap()
647            .to_string()
648            .contains("provider or config"));
649    }
650
651    #[test]
652    fn test_builder_auto_session_id() {
653        let config = create_test_config();
654        let profile = create_test_profile();
655
656        let agent = AgentBuilder::new()
657            .with_profile(profile)
658            .with_config(config)
659            .build()
660            .unwrap();
661
662        // Should auto-generate session ID with timestamp
663        assert!(!agent.session_id().is_empty());
664    }
665
666    #[test]
667    fn test_create_agent_from_registry() {
668        let dir = tempdir().unwrap();
669        let db_path = dir.path().join("test.duckdb");
670        let persistence = Persistence::new(&db_path).unwrap();
671
672        let config = create_test_config();
673        let profile = create_test_profile();
674
675        let mut agents = HashMap::new();
676        agents.insert("test-agent".to_string(), profile.clone());
677
678        let registry = AgentRegistry::new(agents, persistence.clone());
679        registry.set_active("test-agent").unwrap();
680
681        let agent =
682            create_agent_from_registry(&registry, &config, Some("custom-session".to_string()))
683                .unwrap();
684
685        assert_eq!(agent.session_id(), "custom-session");
686        assert_eq!(
687            agent.profile().prompt,
688            Some("Test system prompt".to_string())
689        );
690    }
691
692    #[test]
693    fn test_create_agent_from_registry_no_active() {
694        let dir = tempdir().unwrap();
695        let db_path = dir.path().join("test.duckdb");
696        let persistence = Persistence::new(&db_path).unwrap();
697
698        let config = create_test_config();
699        let registry = AgentRegistry::new(HashMap::new(), persistence);
700
701        let result = create_agent_from_registry(&registry, &config, None);
702
703        assert!(result.is_err());
704        let err_msg = result.err().unwrap().to_string();
705        assert!(err_msg.contains("No active") || err_msg.contains("active agent"));
706    }
707}