Skip to main content

rust_ai_core/
facade.rs

1// SPDX-License-Identifier: MIT
2// Copyright 2026 Tyler Zervas
3
4//! High-level unified API for the rust-ai ecosystem.
5//!
6//! The `RustAI` struct provides a facade over all rust-ai ecosystem crates,
7//! offering a simplified interface for common AI engineering tasks.
8//!
9//! ## Design Philosophy
10//!
11//! Rather than requiring users to understand each individual crate's API,
12//! `RustAI` provides high-level workflows that compose the right crates
13//! automatically based on the task at hand.
14//!
15//! ## Example
16//!
17//! ```rust,ignore
18//! use rust_ai_core::{RustAI, RustAIConfig};
19//!
20//! // Initialize with sensible defaults
21//! let ai = RustAI::new(RustAIConfig::default())?;
22//!
23//! // Fine-tune a model with LoRA
24//! let config = ai.finetune()
25//!     .model("meta-llama/Llama-2-7b")
26//!     .rank(64)
27//!     .build()?;
28//!
29//! // Quantize for deployment
30//! let quant_config = ai.quantize()
31//!     .method(QuantizeMethod::Nf4)
32//!     .bits(4)
33//!     .build();
34//! ```
35
36use crate::device::{get_device, DeviceConfig};
37use crate::ecosystem::EcosystemInfo;
38use crate::error::{CoreError, Result};
39use candle_core::Device;
40
41// =============================================================================
42// CONFIGURATION
43// =============================================================================
44
45/// Configuration for the `RustAI` facade.
46///
47/// This centralizes all configuration options and provides sensible defaults
48/// for the unified API.
49#[derive(Debug, Clone)]
50pub struct RustAIConfig {
51    /// Device configuration (CUDA selection, CPU fallback)
52    pub device: DeviceConfig,
53    /// Enable verbose logging
54    pub verbose: bool,
55    /// Memory limit in bytes (0 = no limit)
56    pub memory_limit: usize,
57}
58
59impl Default for RustAIConfig {
60    fn default() -> Self {
61        Self {
62            device: DeviceConfig::from_env(),
63            verbose: false,
64            memory_limit: 0,
65        }
66    }
67}
68
69impl RustAIConfig {
70    /// Create a new configuration with defaults.
71    #[must_use]
72    pub fn new() -> Self {
73        Self::default()
74    }
75
76    /// Set verbose mode.
77    #[must_use]
78    pub fn with_verbose(mut self, verbose: bool) -> Self {
79        self.verbose = verbose;
80        self
81    }
82
83    /// Set memory limit in bytes.
84    #[must_use]
85    pub fn with_memory_limit(mut self, limit: usize) -> Self {
86        self.memory_limit = limit;
87        self
88    }
89
90    /// Force CPU execution.
91    #[must_use]
92    pub fn with_cpu(mut self) -> Self {
93        self.device = self.device.with_force_cpu(true);
94        self
95    }
96
97    /// Select a specific CUDA device.
98    #[must_use]
99    pub fn with_cuda_device(mut self, ordinal: usize) -> Self {
100        self.device = self.device.with_cuda_device(ordinal);
101        self
102    }
103}
104
105// =============================================================================
106// RUST-AI FACADE
107// =============================================================================
108
109/// Unified facade for the rust-ai ecosystem.
110///
111/// Provides high-level APIs that orchestrate multiple ecosystem crates
112/// to accomplish common AI engineering tasks.
113///
114/// ## Capabilities
115///
116/// | Workflow | Description |
117/// |----------|-------------|
118/// | `finetune()` | `LoRA`, `DoRA`, `AdaLoRA` adapter creation |
119/// | `quantize()` | 4-bit (NF4/FP4) and 1.58-bit (`BitNet`) quantization |
120/// | `vsa()` | VSA-based operations and optimization |
121/// | `train()` | YAML-driven training pipelines |
122///
123/// ## Example
124///
125/// ```rust
126/// use rust_ai_core::{RustAI, RustAIConfig};
127///
128/// let config = RustAIConfig::new()
129///     .with_verbose(true);
130///
131/// let ai = RustAI::new(config).unwrap();
132/// println!("Device: {:?}", ai.device());
133/// println!("Ecosystem: {:?}", ai.ecosystem());
134/// ```
135pub struct RustAI {
136    config: RustAIConfig,
137    device: Device,
138    ecosystem: EcosystemInfo,
139}
140
141impl RustAI {
142    /// Create a new `RustAI` instance.
143    ///
144    /// Initializes the device and ecosystem information.
145    ///
146    /// # Arguments
147    ///
148    /// * `config` - Configuration options
149    ///
150    /// # Returns
151    ///
152    /// A configured `RustAI` instance.
153    ///
154    /// # Errors
155    ///
156    /// Returns `CoreError::DeviceNotAvailable` if device initialization fails.
157    ///
158    /// # Example
159    ///
160    /// ```rust
161    /// use rust_ai_core::{RustAI, RustAIConfig};
162    ///
163    /// let ai = RustAI::new(RustAIConfig::default())?;
164    /// # Ok::<(), rust_ai_core::CoreError>(())
165    /// ```
166    pub fn new(config: RustAIConfig) -> Result<Self> {
167        let device = get_device(&config.device)?;
168        let ecosystem = EcosystemInfo::new();
169
170        if config.verbose {
171            tracing::info!("RustAI initialized");
172            tracing::info!("Device: {:?}", device);
173            tracing::info!("Ecosystem crates: {:?}", EcosystemInfo::crate_names());
174        }
175
176        Ok(Self {
177            config,
178            device,
179            ecosystem,
180        })
181    }
182
183    /// Get the active device.
184    #[must_use]
185    pub fn device(&self) -> &Device {
186        &self.device
187    }
188
189    /// Get ecosystem information.
190    #[must_use]
191    pub fn ecosystem(&self) -> &EcosystemInfo {
192        &self.ecosystem
193    }
194
195    /// Get the configuration.
196    #[must_use]
197    pub fn config(&self) -> &RustAIConfig {
198        &self.config
199    }
200
201    /// Check if CUDA is available and active.
202    #[must_use]
203    pub fn is_cuda(&self) -> bool {
204        matches!(self.device, Device::Cuda(_))
205    }
206
207    /// Start a fine-tuning workflow.
208    ///
209    /// Returns a builder for configuring and running fine-tuning.
210    ///
211    /// # Example
212    ///
213    /// ```rust,ignore
214    /// let config = ai.finetune()
215    ///     .model("meta-llama/Llama-2-7b")
216    ///     .adapter(AdapterType::Lora)
217    ///     .rank(64)
218    ///     .build()?;
219    /// ```
220    #[must_use]
221    pub fn finetune(&self) -> FinetuneBuilder<'_> {
222        FinetuneBuilder::new(self)
223    }
224
225    /// Start a quantization workflow.
226    ///
227    /// Returns a builder for configuring and running quantization.
228    ///
229    /// # Example
230    ///
231    /// ```rust,ignore
232    /// let config = ai.quantize()
233    ///     .method(QuantizeMethod::Nf4)
234    ///     .bits(4)
235    ///     .build();
236    /// ```
237    #[must_use]
238    pub fn quantize(&self) -> QuantizeBuilder<'_> {
239        QuantizeBuilder::new(self)
240    }
241
242    /// Start a VSA workflow.
243    ///
244    /// Returns a builder for VSA-based operations.
245    ///
246    /// # Example
247    ///
248    /// ```rust,ignore
249    /// let config = ai.vsa()
250    ///     .dimension(10000)
251    ///     .build();
252    /// ```
253    #[must_use]
254    pub fn vsa(&self) -> VsaBuilder<'_> {
255        VsaBuilder::new(self)
256    }
257
258    /// Start an Axolotl training pipeline.
259    ///
260    /// Returns a builder for YAML-driven training configuration.
261    ///
262    /// # Example
263    ///
264    /// ```rust,ignore
265    /// let config = ai.train()
266    ///     .config_file("config.yaml")
267    ///     .build()?;
268    /// ```
269    #[must_use]
270    pub fn train(&self) -> TrainBuilder<'_> {
271        TrainBuilder::new(self)
272    }
273
274    /// Get information about the `RustAI` environment.
275    ///
276    /// Returns a struct containing version info, ecosystem, and device details.
277    #[must_use]
278    pub fn info(&self) -> RustAIInfo {
279        RustAIInfo {
280            version: crate::VERSION.to_string(),
281            device: format!("{:?}", self.device),
282            ecosystem_crates: EcosystemInfo::crate_names()
283                .iter()
284                .map(|s| (*s).to_string())
285                .collect(),
286            cuda_available: self.is_cuda(),
287            memory_limit: self.config.memory_limit,
288        }
289    }
290}
291
292/// Information about the `RustAI` environment.
293#[derive(Debug, Clone)]
294pub struct RustAIInfo {
295    /// Crate version
296    pub version: String,
297    /// Active device description
298    pub device: String,
299    /// List of ecosystem crates
300    pub ecosystem_crates: Vec<String>,
301    /// Whether CUDA is available
302    pub cuda_available: bool,
303    /// Memory limit (0 = unlimited)
304    pub memory_limit: usize,
305}
306
307// =============================================================================
308// FINE-TUNING WORKFLOW
309// =============================================================================
310
311/// Builder for fine-tuning workflows.
312pub struct FinetuneBuilder<'a> {
313    #[allow(dead_code)]
314    ai: &'a RustAI,
315    model_path: Option<String>,
316    adapter_type: AdapterType,
317    rank: usize,
318    alpha: f32,
319    dropout: f32,
320    target_modules: Vec<String>,
321}
322
323/// Type of PEFT adapter.
324#[derive(Debug, Clone, Copy, Default)]
325pub enum AdapterType {
326    /// Low-Rank Adaptation
327    #[default]
328    Lora,
329    /// Weight-Decomposed Low-Rank Adaptation
330    Dora,
331    /// Adaptive Budget Low-Rank Adaptation
332    AdaLora,
333}
334
335impl<'a> FinetuneBuilder<'a> {
336    fn new(ai: &'a RustAI) -> Self {
337        Self {
338            ai,
339            model_path: None,
340            adapter_type: AdapterType::Lora,
341            rank: 64,
342            alpha: 16.0,
343            dropout: 0.1,
344            target_modules: vec!["q_proj".into(), "v_proj".into()],
345        }
346    }
347
348    /// Set the model path or identifier.
349    #[must_use]
350    pub fn model(mut self, path: impl Into<String>) -> Self {
351        self.model_path = Some(path.into());
352        self
353    }
354
355    /// Set the adapter type.
356    #[must_use]
357    pub fn adapter(mut self, adapter: AdapterType) -> Self {
358        self.adapter_type = adapter;
359        self
360    }
361
362    /// Set the `LoRA` rank.
363    #[must_use]
364    pub fn rank(mut self, rank: usize) -> Self {
365        self.rank = rank;
366        self
367    }
368
369    /// Set the `LoRA` alpha scaling factor.
370    #[must_use]
371    pub fn alpha(mut self, alpha: f32) -> Self {
372        self.alpha = alpha;
373        self
374    }
375
376    /// Set the dropout rate.
377    #[must_use]
378    pub fn dropout(mut self, dropout: f32) -> Self {
379        self.dropout = dropout;
380        self
381    }
382
383    /// Set the target module names to adapt.
384    #[must_use]
385    pub fn target_modules(mut self, modules: Vec<String>) -> Self {
386        self.target_modules = modules;
387        self
388    }
389
390    /// Build the fine-tuning configuration.
391    ///
392    /// # Errors
393    ///
394    /// Returns error if model path is not specified.
395    pub fn build(self) -> Result<FinetuneConfig> {
396        let model_path = self
397            .model_path
398            .ok_or_else(|| CoreError::invalid_config("model path is required for fine-tuning"))?;
399
400        Ok(FinetuneConfig {
401            model_path,
402            adapter_type: self.adapter_type,
403            rank: self.rank,
404            alpha: self.alpha,
405            dropout: self.dropout,
406            target_modules: self.target_modules,
407        })
408    }
409}
410
411/// Configuration for fine-tuning.
412#[derive(Debug, Clone)]
413pub struct FinetuneConfig {
414    /// Path to the model
415    pub model_path: String,
416    /// Type of adapter
417    pub adapter_type: AdapterType,
418    /// `LoRA` rank
419    pub rank: usize,
420    /// `LoRA` alpha
421    pub alpha: f32,
422    /// Dropout rate
423    pub dropout: f32,
424    /// Target modules
425    pub target_modules: Vec<String>,
426}
427
428// =============================================================================
429// QUANTIZATION WORKFLOW
430// =============================================================================
431
432/// Builder for quantization workflows.
433pub struct QuantizeBuilder<'a> {
434    #[allow(dead_code)]
435    ai: &'a RustAI,
436    method: QuantizeMethod,
437    bits: u8,
438    group_size: usize,
439}
440
441/// Quantization method.
442#[derive(Debug, Clone, Copy, Default)]
443pub enum QuantizeMethod {
444    /// NF4 (Normal Float 4-bit) - used in `QLoRA`
445    #[default]
446    Nf4,
447    /// FP4 (Floating Point 4-bit)
448    Fp4,
449    /// `BitNet` 1.58-bit ternary quantization
450    BitNet,
451    /// Standard INT8 quantization
452    Int8,
453}
454
455impl<'a> QuantizeBuilder<'a> {
456    fn new(ai: &'a RustAI) -> Self {
457        Self {
458            ai,
459            method: QuantizeMethod::Nf4,
460            bits: 4,
461            group_size: 64,
462        }
463    }
464
465    /// Set the quantization method.
466    #[must_use]
467    pub fn method(mut self, method: QuantizeMethod) -> Self {
468        self.method = method;
469        self
470    }
471
472    /// Set the number of bits (for non-BitNet methods).
473    #[must_use]
474    pub fn bits(mut self, bits: u8) -> Self {
475        self.bits = bits;
476        self
477    }
478
479    /// Set the group size for group-wise quantization.
480    #[must_use]
481    pub fn group_size(mut self, size: usize) -> Self {
482        self.group_size = size;
483        self
484    }
485
486    /// Build the quantization configuration.
487    #[must_use]
488    pub fn build(self) -> QuantizeConfig {
489        QuantizeConfig {
490            method: self.method,
491            bits: self.bits,
492            group_size: self.group_size,
493        }
494    }
495}
496
497/// Configuration for quantization.
498#[derive(Debug, Clone)]
499pub struct QuantizeConfig {
500    /// Quantization method
501    pub method: QuantizeMethod,
502    /// Number of bits
503    pub bits: u8,
504    /// Group size
505    pub group_size: usize,
506}
507
508// =============================================================================
509// VSA WORKFLOW
510// =============================================================================
511
512/// Builder for VSA workflows.
513pub struct VsaBuilder<'a> {
514    #[allow(dead_code)]
515    ai: &'a RustAI,
516    dimension: usize,
517}
518
519impl<'a> VsaBuilder<'a> {
520    fn new(ai: &'a RustAI) -> Self {
521        Self {
522            ai,
523            dimension: 10000,
524        }
525    }
526
527    /// Set the VSA dimension.
528    #[must_use]
529    pub fn dimension(mut self, dim: usize) -> Self {
530        self.dimension = dim;
531        self
532    }
533
534    /// Build the VSA configuration.
535    #[must_use]
536    pub fn build(self) -> VsaConfig {
537        VsaConfig {
538            dimension: self.dimension,
539        }
540    }
541}
542
543/// Configuration for VSA operations.
544#[derive(Debug, Clone)]
545pub struct VsaConfig {
546    /// VSA dimension
547    pub dimension: usize,
548}
549
550// =============================================================================
551// TRAINING WORKFLOW
552// =============================================================================
553
554/// Builder for Axolotl training pipelines.
555pub struct TrainBuilder<'a> {
556    #[allow(dead_code)]
557    ai: &'a RustAI,
558    config_path: Option<String>,
559}
560
561impl<'a> TrainBuilder<'a> {
562    fn new(ai: &'a RustAI) -> Self {
563        Self {
564            ai,
565            config_path: None,
566        }
567    }
568
569    /// Set the YAML configuration file path.
570    #[must_use]
571    pub fn config_file(mut self, path: impl Into<String>) -> Self {
572        self.config_path = Some(path.into());
573        self
574    }
575
576    /// Build the training configuration.
577    ///
578    /// # Errors
579    ///
580    /// Returns error if config file is not specified.
581    pub fn build(self) -> Result<TrainConfig> {
582        let config_path = self
583            .config_path
584            .ok_or_else(|| CoreError::invalid_config("config file path is required"))?;
585
586        Ok(TrainConfig { config_path })
587    }
588}
589
590/// Configuration for Axolotl training.
591#[derive(Debug, Clone)]
592pub struct TrainConfig {
593    /// Path to YAML configuration
594    pub config_path: String,
595}
596
597// =============================================================================
598// TESTS
599// =============================================================================
600
601#[cfg(test)]
602mod tests {
603    use super::*;
604
605    #[test]
606    fn test_rustai_config_default() {
607        let config = RustAIConfig::default();
608        assert!(!config.verbose);
609        assert_eq!(config.memory_limit, 0);
610    }
611
612    #[test]
613    fn test_rustai_config_builder() {
614        let config = RustAIConfig::new()
615            .with_verbose(true)
616            .with_memory_limit(1024 * 1024 * 1024)
617            .with_cpu();
618
619        assert!(config.verbose);
620        assert_eq!(config.memory_limit, 1024 * 1024 * 1024);
621        assert!(config.device.force_cpu);
622    }
623
624    #[test]
625    fn test_rustai_new() {
626        let config = RustAIConfig::new().with_cpu();
627        let ai = RustAI::new(config).unwrap();
628        assert!(!ai.is_cuda());
629        assert_eq!(EcosystemInfo::crate_names().len(), 8);
630    }
631
632    #[test]
633    fn test_rustai_info() {
634        let config = RustAIConfig::new().with_cpu();
635        let ai = RustAI::new(config).unwrap();
636        let info = ai.info();
637        assert!(!info.version.is_empty());
638        assert!(!info.cuda_available);
639        assert_eq!(info.ecosystem_crates.len(), 8);
640    }
641
642    #[test]
643    fn test_finetune_builder() {
644        let config = RustAIConfig::new().with_cpu();
645        let ai = RustAI::new(config).unwrap();
646
647        let finetune_config = ai
648            .finetune()
649            .model("test-model")
650            .rank(32)
651            .alpha(8.0)
652            .build()
653            .unwrap();
654
655        assert_eq!(finetune_config.model_path, "test-model");
656        assert_eq!(finetune_config.rank, 32);
657        assert!((finetune_config.alpha - 8.0).abs() < f32::EPSILON);
658    }
659
660    #[test]
661    fn test_quantize_builder() {
662        let config = RustAIConfig::new().with_cpu();
663        let ai = RustAI::new(config).unwrap();
664
665        let quant_config = ai
666            .quantize()
667            .method(QuantizeMethod::BitNet)
668            .bits(2)
669            .group_size(128)
670            .build();
671
672        assert!(matches!(quant_config.method, QuantizeMethod::BitNet));
673        assert_eq!(quant_config.bits, 2);
674        assert_eq!(quant_config.group_size, 128);
675    }
676
677    #[test]
678    fn test_vsa_builder() {
679        let config = RustAIConfig::new().with_cpu();
680        let ai = RustAI::new(config).unwrap();
681
682        let vsa_config = ai.vsa().dimension(8192).build();
683
684        assert_eq!(vsa_config.dimension, 8192);
685    }
686
687    #[test]
688    fn test_train_builder() {
689        let config = RustAIConfig::new().with_cpu();
690        let ai = RustAI::new(config).unwrap();
691
692        let train_config = ai.train().config_file("train.yaml").build().unwrap();
693
694        assert_eq!(train_config.config_path, "train.yaml");
695    }
696}