spec_ai_core/agent/
builder.rs

1//! Agent Builder
2//!
3//! Provides a fluent API for constructing agent instances.
4
5use crate::agent::core::AgentCore;
6use crate::agent::factory::{create_provider, resolve_api_key};
7use crate::agent::model::{ModelProvider, ProviderKind};
8#[cfg(feature = "openai")]
9use crate::agent::providers::openai::OpenAIProvider;
10#[cfg(feature = "lmstudio")]
11use crate::agent::providers::LMStudioProvider;
12#[cfg(feature = "mlx")]
13use crate::agent::providers::MLXProvider;
14use crate::config::{AgentProfile, AgentRegistry, AppConfig, ModelConfig};
15use crate::embeddings::EmbeddingsClient;
16use crate::persistence::Persistence;
17use crate::policy::PolicyEngine;
18use crate::tools::ToolRegistry;
19use anyhow::{anyhow, Context, Result};
20#[cfg(any(feature = "mlx", feature = "lmstudio"))]
21use async_openai::config::OpenAIConfig;
22use std::sync::Arc;
23use tracing::{info, warn};
24
25/// Builder for constructing AgentCore instances
26pub struct AgentBuilder {
27    profile: Option<AgentProfile>,
28    provider: Option<Arc<dyn ModelProvider>>,
29    embeddings_client: Option<EmbeddingsClient>,
30    persistence: Option<Persistence>,
31    session_id: Option<String>,
32    config: Option<AppConfig>,
33    tool_registry: Option<Arc<ToolRegistry>>,
34    policy_engine: Option<Arc<PolicyEngine>>,
35    agent_name: Option<String>,
36}
37
38impl AgentBuilder {
39    /// Create a new agent builder
40    pub fn new() -> Self {
41        Self {
42            profile: None,
43            provider: None,
44            embeddings_client: None,
45            persistence: None,
46            session_id: None,
47            config: None,
48            tool_registry: None,
49            policy_engine: None,
50            agent_name: None,
51        }
52    }
53
54    /// Create an agent from the registry with the active profile
55    /// This is a convenience method for CLI use
56    pub fn new_with_registry(
57        registry: &AgentRegistry,
58        config: &AppConfig,
59        session_id: Option<String>,
60    ) -> Result<AgentCore> {
61        create_agent_from_registry(registry, config, session_id)
62    }
63
64    /// Set the agent profile
65    pub fn with_profile(mut self, profile: AgentProfile) -> Self {
66        self.profile = Some(profile);
67        self
68    }
69
70    /// Set the model provider
71    pub fn with_provider(mut self, provider: Arc<dyn ModelProvider>) -> Self {
72        self.provider = Some(provider);
73        self
74    }
75
76    /// Set a custom embeddings client
77    pub fn with_embeddings_client(mut self, embeddings_client: EmbeddingsClient) -> Self {
78        self.embeddings_client = Some(embeddings_client);
79        self
80    }
81
82    /// Set the persistence layer
83    pub fn with_persistence(mut self, persistence: Persistence) -> Self {
84        self.persistence = Some(persistence);
85        self
86    }
87
88    /// Set the session ID
89    pub fn with_session_id(mut self, session_id: impl Into<String>) -> Self {
90        self.session_id = Some(session_id.into());
91        self
92    }
93
94    /// Set the application configuration (used to derive defaults)
95    pub fn with_config(mut self, config: AppConfig) -> Self {
96        self.config = Some(config);
97        self
98    }
99
100    /// Set the tool registry
101    pub fn with_tool_registry(mut self, tool_registry: Arc<ToolRegistry>) -> Self {
102        self.tool_registry = Some(tool_registry);
103        self
104    }
105
106    /// Set the policy engine
107    pub fn with_policy_engine(mut self, policy_engine: Arc<PolicyEngine>) -> Self {
108        self.policy_engine = Some(policy_engine);
109        self
110    }
111
112    /// Set the logical agent name (used for telemetry/logging)
113    pub fn with_agent_name(mut self, agent_name: impl Into<String>) -> Self {
114        self.agent_name = Some(agent_name.into());
115        self
116    }
117
118    /// Build the agent, validating all required fields
119    pub fn build(self) -> Result<AgentCore> {
120        // Get profile (required)
121        let profile = self
122            .profile
123            .ok_or_else(|| anyhow!("Agent profile is required"))?;
124
125        // Get or create persistence (needed for tool registry)
126        let persistence = if let Some(persistence) = self.persistence {
127            persistence
128        } else if let Some(ref config) = self.config {
129            Persistence::new(&config.database.path).context("Failed to create persistence layer")?
130        } else {
131            return Err(anyhow!(
132                "Either persistence or config must be provided to build agent"
133            ));
134        };
135
136        // Get or create embeddings client
137        let embeddings_client = if let Some(client) = self.embeddings_client {
138            Some(client)
139        } else if let Some(ref config) = self.config {
140            create_embeddings_client_from_config(config)?
141        } else {
142            None
143        };
144
145        // Get or create tool registry (defaults to built-in tools)
146        // Create this before the provider so OpenAI can be configured with tools
147        let tool_registry = if let Some(registry) = self.tool_registry {
148            registry
149        } else {
150            let persistence_arc = Arc::new(persistence.clone());
151            let mut registry =
152                ToolRegistry::with_builtin_tools(Some(persistence_arc), embeddings_client.clone());
153            info!(
154                "Created tool registry with {} builtin tools",
155                registry.len()
156            );
157            for tool_name in registry.list() {
158                tracing::debug!("  - Registered tool: {}", tool_name);
159            }
160
161            // Load plugins if enabled
162            if let Some(ref config) = self.config {
163                if config.plugins.enabled {
164                    match registry.load_plugins(
165                        &config.plugins.custom_tools_dir,
166                        config.plugins.allow_override_builtin,
167                    ) {
168                        Ok(stats) => {
169                            if stats.loaded > 0 {
170                                info!(
171                                    "Loaded {} plugins with {} tools from {}",
172                                    stats.loaded,
173                                    stats.tools_loaded,
174                                    config.plugins.custom_tools_dir.display()
175                                );
176                            }
177                            if stats.failed > 0 {
178                                warn!(
179                                    "{} plugins failed to load",
180                                    stats.failed
181                                );
182                            }
183                        }
184                        Err(e) => {
185                            if config.plugins.continue_on_error {
186                                warn!("Plugin loading failed (continuing): {}", e);
187                            } else {
188                                // Re-wrap as anyhow error to propagate
189                                return Err(anyhow::anyhow!("Plugin loading failed: {}", e));
190                            }
191                        }
192                    }
193                }
194            }
195
196            Arc::new(registry)
197        };
198
199        // Get or create provider with tools configured (for OpenAI-compatible providers)
200        let provider = if let Some(provider) = self.provider {
201            provider
202        } else if let Some(ref config) = self.config {
203            let mut base_provider =
204                create_provider(&config.model).context("Failed to create provider from config")?;
205
206            // Configure OpenAI provider with tools for native function calling
207            #[cfg(feature = "openai")]
208            {
209                if base_provider.kind() == ProviderKind::OpenAI {
210                    let tools = tool_registry.to_openai_tools();
211                    if !tools.is_empty() {
212                        info!(
213                            "Configuring OpenAI provider with {} tools for native function calling",
214                            tools.len()
215                        );
216
217                        // Recreate OpenAI provider with tools
218                        let api_key = if let Some(source) = &config.model.api_key_source {
219                            resolve_api_key(source)?
220                        } else {
221                            // Default to OPENAI_API_KEY environment variable
222                            std::env::var("OPENAI_API_KEY")
223                                .context("OPENAI_API_KEY environment variable not set")?
224                        };
225
226                        let mut openai_provider = OpenAIProvider::with_api_key(api_key);
227
228                        // Set model if specified in config
229                        if let Some(model_name) = &config.model.model_name {
230                            openai_provider = openai_provider.with_model(model_name.clone());
231                        }
232
233                        // Configure with tools and cast to trait object
234                        base_provider = Arc::new(openai_provider.with_tools(tools));
235                    }
236                }
237            }
238
239            // Configure MLX provider with tools for native function calling (OpenAI-compatible API)
240            #[cfg(feature = "mlx")]
241            {
242                if base_provider.kind() == ProviderKind::MLX {
243                    let tools = tool_registry.to_openai_tools();
244                    if !tools.is_empty() {
245                        info!(
246                            "Configuring MLX provider with {} tools for native function calling",
247                            tools.len()
248                        );
249
250                        // MLX requires a model name; mirror create_provider's behavior
251                        let model_name = config
252                            .model
253                            .model_name
254                            .as_ref()
255                            .ok_or_else(|| {
256                                anyhow!("MLX provider requires a model_name to be specified")
257                            })?
258                            .clone();
259
260                        let mlx_provider = if let Ok(endpoint) = std::env::var("MLX_ENDPOINT") {
261                            MLXProvider::with_endpoint(endpoint, model_name)
262                        } else {
263                            MLXProvider::new(model_name)
264                        };
265
266                        base_provider = Arc::new(mlx_provider.with_tools(tools));
267                    }
268                }
269            }
270
271            #[cfg(feature = "lmstudio")]
272            {
273                if base_provider.kind() == ProviderKind::LMStudio {
274                    let tools = tool_registry.to_openai_tools();
275                    if !tools.is_empty() {
276                        info!(
277                            "Configuring LM Studio provider with {} tools for native function calling",
278                            tools.len()
279                        );
280
281                        let model_name = config
282                            .model
283                            .model_name
284                            .as_ref()
285                            .ok_or_else(|| {
286                                anyhow!("LM Studio provider requires a model_name to be specified")
287                            })?
288                            .clone();
289
290                        let lmstudio_provider =
291                            if let Ok(endpoint) = std::env::var("LMSTUDIO_ENDPOINT") {
292                                LMStudioProvider::with_endpoint(endpoint, model_name)
293                            } else {
294                                LMStudioProvider::new(model_name)
295                            };
296
297                        base_provider = Arc::new(lmstudio_provider.with_tools(tools));
298                    }
299                }
300            }
301
302            base_provider
303        } else {
304            return Err(anyhow!(
305                "Either provider or config must be provided to build agent"
306            ));
307        };
308
309        // Get or generate session ID
310        let session_id = self
311            .session_id
312            .unwrap_or_else(|| format!("session-{}", chrono::Utc::now().timestamp_millis()));
313
314        // Get or create policy engine (defaults to empty policy engine, or load from persistence)
315        let policy_engine = if let Some(engine) = self.policy_engine {
316            engine
317        } else {
318            // Try to load from persistence, or create empty engine with default allow rule
319            let mut engine = PolicyEngine::load_from_persistence(&persistence)
320                .unwrap_or_else(|_| PolicyEngine::new());
321
322            // If the policy engine has no rules at all, add a default allow-all for tools
323            if engine.rule_count() == 0 {
324                tracing::debug!(
325                    "Empty policy engine detected, adding default allow-all rule for tools"
326                );
327                engine.add_rule(crate::policy::PolicyRule {
328                    agent: "*".to_string(),
329                    action: "tool_call".to_string(),
330                    resource: "*".to_string(),
331                    effect: crate::policy::PolicyEffect::Allow,
332                });
333            }
334
335            Arc::new(engine)
336        };
337
338        let fast_provider = if profile.fast_reasoning {
339            match (&profile.fast_model_provider, &profile.fast_model_name) {
340                (Some(provider_name), Some(model_name)) => {
341                    let fast_config = ModelConfig {
342                        provider: provider_name.clone(),
343                        model_name: Some(model_name.clone()),
344                        embeddings_model: None,
345                        api_key_source: None,
346                        temperature: profile.fast_model_temperature,
347                    };
348                    match create_provider(&fast_config) {
349                        Ok(provider) => Some(provider),
350                        Err(err) => {
351                            warn!(
352                                "Failed to create fast provider {}:{} - {}",
353                                provider_name, model_name, err
354                            );
355                            None
356                        }
357                    }
358                }
359                _ => None,
360            }
361        } else {
362            None
363        };
364
365        let mut agent = AgentCore::new(
366            profile,
367            provider,
368            embeddings_client,
369            persistence,
370            session_id,
371            self.agent_name,
372            tool_registry,
373            policy_engine,
374        );
375
376        if let Some(fast_provider) = fast_provider {
377            agent = agent.with_fast_provider(fast_provider);
378        }
379
380        Ok(agent)
381    }
382}
383
384impl Default for AgentBuilder {
385    fn default() -> Self {
386        Self::new()
387    }
388}
389
390/// Create an agent from the active profile in the registry
391pub fn create_agent_from_registry(
392    registry: &AgentRegistry,
393    config: &AppConfig,
394    session_id: Option<String>,
395) -> Result<AgentCore> {
396    let (agent_name, profile) = registry
397        .active()
398        .context("No active agent profile in registry")?
399        .ok_or_else(|| anyhow!("No active agent set in registry"))?;
400
401    let mut builder = AgentBuilder::new()
402        .with_profile(profile)
403        .with_config(config.clone())
404        .with_persistence(registry.persistence().clone())
405        .with_agent_name(agent_name.clone());
406
407    if let Some(sid) = session_id {
408        builder = builder.with_session_id(sid);
409    }
410
411    builder.build()
412}
413
414fn create_embeddings_client_from_config(config: &AppConfig) -> Result<Option<EmbeddingsClient>> {
415    let model = &config.model;
416    let Some(model_name) = &model.embeddings_model else {
417        return Ok(None);
418    };
419
420    #[cfg(feature = "mlx")]
421    {
422        if ProviderKind::from_str(&model.provider) == Some(ProviderKind::MLX) {
423            return Ok(Some(build_mlx_embeddings_client(model_name)));
424        }
425    }
426
427    #[cfg(feature = "lmstudio")]
428    {
429        if ProviderKind::from_str(&model.provider) == Some(ProviderKind::LMStudio) {
430            return Ok(Some(build_lmstudio_embeddings_client(model_name)));
431        }
432    }
433
434    let client = if let Some(source) = &model.api_key_source {
435        let api_key = resolve_api_key(source)?;
436        EmbeddingsClient::with_api_key(model_name.clone(), api_key)
437    } else {
438        EmbeddingsClient::new(model_name.clone())
439    };
440
441    Ok(Some(client))
442}
443
444#[cfg(feature = "mlx")]
445fn build_mlx_embeddings_client(model_name: &str) -> EmbeddingsClient {
446    let endpoint =
447        std::env::var("MLX_ENDPOINT").unwrap_or_else(|_| "http://localhost:10240".to_string());
448    let api_base = if endpoint.ends_with("/v1") {
449        endpoint
450    } else {
451        format!("{}/v1", endpoint)
452    };
453
454    let config = OpenAIConfig::new()
455        .with_api_base(api_base)
456        .with_api_key("mlx-key");
457
458    EmbeddingsClient::with_config(model_name.to_string(), config)
459}
460
461#[cfg(feature = "lmstudio")]
462fn build_lmstudio_embeddings_client(model_name: &str) -> EmbeddingsClient {
463    let endpoint =
464        std::env::var("LMSTUDIO_ENDPOINT").unwrap_or_else(|_| "http://localhost:1234".to_string());
465    let api_base = if endpoint.ends_with("/v1") {
466        endpoint
467    } else {
468        format!("{}/v1", endpoint)
469    };
470
471    let config = OpenAIConfig::new()
472        .with_api_base(api_base)
473        .with_api_key("lmstudio-key");
474
475    EmbeddingsClient::with_config(model_name.to_string(), config)
476}
477
478#[cfg(test)]
479mod tests {
480    use super::*;
481    use crate::agent::providers::MockProvider;
482    use crate::config::{
483        AgentProfile, AudioConfig, DatabaseConfig, LoggingConfig, ModelConfig, PluginConfig,
484        UiConfig,
485    };
486    use std::collections::HashMap;
487    use tempfile::tempdir;
488
489    fn create_test_config() -> AppConfig {
490        let dir = tempdir().unwrap();
491        let db_path = dir.path().join("test.duckdb");
492
493        AppConfig {
494            database: DatabaseConfig { path: db_path },
495            model: ModelConfig {
496                provider: "mock".to_string(),
497                model_name: Some("test-model".to_string()),
498                embeddings_model: None,
499                api_key_source: None,
500                temperature: 0.7,
501            },
502            ui: UiConfig {
503                prompt: "> ".to_string(),
504                theme: "default".to_string(),
505            },
506            logging: LoggingConfig {
507                level: "info".to_string(),
508            },
509            audio: AudioConfig::default(),
510            mesh: crate::config::MeshConfig::default(),
511            plugins: PluginConfig::default(),
512            agents: HashMap::new(),
513            default_agent: None,
514        }
515    }
516
517    fn create_test_profile() -> AgentProfile {
518        AgentProfile {
519            prompt: Some("Test system prompt".to_string()),
520            style: None,
521            temperature: Some(0.8),
522            model_provider: None,
523            model_name: None,
524            allowed_tools: None,
525            denied_tools: None,
526            memory_k: 10,
527            top_p: 0.95,
528            max_context_tokens: Some(4096),
529            enable_graph: false,
530            graph_memory: false,
531            auto_graph: false,
532            graph_steering: false,
533            graph_depth: 3,
534            graph_weight: 0.5,
535            graph_threshold: 0.7,
536            fast_reasoning: false,
537            fast_model_provider: None,
538            fast_model_name: None,
539            fast_model_temperature: 0.3,
540            fast_model_tasks: vec![],
541            escalation_threshold: 0.6,
542            show_reasoning: false,
543            enable_audio_transcription: false,
544            audio_response_mode: "immediate".to_string(),
545            audio_scenario: None,
546        }
547    }
548
549    #[test]
550    fn test_builder_with_all_fields() {
551        let dir = tempdir().unwrap();
552        let db_path = dir.path().join("test.duckdb");
553        let persistence = Persistence::new(&db_path).unwrap();
554
555        let profile = create_test_profile();
556        let provider = Arc::new(MockProvider::default());
557
558        let agent = AgentBuilder::new()
559            .with_profile(profile)
560            .with_provider(provider)
561            .with_persistence(persistence)
562            .with_session_id("test-session")
563            .build()
564            .unwrap();
565
566        assert_eq!(agent.session_id(), "test-session");
567        assert_eq!(
568            agent.profile().prompt,
569            Some("Test system prompt".to_string())
570        );
571    }
572
573    #[test]
574    fn test_builder_with_config() {
575        let config = create_test_config();
576        let profile = create_test_profile();
577
578        let agent = AgentBuilder::new()
579            .with_profile(profile)
580            .with_config(config)
581            .build()
582            .unwrap();
583
584        // Should auto-generate session ID with timestamp
585        assert!(agent.session_id().starts_with("session-"));
586    }
587
588    #[test]
589    fn test_builder_missing_profile() {
590        let config = create_test_config();
591
592        let result = AgentBuilder::new().with_config(config).build();
593
594        assert!(result.is_err());
595        assert!(result.err().unwrap().to_string().contains("profile"));
596    }
597
598    #[test]
599    fn test_builder_missing_provider_and_config() {
600        let dir = tempdir().unwrap();
601        let db_path = dir.path().join("test.duckdb");
602        let persistence = Persistence::new(&db_path).unwrap();
603
604        let profile = create_test_profile();
605
606        let result = AgentBuilder::new()
607            .with_profile(profile)
608            .with_persistence(persistence)
609            .build();
610
611        assert!(result.is_err());
612        assert!(result
613            .err()
614            .unwrap()
615            .to_string()
616            .contains("provider or config"));
617    }
618
619    #[test]
620    fn test_builder_auto_session_id() {
621        let config = create_test_config();
622        let profile = create_test_profile();
623
624        let agent = AgentBuilder::new()
625            .with_profile(profile)
626            .with_config(config)
627            .build()
628            .unwrap();
629
630        // Should auto-generate session ID with timestamp
631        assert!(!agent.session_id().is_empty());
632    }
633
634    #[test]
635    fn test_create_agent_from_registry() {
636        let dir = tempdir().unwrap();
637        let db_path = dir.path().join("test.duckdb");
638        let persistence = Persistence::new(&db_path).unwrap();
639
640        let config = create_test_config();
641        let profile = create_test_profile();
642
643        let mut agents = HashMap::new();
644        agents.insert("test-agent".to_string(), profile.clone());
645
646        let registry = AgentRegistry::new(agents, persistence.clone());
647        registry.set_active("test-agent").unwrap();
648
649        let agent =
650            create_agent_from_registry(&registry, &config, Some("custom-session".to_string()))
651                .unwrap();
652
653        assert_eq!(agent.session_id(), "custom-session");
654        assert_eq!(
655            agent.profile().prompt,
656            Some("Test system prompt".to_string())
657        );
658    }
659
660    #[test]
661    fn test_create_agent_from_registry_no_active() {
662        let dir = tempdir().unwrap();
663        let db_path = dir.path().join("test.duckdb");
664        let persistence = Persistence::new(&db_path).unwrap();
665
666        let config = create_test_config();
667        let registry = AgentRegistry::new(HashMap::new(), persistence);
668
669        let result = create_agent_from_registry(&registry, &config, None);
670
671        assert!(result.is_err());
672        let err_msg = result.err().unwrap().to_string();
673        assert!(err_msg.contains("No active") || err_msg.contains("active agent"));
674    }
675}