Skip to main content

trustformers_core/patterns/
builder.rs

1//! Standardized builder patterns for TrustformeRS Core
2//!
3//! This module provides a consistent builder pattern implementation
4//! that can be used across all modules for configuration and object construction.
5
6#![allow(unused_variables)] // Builder pattern
7
8use crate::errors::Result;
9use serde::{Deserialize, Serialize};
10use std::fmt::Debug;
11use std::marker::PhantomData;
12
13/// Standard builder trait that all builders should implement
14pub trait Builder<T> {
15    /// Build the final object
16    fn build(self) -> Result<T>;
17
18    /// Validate the current builder state
19    fn validate(&self) -> Result<()> {
20        Ok(())
21    }
22
23    /// Reset the builder to default state
24    fn reset(self) -> Self
25    where
26        T: Default;
27}
28
29/// Configuration builder trait for objects that have configuration
30pub trait ConfigBuilder<T, C>: Builder<T> {
31    /// Set configuration
32    fn config(self, config: C) -> Self;
33
34    /// Get current configuration (if any)
35    fn get_config(&self) -> Option<&C>;
36}
37
38/// Standard builder implementation using the builder pattern
39#[derive(Debug, Clone)]
40pub struct StandardBuilder<T, S = BuilderComplete> {
41    data: T,
42    _state: PhantomData<S>,
43}
44
45/// Builder state types for compile-time validation
46#[derive(Debug, Clone)]
47pub struct BuilderIncomplete;
48
49#[derive(Debug, Clone)]
50pub struct BuilderComplete;
51
52/// Trait for objects that can be built with a standard builder
53pub trait Buildable: Sized + Default {
54    type Builder: Builder<Self>;
55
56    /// Create a new builder for this type
57    fn builder() -> Self::Builder;
58}
59
60/// Standard configuration trait
61pub trait StandardConfig: Debug + Clone + Default + Serialize + for<'de> Deserialize<'de> {
62    /// Validate the configuration
63    fn validate(&self) -> Result<()> {
64        Ok(())
65    }
66
67    /// Merge with another configuration (self takes precedence)
68    fn merge(self, other: Self) -> Self {
69        self
70    }
71}
72
73impl<T> Default for StandardBuilder<T, BuilderIncomplete>
74where
75    T: Default,
76{
77    fn default() -> Self {
78        Self::new()
79    }
80}
81
82impl<T> StandardBuilder<T, BuilderIncomplete>
83where
84    T: Default,
85{
86    /// Create a new builder
87    pub fn new() -> Self {
88        Self {
89            data: T::default(),
90            _state: PhantomData,
91        }
92    }
93
94    /// Create a builder from existing data
95    pub fn from(data: T) -> Self {
96        Self {
97            data,
98            _state: PhantomData,
99        }
100    }
101}
102
103impl<T> StandardBuilder<T, BuilderIncomplete>
104where
105    T: Clone,
106{
107    /// Get a mutable reference to the data
108    pub fn data_mut(&mut self) -> &mut T {
109        &mut self.data
110    }
111
112    /// Mark builder as complete (ready to build)
113    pub fn complete(self) -> StandardBuilder<T, BuilderComplete> {
114        StandardBuilder {
115            data: self.data,
116            _state: PhantomData,
117        }
118    }
119}
120
121impl<T> StandardBuilder<T, BuilderComplete>
122where
123    T: Clone,
124{
125    /// Get a reference to the data
126    pub fn data(&self) -> &T {
127        &self.data
128    }
129
130    /// Get a mutable reference to the data
131    pub fn data_mut(&mut self) -> &mut T {
132        &mut self.data
133    }
134}
135
136impl<T> Builder<T> for StandardBuilder<T, BuilderComplete>
137where
138    T: Clone + Default,
139{
140    fn build(self) -> Result<T> {
141        self.validate()?;
142        Ok(self.data)
143    }
144
145    fn reset(self) -> Self {
146        Self {
147            data: T::default(),
148            _state: PhantomData,
149        }
150    }
151}
152
153/// Fluent builder macro for creating builder methods
154#[macro_export]
155macro_rules! builder_methods {
156    (
157        $builder_type:ty,
158        $target_type:ty,
159        {
160            $(
161                $method_name:ident : $field_type:ty = $field_name:ident
162            ),* $(,)?
163        }
164    ) => {
165        impl $builder_type {
166            $(
167                #[doc = concat!("Set ", stringify!($field_name))]
168                pub fn $method_name(mut self, value: $field_type) -> Self {
169                    self.data.$field_name = value;
170                    self
171                }
172            )*
173        }
174    };
175}
176
177/// Type alias for validation function
178pub type ValidationFn<T> = Box<dyn Fn(&T) -> Result<()> + Send + Sync>;
179
180/// Validation builder for objects that need validation
181pub struct ValidatedBuilder<T> {
182    data: T,
183    validators: Vec<ValidationFn<T>>,
184}
185
186impl<T> Default for ValidatedBuilder<T>
187where
188    T: Default,
189{
190    fn default() -> Self {
191        Self::new()
192    }
193}
194
195impl<T> ValidatedBuilder<T>
196where
197    T: Default,
198{
199    /// Create a new validated builder
200    pub fn new() -> Self {
201        Self {
202            data: T::default(),
203            validators: Vec::new(),
204        }
205    }
206
207    /// Add a validation function
208    pub fn add_validator<F>(mut self, validator: F) -> Self
209    where
210        F: Fn(&T) -> Result<()> + Send + Sync + 'static,
211    {
212        self.validators.push(Box::new(validator));
213        self
214    }
215
216    /// Get a reference to the data
217    pub fn data(&self) -> &T {
218        &self.data
219    }
220
221    /// Get a mutable reference to the data
222    pub fn data_mut(&mut self) -> &mut T {
223        &mut self.data
224    }
225}
226
227impl<T> Builder<T> for ValidatedBuilder<T>
228where
229    T: Clone,
230{
231    fn build(self) -> Result<T> {
232        self.validate()?;
233        Ok(self.data)
234    }
235
236    fn validate(&self) -> Result<()> {
237        for validator in &self.validators {
238            validator(&self.data)?;
239        }
240        Ok(())
241    }
242
243    fn reset(mut self) -> Self
244    where
245        T: Default,
246    {
247        self.data = T::default();
248        self
249    }
250}
251
252/// Configuration builder with validation
253#[derive(Debug, Clone)]
254pub struct ConfigBuilderImpl<T, C> {
255    target: Option<T>,
256    config: Option<C>,
257    name: Option<String>,
258    description: Option<String>,
259    tags: Vec<String>,
260}
261
262impl<T, C> Default for ConfigBuilderImpl<T, C>
263where
264    C: StandardConfig,
265{
266    fn default() -> Self {
267        Self::new()
268    }
269}
270
271impl<T, C> ConfigBuilderImpl<T, C>
272where
273    C: StandardConfig,
274{
275    /// Create a new config builder
276    pub fn new() -> Self {
277        Self {
278            target: None,
279            config: None,
280            name: None,
281            description: None,
282            tags: Vec::new(),
283        }
284    }
285
286    /// Set name
287    pub fn name(mut self, name: impl Into<String>) -> Self {
288        self.name = Some(name.into());
289        self
290    }
291
292    /// Set description
293    pub fn description(mut self, description: impl Into<String>) -> Self {
294        self.description = Some(description.into());
295        self
296    }
297
298    /// Add tag
299    pub fn tag(mut self, tag: impl Into<String>) -> Self {
300        self.tags.push(tag.into());
301        self
302    }
303
304    /// Add multiple tags
305    pub fn tags(mut self, tags: Vec<String>) -> Self {
306        self.tags.extend(tags);
307        self
308    }
309
310    /// Set target object
311    pub fn target(mut self, target: T) -> Self {
312        self.target = Some(target);
313        self
314    }
315}
316
317impl<T, C> ConfigBuilder<T, C> for ConfigBuilderImpl<T, C>
318where
319    C: StandardConfig,
320    T: Default,
321{
322    fn config(mut self, config: C) -> Self {
323        self.config = Some(config);
324        self
325    }
326
327    fn get_config(&self) -> Option<&C> {
328        self.config.as_ref()
329    }
330}
331
332impl<T, C> Builder<T> for ConfigBuilderImpl<T, C>
333where
334    T: Default,
335    C: StandardConfig,
336{
337    fn build(self) -> Result<T> {
338        self.validate()?;
339        Ok(self.target.unwrap_or_default())
340    }
341
342    fn validate(&self) -> Result<()> {
343        if let Some(config) = &self.config {
344            config.validate()?;
345        }
346        Ok(())
347    }
348
349    fn reset(self) -> Self {
350        Self::new()
351    }
352}
353
354/// Quick builder macro for simple cases
355#[macro_export]
356macro_rules! quick_builder {
357    ($name:ident for $target:ty {
358        $(
359            $field:ident: $field_type:ty
360        ),* $(,)?
361    }) => {
362        #[derive(Debug, Clone, Default)]
363        pub struct $name {
364            $(
365                $field: Option<$field_type>,
366            )*
367        }
368
369        impl $name {
370            pub fn new() -> Self {
371                Self::default()
372            }
373
374            $(
375                pub fn $field(mut self, value: $field_type) -> Self {
376                    self.$field = Some(value);
377                    self
378                }
379            )*
380        }
381
382        impl Builder<$target> for $name {
383            fn build(self) -> Result<$target> {
384                // NOTE: This is a template implementation. Real builders should
385                // implement custom logic to construct the target type from the builder fields.
386                // Example implementation for a struct with the same fields:
387                // Ok($target {
388                //     $(
389                //         $field: self.$field.ok_or_else(|| {
390                //             crate::errors::TrustformersError::invalid_input {
391                //                 message: format!("Missing required field: {}", stringify!($field)),
392                //                 details: std::collections::HashMap::new(),
393                //             }
394                //         })?,
395                //     )*
396                // })
397
398                // For now, return a default instance if the target implements Default
399                Ok(<$target>::default())
400            }
401
402            fn reset(self) -> Self {
403                Self::default()
404            }
405        }
406    };
407}
408
409/// Builder validation error
410#[derive(Debug, thiserror::Error)]
411pub enum BuilderError {
412    #[error("Required field missing: {field}")]
413    MissingField { field: String },
414    #[error("Invalid value for field {field}: {reason}")]
415    InvalidValue { field: String, reason: String },
416    #[error("Builder validation failed: {reason}")]
417    ValidationFailed { reason: String },
418    #[error("Configuration error: {0}")]
419    ConfigError(String),
420}
421
422/// Result type for builder operations
423pub type BuilderResult<T> = std::result::Result<T, BuilderError>;
424
425/// Trait for objects that can be serialized as configuration
426pub trait ConfigSerializable {
427    /// Serialize to JSON string
428    fn to_json(&self) -> Result<String>;
429
430    /// Deserialize from JSON string
431    fn from_json(json: &str) -> Result<Self>
432    where
433        Self: Sized;
434
435    /// Save to file
436    fn save_to_file(&self, path: &std::path::Path) -> Result<()> {
437        let json = self.to_json()?;
438        std::fs::write(path, json)?;
439        Ok(())
440    }
441
442    /// Load from file
443    fn load_from_file(path: &std::path::Path) -> Result<Self>
444    where
445        Self: Sized,
446    {
447        let json = std::fs::read_to_string(path)?;
448        Self::from_json(&json)
449    }
450}
451
452/// Default implementation for types that implement Serialize + DeserializeOwned
453impl<T> ConfigSerializable for T
454where
455    T: Serialize + for<'de> Deserialize<'de>,
456{
457    fn to_json(&self) -> Result<String> {
458        Ok(serde_json::to_string_pretty(self)?)
459    }
460
461    fn from_json(json: &str) -> Result<Self> {
462        Ok(serde_json::from_str(json)?)
463    }
464}
465
466/// Example concrete builder implementations demonstrating best practices
467///
468/// Example model configuration that can be built
469#[derive(Debug, Clone, Default, PartialEq, Serialize, Deserialize)]
470pub struct ModelConfig {
471    pub name: String,
472    pub model_type: String,
473    pub max_length: usize,
474    pub batch_size: usize,
475    pub temperature: f32,
476    pub top_p: f32,
477}
478
479impl StandardConfig for ModelConfig {
480    fn validate(&self) -> Result<()> {
481        if self.name.is_empty() {
482            return Err(crate::errors::TrustformersError::invalid_input(
483                "Model name cannot be empty".to_string(),
484            ));
485        }
486        if self.max_length == 0 {
487            return Err(crate::errors::TrustformersError::invalid_input(
488                "Max length must be greater than 0".to_string(),
489            ));
490        }
491        if self.temperature < 0.0 || self.temperature > 2.0 {
492            return Err(crate::errors::TrustformersError::invalid_input(
493                "Temperature must be between 0.0 and 2.0".to_string(),
494            ));
495        }
496        if self.top_p < 0.0 || self.top_p > 1.0 {
497            return Err(crate::errors::TrustformersError::invalid_input(
498                "Top-p must be between 0.0 and 1.0".to_string(),
499            ));
500        }
501        Ok(())
502    }
503}
504
505/// Concrete builder for ModelConfig with proper validation
506#[derive(Debug, Clone, Default)]
507pub struct ModelConfigBuilder {
508    name: Option<String>,
509    model_type: Option<String>,
510    max_length: Option<usize>,
511    batch_size: Option<usize>,
512    temperature: Option<f32>,
513    top_p: Option<f32>,
514}
515
516impl ModelConfigBuilder {
517    pub fn new() -> Self {
518        Self::default()
519    }
520
521    pub fn name(mut self, name: impl Into<String>) -> Self {
522        self.name = Some(name.into());
523        self
524    }
525
526    pub fn model_type(mut self, model_type: impl Into<String>) -> Self {
527        self.model_type = Some(model_type.into());
528        self
529    }
530
531    pub fn max_length(mut self, max_length: usize) -> Self {
532        self.max_length = Some(max_length);
533        self
534    }
535
536    pub fn batch_size(mut self, batch_size: usize) -> Self {
537        self.batch_size = Some(batch_size);
538        self
539    }
540
541    pub fn temperature(mut self, temperature: f32) -> Self {
542        self.temperature = Some(temperature);
543        self
544    }
545
546    pub fn top_p(mut self, top_p: f32) -> Self {
547        self.top_p = Some(top_p);
548        self
549    }
550}
551
552impl Builder<ModelConfig> for ModelConfigBuilder {
553    fn build(self) -> Result<ModelConfig> {
554        let config = ModelConfig {
555            name: self.name.unwrap_or_default(),
556            model_type: self.model_type.unwrap_or_else(|| "transformer".to_string()),
557            max_length: self.max_length.unwrap_or(2048),
558            batch_size: self.batch_size.unwrap_or(1),
559            temperature: self.temperature.unwrap_or(1.0),
560            top_p: self.top_p.unwrap_or(1.0),
561        };
562
563        // Validate the built configuration
564        config.validate()?;
565        Ok(config)
566    }
567
568    fn reset(self) -> Self {
569        Self::default()
570    }
571}
572
573impl Buildable for ModelConfig {
574    type Builder = ModelConfigBuilder;
575
576    fn builder() -> Self::Builder {
577        ModelConfigBuilder::new()
578    }
579}
580
581/// Example training configuration
582#[derive(Debug, Clone, Default, PartialEq, Serialize, Deserialize)]
583pub struct TrainingConfig {
584    pub learning_rate: f64,
585    pub epochs: usize,
586    pub warmup_steps: usize,
587    pub weight_decay: f64,
588    pub gradient_clipping: f64,
589}
590
591impl StandardConfig for TrainingConfig {
592    fn validate(&self) -> Result<()> {
593        if self.learning_rate <= 0.0 {
594            return Err(crate::errors::TrustformersError::invalid_input(
595                "Learning rate must be positive".to_string(),
596            ));
597        }
598        if self.epochs == 0 {
599            return Err(crate::errors::TrustformersError::invalid_input(
600                "Epochs must be greater than 0".to_string(),
601            ));
602        }
603        Ok(())
604    }
605}
606
607/// Example concrete training config builder without using the macro
608#[derive(Debug, Clone, Default)]
609pub struct TrainingConfigBuilder {
610    learning_rate: Option<f64>,
611    epochs: Option<usize>,
612    warmup_steps: Option<usize>,
613    weight_decay: Option<f64>,
614    gradient_clipping: Option<f64>,
615}
616
617impl TrainingConfigBuilder {
618    pub fn new() -> Self {
619        Self::default()
620    }
621
622    pub fn learning_rate(mut self, learning_rate: f64) -> Self {
623        self.learning_rate = Some(learning_rate);
624        self
625    }
626
627    pub fn epochs(mut self, epochs: usize) -> Self {
628        self.epochs = Some(epochs);
629        self
630    }
631
632    pub fn warmup_steps(mut self, warmup_steps: usize) -> Self {
633        self.warmup_steps = Some(warmup_steps);
634        self
635    }
636
637    pub fn weight_decay(mut self, weight_decay: f64) -> Self {
638        self.weight_decay = Some(weight_decay);
639        self
640    }
641
642    pub fn gradient_clipping(mut self, gradient_clipping: f64) -> Self {
643        self.gradient_clipping = Some(gradient_clipping);
644        self
645    }
646}
647
648// Implement proper build method for TrainingConfigBuilder
649impl Builder<TrainingConfig> for TrainingConfigBuilder {
650    fn build(self) -> Result<TrainingConfig> {
651        let config = TrainingConfig {
652            learning_rate: self.learning_rate.unwrap_or(1e-4),
653            epochs: self.epochs.unwrap_or(10),
654            warmup_steps: self.warmup_steps.unwrap_or(1000),
655            weight_decay: self.weight_decay.unwrap_or(0.01),
656            gradient_clipping: self.gradient_clipping.unwrap_or(1.0),
657        };
658
659        config.validate()?;
660        Ok(config)
661    }
662
663    fn reset(self) -> Self {
664        Self::default()
665    }
666}
667
668#[cfg(test)]
669mod tests {
670    use super::*;
671
672    #[derive(Debug, Clone, Default, PartialEq)]
673    struct TestObject {
674        name: String,
675        value: i32,
676        enabled: bool,
677    }
678
679    #[derive(Debug, Clone, Default, Serialize, Deserialize)]
680    struct TestConfig {
681        timeout: u64,
682        retries: u32,
683    }
684
685    impl StandardConfig for TestConfig {}
686
687    #[test]
688    fn test_standard_builder() {
689        let mut builder: StandardBuilder<TestObject, BuilderIncomplete> = StandardBuilder::new();
690        builder.data_mut().name = "test".to_string();
691        builder.data_mut().value = 42;
692        builder.data_mut().enabled = true;
693
694        let obj = builder.complete().build().expect("operation failed in test");
695        assert_eq!(obj.name, "test");
696        assert_eq!(obj.value, 42);
697        assert!(obj.enabled);
698    }
699
700    #[test]
701    fn test_validated_builder() {
702        let builder = ValidatedBuilder::new().add_validator(|obj: &TestObject| {
703            if obj.name.is_empty() {
704                return Err(anyhow::anyhow!("Name cannot be empty").into());
705            }
706            Ok(())
707        });
708
709        // This should fail validation
710        let result = builder.build();
711        assert!(result.is_err());
712
713        // This should succeed
714        let mut builder = ValidatedBuilder::new().add_validator(|obj: &TestObject| {
715            if obj.name.is_empty() {
716                return Err(anyhow::anyhow!("Name cannot be empty").into());
717            }
718            Ok(())
719        });
720
721        builder.data_mut().name = "test".to_string();
722        let result = builder.build();
723        assert!(result.is_ok());
724    }
725
726    #[test]
727    fn test_config_builder() {
728        let config = TestConfig {
729            timeout: 5000,
730            retries: 3,
731        };
732
733        let builder = ConfigBuilderImpl::new()
734            .config(config)
735            .name("test_config")
736            .description("A test configuration")
737            .tag("test")
738            .target(TestObject::default());
739
740        let result = builder.build();
741        assert!(result.is_ok());
742    }
743
744    #[test]
745    fn test_config_serialization() {
746        let config = TestConfig {
747            timeout: 5000,
748            retries: 3,
749        };
750
751        let json = config.to_json().expect("operation failed in test");
752        let deserialized = TestConfig::from_json(&json).expect("operation failed in test");
753
754        assert_eq!(config.timeout, deserialized.timeout);
755        assert_eq!(config.retries, deserialized.retries);
756    }
757
758    // Example of using the quick_builder macro
759    quick_builder!(TestObjectBuilder for TestObject {
760        name: String,
761        value: i32,
762        enabled: bool
763    });
764
765    #[test]
766    fn test_quick_builder_creation() {
767        let builder = TestObjectBuilder::new().name("test".to_string()).value(42).enabled(true);
768
769        // Note: build() would need to be implemented for the specific type
770        // This just tests the builder pattern creation
771        assert!(builder.name.is_some());
772        assert!(builder.value.is_some());
773        assert!(builder.enabled.is_some());
774    }
775
776    #[test]
777    fn test_model_config_builder() {
778        let config = ModelConfig::builder()
779            .name("test-model")
780            .model_type("gpt")
781            .max_length(1024)
782            .batch_size(4)
783            .temperature(0.7)
784            .top_p(0.9)
785            .build()
786            .expect("operation failed in test");
787
788        assert_eq!(config.name, "test-model");
789        assert_eq!(config.model_type, "gpt");
790        assert_eq!(config.max_length, 1024);
791        assert_eq!(config.batch_size, 4);
792        assert_eq!(config.temperature, 0.7);
793        assert_eq!(config.top_p, 0.9);
794    }
795
796    #[test]
797    fn test_model_config_builder_validation() {
798        // Test validation failure - invalid temperature
799        let result = ModelConfig::builder()
800            .name("test")
801            .temperature(3.0) // Invalid: > 2.0
802            .build();
803        assert!(result.is_err());
804
805        // Test validation failure - invalid top_p
806        let result = ModelConfig::builder()
807            .name("test")
808            .top_p(1.5) // Invalid: > 1.0
809            .build();
810        assert!(result.is_err());
811
812        // Test validation success
813        let result = ModelConfig::builder().name("test").temperature(0.8).top_p(0.9).build();
814        assert!(result.is_ok());
815    }
816
817    #[test]
818    fn test_training_config_builder() {
819        let config = TrainingConfigBuilder::new()
820            .learning_rate(1e-3)
821            .epochs(5)
822            .warmup_steps(500)
823            .weight_decay(0.001)
824            .gradient_clipping(0.5)
825            .build()
826            .expect("operation failed in test");
827
828        assert_eq!(config.learning_rate, 1e-3);
829        assert_eq!(config.epochs, 5);
830        assert_eq!(config.warmup_steps, 500);
831        assert_eq!(config.weight_decay, 0.001);
832        assert_eq!(config.gradient_clipping, 0.5);
833    }
834
835    #[test]
836    fn test_training_config_builder_defaults() {
837        let config = TrainingConfigBuilder::new().build().expect("operation failed in test");
838
839        assert_eq!(config.learning_rate, 1e-4);
840        assert_eq!(config.epochs, 10);
841        assert_eq!(config.warmup_steps, 1000);
842        assert_eq!(config.weight_decay, 0.01);
843        assert_eq!(config.gradient_clipping, 1.0);
844    }
845
846    #[test]
847    fn test_training_config_validation() {
848        // Test validation failure - invalid learning rate
849        let result = TrainingConfigBuilder::new()
850            .learning_rate(-0.1) // Invalid: negative
851            .build();
852        assert!(result.is_err());
853
854        // Test validation failure - zero epochs
855        let result = TrainingConfigBuilder::new()
856            .epochs(0) // Invalid: zero
857            .build();
858        assert!(result.is_err());
859    }
860}