Skip to main content

rusty_commit/providers/
registry.rs

1//! Provider Registry - Central registry for all AI providers
2//!
3//! This module provides a extensible registry pattern for AI providers.
4//! New providers can be added by implementing the `ProviderBuilder` trait
5//! and registering them with the `ProviderRegistry`.
6
7use crate::config::Config;
8use anyhow::{Context, Result};
9use std::collections::HashMap;
10use std::sync::RwLock;
11
12/// Lock error type for registry operations
13#[derive(thiserror::Error, Debug)]
14#[error("Registry lock error")]
15pub struct LockError;
16
17macro_rules! read_lock {
18    ($lock:expr, $field:ident) => {
19        $lock.read().map_err(|_| {
20            tracing::error!("{} lock is poisoned", stringify!($field));
21            LockError
22        })
23    };
24}
25
26macro_rules! write_lock {
27    ($lock:expr, $field:ident) => {
28        $lock.write().map_err(|_| {
29            tracing::error!("{} lock is poisoned", stringify!($field));
30            LockError
31        })
32    };
33}
34
35/// Trait for building AI provider instances
36pub trait ProviderBuilder: Send + Sync {
37    /// The provider name/identifier
38    fn name(&self) -> &'static str;
39
40    /// Alternative names for this provider (aliases)
41    fn aliases(&self) -> Vec<&'static str> {
42        vec![]
43    }
44
45    /// Provider category for documentation
46    fn category(&self) -> ProviderCategory {
47        ProviderCategory::Standard
48    }
49
50    /// Create a provider instance from config
51    fn create(&self, config: &Config) -> Result<Box<dyn super::AIProvider>>;
52
53    /// Whether this provider requires an API key
54    fn requires_api_key(&self) -> bool {
55        true
56    }
57
58    /// Default model for this provider (if applicable)
59    fn default_model(&self) -> Option<&'static str> {
60        None
61    }
62}
63
64/// Provider categories for organization
65#[derive(Debug, Clone, Copy, PartialEq, Eq)]
66#[allow(dead_code)]
67pub enum ProviderCategory {
68    /// Direct API providers (OpenAI, Anthropic, etc.)
69    Standard,
70    /// OpenAI-compatible API providers
71    OpenAICompatible,
72    /// Self-hosted/local providers
73    Local,
74    /// Cloud marketplace providers
75    Cloud,
76}
77
78/// Registry entry for a provider (metadata only, no builder)
79#[derive(Clone)]
80pub struct ProviderEntry {
81    pub name: &'static str,
82    pub aliases: Vec<&'static str>,
83    pub category: ProviderCategory,
84    #[allow(dead_code)]
85    pub requires_api_key: bool,
86    #[allow(dead_code)]
87    pub default_model: Option<&'static str>,
88}
89
90impl ProviderEntry {
91    pub fn from_builder(builder: &dyn ProviderBuilder) -> Self {
92        Self {
93            name: builder.name(),
94            aliases: builder.aliases(),
95            category: builder.category(),
96            requires_api_key: builder.requires_api_key(),
97            default_model: builder.default_model(),
98        }
99    }
100
101    /// Check if this entry matches a provider name
102    #[allow(dead_code)]
103    pub fn matches(&self, provider: &str) -> bool {
104        let lower = provider.to_lowercase();
105        self.name.eq_ignore_ascii_case(&lower)
106            || self.aliases.iter().any(|&a| a.eq_ignore_ascii_case(&lower))
107    }
108}
109
110/// The provider registry - a thread-safe registry of all available providers
111pub struct ProviderRegistry {
112    entries: RwLock<HashMap<&'static str, ProviderEntry>>,
113    builders: RwLock<HashMap<&'static str, Box<dyn ProviderBuilder>>>,
114    by_alias: RwLock<HashMap<&'static str, &'static str>>,
115}
116
117impl ProviderRegistry {
118    /// Create a new empty registry
119    pub fn new() -> Self {
120        Self {
121            entries: RwLock::new(HashMap::new()),
122            builders: RwLock::new(HashMap::new()),
123            by_alias: RwLock::new(HashMap::new()),
124        }
125    }
126
127    /// Register a provider builder
128    pub fn register(&self, builder: Box<dyn ProviderBuilder>) -> Result<()> {
129        let name = builder.name();
130        let entry = ProviderEntry::from_builder(&*builder);
131
132        // Register primary name
133        write_lock!(self.entries, entries)?.insert(name, entry.clone());
134        write_lock!(self.builders, builders)?.insert(name, builder);
135
136        // Register aliases
137        for &alias in &entry.aliases {
138            write_lock!(self.by_alias, by_alias)?.insert(alias, name);
139        }
140
141        Ok(())
142    }
143
144    /// Get a provider entry by name or alias
145    #[allow(dead_code)]
146    pub fn get(&self, provider: &str) -> Option<ProviderEntry> {
147        let lower = provider.to_lowercase();
148
149        // Try direct lookup
150        let entries = read_lock!(self.entries, entries).ok()?;
151        if let Some(entry) = entries.get(lower.as_str()) {
152            return Some(entry.clone());
153        }
154
155        // Try alias lookup
156        let by_alias = read_lock!(self.by_alias, by_alias).ok()?;
157        if let Some(&primary) = by_alias.get(lower.as_str()) {
158            return entries.get(primary).cloned();
159        }
160
161        None
162    }
163
164    /// Get all registered providers
165    pub fn all(&self) -> Option<Vec<ProviderEntry>> {
166        let entries = read_lock!(self.entries, entries).ok()?;
167        Some(entries.values().cloned().collect())
168    }
169
170    /// Get providers by category
171    pub fn by_category(&self, category: ProviderCategory) -> Option<Vec<ProviderEntry>> {
172        let entries = read_lock!(self.entries, entries).ok()?;
173        Some(
174            entries
175                .values()
176                .filter(|e| e.category == category)
177                .cloned()
178                .collect(),
179        )
180    }
181
182    /// Check if any providers are registered
183    #[allow(dead_code)]
184    pub fn is_empty(&self) -> bool {
185        match read_lock!(self.entries, entries) {
186            Ok(entries) => entries.is_empty(),
187            Err(_) => true,
188        }
189    }
190
191    /// Get count of registered providers
192    #[allow(dead_code)]
193    pub fn len(&self) -> usize {
194        match read_lock!(self.entries, entries) {
195            Ok(entries) => entries.len(),
196            Err(_) => 0,
197        }
198    }
199
200    /// Create a provider instance
201    pub fn create(
202        &self,
203        name: &str,
204        config: &Config,
205    ) -> Result<Option<Box<dyn super::AIProvider>>> {
206        let lower = name.to_lowercase();
207
208        let builders = read_lock!(self.builders, builders).context("Failed to read builders")?;
209        let by_alias = read_lock!(self.by_alias, by_alias).context("Failed to read aliases")?;
210
211        // Try direct lookup first
212        if let Some(builder) = builders.get(lower.as_str()) {
213            return Ok(Some(builder.create(config)?));
214        }
215
216        // Try alias lookup
217        if let Some(&primary) = by_alias.get(lower.as_str()) {
218            if let Some(builder) = builders.get(primary) {
219                return Ok(Some(builder.create(config)?));
220            }
221        }
222
223        Ok(None)
224    }
225}
226
227impl Default for ProviderRegistry {
228    fn default() -> Self {
229        Self::new()
230    }
231}