spec_ai_core/agent/
builder.rs

1//! Agent Builder
2//!
3//! Provides a fluent API for constructing agent instances.
4
5use crate::agent::core::AgentCore;
6use crate::agent::factory::{create_provider, resolve_api_key};
7use crate::agent::model::{ModelProvider, ProviderKind};
8#[cfg(feature = "openai")]
9use crate::agent::providers::openai::OpenAIProvider;
10#[cfg(feature = "lmstudio")]
11use crate::agent::providers::LMStudioProvider;
12#[cfg(feature = "mlx")]
13use crate::agent::providers::MLXProvider;
14use crate::config::{AgentProfile, AgentRegistry, AppConfig, ModelConfig};
15use crate::embeddings::EmbeddingsClient;
16use crate::persistence::Persistence;
17use crate::policy::PolicyEngine;
18use crate::tools::ToolRegistry;
19use anyhow::{anyhow, Context, Result};
20#[cfg(any(feature = "mlx", feature = "lmstudio"))]
21use async_openai::config::OpenAIConfig;
22use std::sync::Arc;
23use tracing::{info, warn};
24
25/// Builder for constructing AgentCore instances
26pub struct AgentBuilder {
27    profile: Option<AgentProfile>,
28    provider: Option<Arc<dyn ModelProvider>>,
29    embeddings_client: Option<EmbeddingsClient>,
30    persistence: Option<Persistence>,
31    session_id: Option<String>,
32    config: Option<AppConfig>,
33    tool_registry: Option<Arc<ToolRegistry>>,
34    policy_engine: Option<Arc<PolicyEngine>>,
35    agent_name: Option<String>,
36}
37
38impl AgentBuilder {
39    /// Create a new agent builder
40    pub fn new() -> Self {
41        Self {
42            profile: None,
43            provider: None,
44            embeddings_client: None,
45            persistence: None,
46            session_id: None,
47            config: None,
48            tool_registry: None,
49            policy_engine: None,
50            agent_name: None,
51        }
52    }
53
54    /// Create an agent from the registry with the active profile
55    /// This is a convenience method for CLI use
56    pub fn new_with_registry(
57        registry: &AgentRegistry,
58        config: &AppConfig,
59        session_id: Option<String>,
60    ) -> Result<AgentCore> {
61        create_agent_from_registry(registry, config, session_id)
62    }
63
64    /// Set the agent profile
65    pub fn with_profile(mut self, profile: AgentProfile) -> Self {
66        self.profile = Some(profile);
67        self
68    }
69
70    /// Set the model provider
71    pub fn with_provider(mut self, provider: Arc<dyn ModelProvider>) -> Self {
72        self.provider = Some(provider);
73        self
74    }
75
76    /// Set a custom embeddings client
77    pub fn with_embeddings_client(mut self, embeddings_client: EmbeddingsClient) -> Self {
78        self.embeddings_client = Some(embeddings_client);
79        self
80    }
81
82    /// Set the persistence layer
83    pub fn with_persistence(mut self, persistence: Persistence) -> Self {
84        self.persistence = Some(persistence);
85        self
86    }
87
88    /// Set the session ID
89    pub fn with_session_id(mut self, session_id: impl Into<String>) -> Self {
90        self.session_id = Some(session_id.into());
91        self
92    }
93
94    /// Set the application configuration (used to derive defaults)
95    pub fn with_config(mut self, config: AppConfig) -> Self {
96        self.config = Some(config);
97        self
98    }
99
100    /// Set the tool registry
101    pub fn with_tool_registry(mut self, tool_registry: Arc<ToolRegistry>) -> Self {
102        self.tool_registry = Some(tool_registry);
103        self
104    }
105
106    /// Set the policy engine
107    pub fn with_policy_engine(mut self, policy_engine: Arc<PolicyEngine>) -> Self {
108        self.policy_engine = Some(policy_engine);
109        self
110    }
111
112    /// Set the logical agent name (used for telemetry/logging)
113    pub fn with_agent_name(mut self, agent_name: impl Into<String>) -> Self {
114        self.agent_name = Some(agent_name.into());
115        self
116    }
117
118    /// Build the agent, validating all required fields
119    pub fn build(self) -> Result<AgentCore> {
120        // Get profile (required)
121        let profile = self
122            .profile
123            .ok_or_else(|| anyhow!("Agent profile is required"))?;
124
125        // Get or create persistence (needed for tool registry)
126        let persistence = if let Some(persistence) = self.persistence {
127            persistence
128        } else if let Some(ref config) = self.config {
129            Persistence::new(&config.database.path).context("Failed to create persistence layer")?
130        } else {
131            return Err(anyhow!(
132                "Either persistence or config must be provided to build agent"
133            ));
134        };
135
136        // Get or create embeddings client
137        let embeddings_client = if let Some(client) = self.embeddings_client {
138            Some(client)
139        } else if let Some(ref config) = self.config {
140            create_embeddings_client_from_config(config)?
141        } else {
142            None
143        };
144
145        // Get or create tool registry (defaults to built-in tools)
146        // Create this before the provider so OpenAI can be configured with tools
147        let tool_registry = if let Some(registry) = self.tool_registry {
148            registry
149        } else {
150            let persistence_arc = Arc::new(persistence.clone());
151            let mut registry =
152                ToolRegistry::with_builtin_tools(Some(persistence_arc), embeddings_client.clone());
153            info!(
154                "Created tool registry with {} builtin tools",
155                registry.len()
156            );
157            for tool_name in registry.list() {
158                tracing::debug!("  - Registered tool: {}", tool_name);
159            }
160
161            // Load plugins if enabled
162            if let Some(ref config) = self.config {
163                if config.plugins.enabled {
164                    match registry.load_plugins(
165                        &config.plugins.custom_tools_dir,
166                        config.plugins.allow_override_builtin,
167                    ) {
168                        Ok(stats) => {
169                            if stats.loaded > 0 {
170                                info!(
171                                    "Loaded {} plugins with {} tools from {}",
172                                    stats.loaded,
173                                    stats.tools_loaded,
174                                    config.plugins.custom_tools_dir.display()
175                                );
176                            }
177                            if stats.failed > 0 {
178                                warn!("{} plugins failed to load", stats.failed);
179                            }
180                        }
181                        Err(e) => {
182                            if config.plugins.continue_on_error {
183                                warn!("Plugin loading failed (continuing): {}", e);
184                            } else {
185                                // Re-wrap as anyhow error to propagate
186                                return Err(anyhow::anyhow!("Plugin loading failed: {}", e));
187                            }
188                        }
189                    }
190                }
191            }
192
193            Arc::new(registry)
194        };
195
196        // Get or create provider with tools configured (for OpenAI-compatible providers)
197        let provider = if let Some(provider) = self.provider {
198            provider
199        } else if let Some(ref config) = self.config {
200            let mut base_provider =
201                create_provider(&config.model).context("Failed to create provider from config")?;
202
203            // Configure OpenAI provider with tools for native function calling
204            #[cfg(feature = "openai")]
205            {
206                if base_provider.kind() == ProviderKind::OpenAI {
207                    let tools = tool_registry.to_openai_tools();
208                    if !tools.is_empty() {
209                        info!(
210                            "Configuring OpenAI provider with {} tools for native function calling",
211                            tools.len()
212                        );
213
214                        // Recreate OpenAI provider with tools
215                        let api_key = if let Some(source) = &config.model.api_key_source {
216                            resolve_api_key(source)?
217                        } else {
218                            // Default to OPENAI_API_KEY environment variable
219                            std::env::var("OPENAI_API_KEY")
220                                .context("OPENAI_API_KEY environment variable not set")?
221                        };
222
223                        let mut openai_provider = OpenAIProvider::with_api_key(api_key);
224
225                        // Set model if specified in config
226                        if let Some(model_name) = &config.model.model_name {
227                            openai_provider = openai_provider.with_model(model_name.clone());
228                        }
229
230                        // Configure with tools and cast to trait object
231                        base_provider = Arc::new(openai_provider.with_tools(tools));
232                    }
233                }
234            }
235
236            // Configure MLX provider with tools for native function calling (OpenAI-compatible API)
237            #[cfg(feature = "mlx")]
238            {
239                if base_provider.kind() == ProviderKind::MLX {
240                    let tools = tool_registry.to_openai_tools();
241                    if !tools.is_empty() {
242                        info!(
243                            "Configuring MLX provider with {} tools for native function calling",
244                            tools.len()
245                        );
246
247                        // MLX requires a model name; mirror create_provider's behavior
248                        let model_name = config
249                            .model
250                            .model_name
251                            .as_ref()
252                            .ok_or_else(|| {
253                                anyhow!("MLX provider requires a model_name to be specified")
254                            })?
255                            .clone();
256
257                        let mlx_provider = if let Ok(endpoint) = std::env::var("MLX_ENDPOINT") {
258                            MLXProvider::with_endpoint(endpoint, model_name)
259                        } else {
260                            MLXProvider::new(model_name)
261                        };
262
263                        base_provider = Arc::new(mlx_provider.with_tools(tools));
264                    }
265                }
266            }
267
268            #[cfg(feature = "lmstudio")]
269            {
270                if base_provider.kind() == ProviderKind::LMStudio {
271                    let tools = tool_registry.to_openai_tools();
272                    if !tools.is_empty() {
273                        info!(
274                            "Configuring LM Studio provider with {} tools for native function calling",
275                            tools.len()
276                        );
277
278                        let model_name = config
279                            .model
280                            .model_name
281                            .as_ref()
282                            .ok_or_else(|| {
283                                anyhow!("LM Studio provider requires a model_name to be specified")
284                            })?
285                            .clone();
286
287                        let lmstudio_provider =
288                            if let Ok(endpoint) = std::env::var("LMSTUDIO_ENDPOINT") {
289                                LMStudioProvider::with_endpoint(endpoint, model_name)
290                            } else {
291                                LMStudioProvider::new(model_name)
292                            };
293
294                        base_provider = Arc::new(lmstudio_provider.with_tools(tools));
295                    }
296                }
297            }
298
299            base_provider
300        } else {
301            return Err(anyhow!(
302                "Either provider or config must be provided to build agent"
303            ));
304        };
305
306        // Get or generate session ID
307        let session_id = self
308            .session_id
309            .unwrap_or_else(|| format!("session-{}", chrono::Utc::now().timestamp_millis()));
310
311        // Get or create policy engine (defaults to empty policy engine, or load from persistence)
312        let policy_engine = if let Some(engine) = self.policy_engine {
313            engine
314        } else {
315            // Try to load from persistence, or create empty engine with default allow rule
316            let mut engine = PolicyEngine::load_from_persistence(&persistence)
317                .unwrap_or_else(|_| PolicyEngine::new());
318
319            // If the policy engine has no rules at all, add a default allow-all for tools
320            if engine.rule_count() == 0 {
321                tracing::debug!(
322                    "Empty policy engine detected, adding default allow-all rule for tools"
323                );
324                engine.add_rule(crate::policy::PolicyRule {
325                    agent: "*".to_string(),
326                    action: "tool_call".to_string(),
327                    resource: "*".to_string(),
328                    effect: crate::policy::PolicyEffect::Allow,
329                });
330            }
331
332            Arc::new(engine)
333        };
334
335        let fast_provider = if profile.fast_reasoning {
336            match (&profile.fast_model_provider, &profile.fast_model_name) {
337                (Some(provider_name), Some(model_name)) => {
338                    let fast_config = ModelConfig {
339                        provider: provider_name.clone(),
340                        model_name: Some(model_name.clone()),
341                        embeddings_model: None,
342                        api_key_source: None,
343                        temperature: profile.fast_model_temperature,
344                    };
345                    match create_provider(&fast_config) {
346                        Ok(provider) => Some(provider),
347                        Err(err) => {
348                            warn!(
349                                "Failed to create fast provider {}:{} - {}",
350                                provider_name, model_name, err
351                            );
352                            None
353                        }
354                    }
355                }
356                _ => None,
357            }
358        } else {
359            None
360        };
361
362        let mut agent = AgentCore::new(
363            profile,
364            provider,
365            embeddings_client,
366            persistence,
367            session_id,
368            self.agent_name,
369            tool_registry,
370            policy_engine,
371        );
372
373        if let Some(fast_provider) = fast_provider {
374            agent = agent.with_fast_provider(fast_provider);
375        }
376
377        Ok(agent)
378    }
379}
380
381impl Default for AgentBuilder {
382    fn default() -> Self {
383        Self::new()
384    }
385}
386
387/// Create an agent from the active profile in the registry
388pub fn create_agent_from_registry(
389    registry: &AgentRegistry,
390    config: &AppConfig,
391    session_id: Option<String>,
392) -> Result<AgentCore> {
393    let (agent_name, profile) = registry
394        .active()
395        .context("No active agent profile in registry")?
396        .ok_or_else(|| anyhow!("No active agent set in registry"))?;
397
398    let mut builder = AgentBuilder::new()
399        .with_profile(profile)
400        .with_config(config.clone())
401        .with_persistence(registry.persistence().clone())
402        .with_agent_name(agent_name.clone());
403
404    if let Some(sid) = session_id {
405        builder = builder.with_session_id(sid);
406    }
407
408    builder.build()
409}
410
411fn create_embeddings_client_from_config(config: &AppConfig) -> Result<Option<EmbeddingsClient>> {
412    let model = &config.model;
413    let Some(model_name) = &model.embeddings_model else {
414        return Ok(None);
415    };
416
417    #[cfg(feature = "mlx")]
418    {
419        if ProviderKind::from_str(&model.provider) == Some(ProviderKind::MLX) {
420            return Ok(Some(build_mlx_embeddings_client(model_name)));
421        }
422    }
423
424    #[cfg(feature = "lmstudio")]
425    {
426        if ProviderKind::from_str(&model.provider) == Some(ProviderKind::LMStudio) {
427            return Ok(Some(build_lmstudio_embeddings_client(model_name)));
428        }
429    }
430
431    let client = if let Some(source) = &model.api_key_source {
432        let api_key = resolve_api_key(source)?;
433        EmbeddingsClient::with_api_key(model_name.clone(), api_key)
434    } else {
435        EmbeddingsClient::new(model_name.clone())
436    };
437
438    Ok(Some(client))
439}
440
441#[cfg(feature = "mlx")]
442fn build_mlx_embeddings_client(model_name: &str) -> EmbeddingsClient {
443    let endpoint =
444        std::env::var("MLX_ENDPOINT").unwrap_or_else(|_| "http://localhost:10240".to_string());
445    let api_base = if endpoint.ends_with("/v1") {
446        endpoint
447    } else {
448        format!("{}/v1", endpoint)
449    };
450
451    let config = OpenAIConfig::new()
452        .with_api_base(api_base)
453        .with_api_key("mlx-key");
454
455    EmbeddingsClient::with_config(model_name.to_string(), config)
456}
457
458#[cfg(feature = "lmstudio")]
459fn build_lmstudio_embeddings_client(model_name: &str) -> EmbeddingsClient {
460    let endpoint =
461        std::env::var("LMSTUDIO_ENDPOINT").unwrap_or_else(|_| "http://localhost:1234".to_string());
462    let api_base = if endpoint.ends_with("/v1") {
463        endpoint
464    } else {
465        format!("{}/v1", endpoint)
466    };
467
468    let config = OpenAIConfig::new()
469        .with_api_base(api_base)
470        .with_api_key("lmstudio-key");
471
472    EmbeddingsClient::with_config(model_name.to_string(), config)
473}
474
475#[cfg(test)]
476mod tests {
477    use super::*;
478    use crate::agent::providers::MockProvider;
479    use crate::config::{
480        AgentProfile, AudioConfig, DatabaseConfig, LoggingConfig, ModelConfig, PluginConfig,
481        UiConfig,
482    };
483    use std::collections::HashMap;
484    use tempfile::tempdir;
485
486    fn create_test_config() -> AppConfig {
487        let dir = tempdir().unwrap();
488        let db_path = dir.path().join("test.duckdb");
489
490        AppConfig {
491            database: DatabaseConfig { path: db_path },
492            model: ModelConfig {
493                provider: "mock".to_string(),
494                model_name: Some("test-model".to_string()),
495                embeddings_model: None,
496                api_key_source: None,
497                temperature: 0.7,
498            },
499            ui: UiConfig {
500                prompt: "> ".to_string(),
501                theme: "default".to_string(),
502            },
503            logging: LoggingConfig {
504                level: "info".to_string(),
505            },
506            audio: AudioConfig::default(),
507            mesh: crate::config::MeshConfig::default(),
508            plugins: PluginConfig::default(),
509            agents: HashMap::new(),
510            default_agent: None,
511        }
512    }
513
514    fn create_test_profile() -> AgentProfile {
515        AgentProfile {
516            prompt: Some("Test system prompt".to_string()),
517            style: None,
518            temperature: Some(0.8),
519            model_provider: None,
520            model_name: None,
521            allowed_tools: None,
522            denied_tools: None,
523            memory_k: 10,
524            top_p: 0.95,
525            max_context_tokens: Some(4096),
526            enable_graph: false,
527            graph_memory: false,
528            auto_graph: false,
529            graph_steering: false,
530            graph_depth: 3,
531            graph_weight: 0.5,
532            graph_threshold: 0.7,
533            fast_reasoning: false,
534            fast_model_provider: None,
535            fast_model_name: None,
536            fast_model_temperature: 0.3,
537            fast_model_tasks: vec![],
538            escalation_threshold: 0.6,
539            show_reasoning: false,
540            enable_audio_transcription: false,
541            audio_response_mode: "immediate".to_string(),
542            audio_scenario: None,
543        }
544    }
545
546    #[test]
547    fn test_builder_with_all_fields() {
548        let dir = tempdir().unwrap();
549        let db_path = dir.path().join("test.duckdb");
550        let persistence = Persistence::new(&db_path).unwrap();
551
552        let profile = create_test_profile();
553        let provider = Arc::new(MockProvider::default());
554
555        let agent = AgentBuilder::new()
556            .with_profile(profile)
557            .with_provider(provider)
558            .with_persistence(persistence)
559            .with_session_id("test-session")
560            .build()
561            .unwrap();
562
563        assert_eq!(agent.session_id(), "test-session");
564        assert_eq!(
565            agent.profile().prompt,
566            Some("Test system prompt".to_string())
567        );
568    }
569
570    #[test]
571    fn test_builder_with_config() {
572        let config = create_test_config();
573        let profile = create_test_profile();
574
575        let agent = AgentBuilder::new()
576            .with_profile(profile)
577            .with_config(config)
578            .build()
579            .unwrap();
580
581        // Should auto-generate session ID with timestamp
582        assert!(agent.session_id().starts_with("session-"));
583    }
584
585    #[test]
586    fn test_builder_missing_profile() {
587        let config = create_test_config();
588
589        let result = AgentBuilder::new().with_config(config).build();
590
591        assert!(result.is_err());
592        assert!(result.err().unwrap().to_string().contains("profile"));
593    }
594
595    #[test]
596    fn test_builder_missing_provider_and_config() {
597        let dir = tempdir().unwrap();
598        let db_path = dir.path().join("test.duckdb");
599        let persistence = Persistence::new(&db_path).unwrap();
600
601        let profile = create_test_profile();
602
603        let result = AgentBuilder::new()
604            .with_profile(profile)
605            .with_persistence(persistence)
606            .build();
607
608        assert!(result.is_err());
609        assert!(result
610            .err()
611            .unwrap()
612            .to_string()
613            .contains("provider or config"));
614    }
615
616    #[test]
617    fn test_builder_auto_session_id() {
618        let config = create_test_config();
619        let profile = create_test_profile();
620
621        let agent = AgentBuilder::new()
622            .with_profile(profile)
623            .with_config(config)
624            .build()
625            .unwrap();
626
627        // Should auto-generate session ID with timestamp
628        assert!(!agent.session_id().is_empty());
629    }
630
631    #[test]
632    fn test_create_agent_from_registry() {
633        let dir = tempdir().unwrap();
634        let db_path = dir.path().join("test.duckdb");
635        let persistence = Persistence::new(&db_path).unwrap();
636
637        let config = create_test_config();
638        let profile = create_test_profile();
639
640        let mut agents = HashMap::new();
641        agents.insert("test-agent".to_string(), profile.clone());
642
643        let registry = AgentRegistry::new(agents, persistence.clone());
644        registry.set_active("test-agent").unwrap();
645
646        let agent =
647            create_agent_from_registry(&registry, &config, Some("custom-session".to_string()))
648                .unwrap();
649
650        assert_eq!(agent.session_id(), "custom-session");
651        assert_eq!(
652            agent.profile().prompt,
653            Some("Test system prompt".to_string())
654        );
655    }
656
657    #[test]
658    fn test_create_agent_from_registry_no_active() {
659        let dir = tempdir().unwrap();
660        let db_path = dir.path().join("test.duckdb");
661        let persistence = Persistence::new(&db_path).unwrap();
662
663        let config = create_test_config();
664        let registry = AgentRegistry::new(HashMap::new(), persistence);
665
666        let result = create_agent_from_registry(&registry, &config, None);
667
668        assert!(result.is_err());
669        let err_msg = result.err().unwrap().to_string();
670        assert!(err_msg.contains("No active") || err_msg.contains("active agent"));
671    }
672}