1#![allow(unused_variables)] use crate::errors::Result;
9use serde::{Deserialize, Serialize};
10use std::fmt::Debug;
11use std::marker::PhantomData;
12
13pub trait Builder<T> {
15 fn build(self) -> Result<T>;
17
18 fn validate(&self) -> Result<()> {
20 Ok(())
21 }
22
23 fn reset(self) -> Self
25 where
26 T: Default;
27}
28
29pub trait ConfigBuilder<T, C>: Builder<T> {
31 fn config(self, config: C) -> Self;
33
34 fn get_config(&self) -> Option<&C>;
36}
37
38#[derive(Debug, Clone)]
40pub struct StandardBuilder<T, S = BuilderComplete> {
41 data: T,
42 _state: PhantomData<S>,
43}
44
45#[derive(Debug, Clone)]
47pub struct BuilderIncomplete;
48
49#[derive(Debug, Clone)]
50pub struct BuilderComplete;
51
52pub trait Buildable: Sized + Default {
54 type Builder: Builder<Self>;
55
56 fn builder() -> Self::Builder;
58}
59
60pub trait StandardConfig: Debug + Clone + Default + Serialize + for<'de> Deserialize<'de> {
62 fn validate(&self) -> Result<()> {
64 Ok(())
65 }
66
67 fn merge(self, other: Self) -> Self {
69 self
70 }
71}
72
73impl<T> Default for StandardBuilder<T, BuilderIncomplete>
74where
75 T: Default,
76{
77 fn default() -> Self {
78 Self::new()
79 }
80}
81
82impl<T> StandardBuilder<T, BuilderIncomplete>
83where
84 T: Default,
85{
86 pub fn new() -> Self {
88 Self {
89 data: T::default(),
90 _state: PhantomData,
91 }
92 }
93
94 pub fn from(data: T) -> Self {
96 Self {
97 data,
98 _state: PhantomData,
99 }
100 }
101}
102
103impl<T> StandardBuilder<T, BuilderIncomplete>
104where
105 T: Clone,
106{
107 pub fn data_mut(&mut self) -> &mut T {
109 &mut self.data
110 }
111
112 pub fn complete(self) -> StandardBuilder<T, BuilderComplete> {
114 StandardBuilder {
115 data: self.data,
116 _state: PhantomData,
117 }
118 }
119}
120
121impl<T> StandardBuilder<T, BuilderComplete>
122where
123 T: Clone,
124{
125 pub fn data(&self) -> &T {
127 &self.data
128 }
129
130 pub fn data_mut(&mut self) -> &mut T {
132 &mut self.data
133 }
134}
135
136impl<T> Builder<T> for StandardBuilder<T, BuilderComplete>
137where
138 T: Clone + Default,
139{
140 fn build(self) -> Result<T> {
141 self.validate()?;
142 Ok(self.data)
143 }
144
145 fn reset(self) -> Self {
146 Self {
147 data: T::default(),
148 _state: PhantomData,
149 }
150 }
151}
152
153#[macro_export]
155macro_rules! builder_methods {
156 (
157 $builder_type:ty,
158 $target_type:ty,
159 {
160 $(
161 $method_name:ident : $field_type:ty = $field_name:ident
162 ),* $(,)?
163 }
164 ) => {
165 impl $builder_type {
166 $(
167 #[doc = concat!("Set ", stringify!($field_name))]
168 pub fn $method_name(mut self, value: $field_type) -> Self {
169 self.data.$field_name = value;
170 self
171 }
172 )*
173 }
174 };
175}
176
177pub type ValidationFn<T> = Box<dyn Fn(&T) -> Result<()> + Send + Sync>;
179
180pub struct ValidatedBuilder<T> {
182 data: T,
183 validators: Vec<ValidationFn<T>>,
184}
185
186impl<T> Default for ValidatedBuilder<T>
187where
188 T: Default,
189{
190 fn default() -> Self {
191 Self::new()
192 }
193}
194
195impl<T> ValidatedBuilder<T>
196where
197 T: Default,
198{
199 pub fn new() -> Self {
201 Self {
202 data: T::default(),
203 validators: Vec::new(),
204 }
205 }
206
207 pub fn add_validator<F>(mut self, validator: F) -> Self
209 where
210 F: Fn(&T) -> Result<()> + Send + Sync + 'static,
211 {
212 self.validators.push(Box::new(validator));
213 self
214 }
215
216 pub fn data(&self) -> &T {
218 &self.data
219 }
220
221 pub fn data_mut(&mut self) -> &mut T {
223 &mut self.data
224 }
225}
226
227impl<T> Builder<T> for ValidatedBuilder<T>
228where
229 T: Clone,
230{
231 fn build(self) -> Result<T> {
232 self.validate()?;
233 Ok(self.data)
234 }
235
236 fn validate(&self) -> Result<()> {
237 for validator in &self.validators {
238 validator(&self.data)?;
239 }
240 Ok(())
241 }
242
243 fn reset(mut self) -> Self
244 where
245 T: Default,
246 {
247 self.data = T::default();
248 self
249 }
250}
251
252#[derive(Debug, Clone)]
254pub struct ConfigBuilderImpl<T, C> {
255 target: Option<T>,
256 config: Option<C>,
257 name: Option<String>,
258 description: Option<String>,
259 tags: Vec<String>,
260}
261
262impl<T, C> Default for ConfigBuilderImpl<T, C>
263where
264 C: StandardConfig,
265{
266 fn default() -> Self {
267 Self::new()
268 }
269}
270
271impl<T, C> ConfigBuilderImpl<T, C>
272where
273 C: StandardConfig,
274{
275 pub fn new() -> Self {
277 Self {
278 target: None,
279 config: None,
280 name: None,
281 description: None,
282 tags: Vec::new(),
283 }
284 }
285
286 pub fn name(mut self, name: impl Into<String>) -> Self {
288 self.name = Some(name.into());
289 self
290 }
291
292 pub fn description(mut self, description: impl Into<String>) -> Self {
294 self.description = Some(description.into());
295 self
296 }
297
298 pub fn tag(mut self, tag: impl Into<String>) -> Self {
300 self.tags.push(tag.into());
301 self
302 }
303
304 pub fn tags(mut self, tags: Vec<String>) -> Self {
306 self.tags.extend(tags);
307 self
308 }
309
310 pub fn target(mut self, target: T) -> Self {
312 self.target = Some(target);
313 self
314 }
315}
316
317impl<T, C> ConfigBuilder<T, C> for ConfigBuilderImpl<T, C>
318where
319 C: StandardConfig,
320 T: Default,
321{
322 fn config(mut self, config: C) -> Self {
323 self.config = Some(config);
324 self
325 }
326
327 fn get_config(&self) -> Option<&C> {
328 self.config.as_ref()
329 }
330}
331
332impl<T, C> Builder<T> for ConfigBuilderImpl<T, C>
333where
334 T: Default,
335 C: StandardConfig,
336{
337 fn build(self) -> Result<T> {
338 self.validate()?;
339 Ok(self.target.unwrap_or_default())
340 }
341
342 fn validate(&self) -> Result<()> {
343 if let Some(config) = &self.config {
344 config.validate()?;
345 }
346 Ok(())
347 }
348
349 fn reset(self) -> Self {
350 Self::new()
351 }
352}
353
354#[macro_export]
356macro_rules! quick_builder {
357 ($name:ident for $target:ty {
358 $(
359 $field:ident: $field_type:ty
360 ),* $(,)?
361 }) => {
362 #[derive(Debug, Clone, Default)]
363 pub struct $name {
364 $(
365 $field: Option<$field_type>,
366 )*
367 }
368
369 impl $name {
370 pub fn new() -> Self {
371 Self::default()
372 }
373
374 $(
375 pub fn $field(mut self, value: $field_type) -> Self {
376 self.$field = Some(value);
377 self
378 }
379 )*
380 }
381
382 impl Builder<$target> for $name {
383 fn build(self) -> Result<$target> {
384 Ok(<$target>::default())
400 }
401
402 fn reset(self) -> Self {
403 Self::default()
404 }
405 }
406 };
407}
408
409#[derive(Debug, thiserror::Error)]
411pub enum BuilderError {
412 #[error("Required field missing: {field}")]
413 MissingField { field: String },
414 #[error("Invalid value for field {field}: {reason}")]
415 InvalidValue { field: String, reason: String },
416 #[error("Builder validation failed: {reason}")]
417 ValidationFailed { reason: String },
418 #[error("Configuration error: {0}")]
419 ConfigError(String),
420}
421
422pub type BuilderResult<T> = std::result::Result<T, BuilderError>;
424
425pub trait ConfigSerializable {
427 fn to_json(&self) -> Result<String>;
429
430 fn from_json(json: &str) -> Result<Self>
432 where
433 Self: Sized;
434
435 fn save_to_file(&self, path: &std::path::Path) -> Result<()> {
437 let json = self.to_json()?;
438 std::fs::write(path, json)?;
439 Ok(())
440 }
441
442 fn load_from_file(path: &std::path::Path) -> Result<Self>
444 where
445 Self: Sized,
446 {
447 let json = std::fs::read_to_string(path)?;
448 Self::from_json(&json)
449 }
450}
451
452impl<T> ConfigSerializable for T
454where
455 T: Serialize + for<'de> Deserialize<'de>,
456{
457 fn to_json(&self) -> Result<String> {
458 Ok(serde_json::to_string_pretty(self)?)
459 }
460
461 fn from_json(json: &str) -> Result<Self> {
462 Ok(serde_json::from_str(json)?)
463 }
464}
465
466#[derive(Debug, Clone, Default, PartialEq, Serialize, Deserialize)]
470pub struct ModelConfig {
471 pub name: String,
472 pub model_type: String,
473 pub max_length: usize,
474 pub batch_size: usize,
475 pub temperature: f32,
476 pub top_p: f32,
477}
478
479impl StandardConfig for ModelConfig {
480 fn validate(&self) -> Result<()> {
481 if self.name.is_empty() {
482 return Err(crate::errors::TrustformersError::invalid_input(
483 "Model name cannot be empty".to_string(),
484 ));
485 }
486 if self.max_length == 0 {
487 return Err(crate::errors::TrustformersError::invalid_input(
488 "Max length must be greater than 0".to_string(),
489 ));
490 }
491 if self.temperature < 0.0 || self.temperature > 2.0 {
492 return Err(crate::errors::TrustformersError::invalid_input(
493 "Temperature must be between 0.0 and 2.0".to_string(),
494 ));
495 }
496 if self.top_p < 0.0 || self.top_p > 1.0 {
497 return Err(crate::errors::TrustformersError::invalid_input(
498 "Top-p must be between 0.0 and 1.0".to_string(),
499 ));
500 }
501 Ok(())
502 }
503}
504
505#[derive(Debug, Clone, Default)]
507pub struct ModelConfigBuilder {
508 name: Option<String>,
509 model_type: Option<String>,
510 max_length: Option<usize>,
511 batch_size: Option<usize>,
512 temperature: Option<f32>,
513 top_p: Option<f32>,
514}
515
516impl ModelConfigBuilder {
517 pub fn new() -> Self {
518 Self::default()
519 }
520
521 pub fn name(mut self, name: impl Into<String>) -> Self {
522 self.name = Some(name.into());
523 self
524 }
525
526 pub fn model_type(mut self, model_type: impl Into<String>) -> Self {
527 self.model_type = Some(model_type.into());
528 self
529 }
530
531 pub fn max_length(mut self, max_length: usize) -> Self {
532 self.max_length = Some(max_length);
533 self
534 }
535
536 pub fn batch_size(mut self, batch_size: usize) -> Self {
537 self.batch_size = Some(batch_size);
538 self
539 }
540
541 pub fn temperature(mut self, temperature: f32) -> Self {
542 self.temperature = Some(temperature);
543 self
544 }
545
546 pub fn top_p(mut self, top_p: f32) -> Self {
547 self.top_p = Some(top_p);
548 self
549 }
550}
551
552impl Builder<ModelConfig> for ModelConfigBuilder {
553 fn build(self) -> Result<ModelConfig> {
554 let config = ModelConfig {
555 name: self.name.unwrap_or_default(),
556 model_type: self.model_type.unwrap_or_else(|| "transformer".to_string()),
557 max_length: self.max_length.unwrap_or(2048),
558 batch_size: self.batch_size.unwrap_or(1),
559 temperature: self.temperature.unwrap_or(1.0),
560 top_p: self.top_p.unwrap_or(1.0),
561 };
562
563 config.validate()?;
565 Ok(config)
566 }
567
568 fn reset(self) -> Self {
569 Self::default()
570 }
571}
572
573impl Buildable for ModelConfig {
574 type Builder = ModelConfigBuilder;
575
576 fn builder() -> Self::Builder {
577 ModelConfigBuilder::new()
578 }
579}
580
581#[derive(Debug, Clone, Default, PartialEq, Serialize, Deserialize)]
583pub struct TrainingConfig {
584 pub learning_rate: f64,
585 pub epochs: usize,
586 pub warmup_steps: usize,
587 pub weight_decay: f64,
588 pub gradient_clipping: f64,
589}
590
591impl StandardConfig for TrainingConfig {
592 fn validate(&self) -> Result<()> {
593 if self.learning_rate <= 0.0 {
594 return Err(crate::errors::TrustformersError::invalid_input(
595 "Learning rate must be positive".to_string(),
596 ));
597 }
598 if self.epochs == 0 {
599 return Err(crate::errors::TrustformersError::invalid_input(
600 "Epochs must be greater than 0".to_string(),
601 ));
602 }
603 Ok(())
604 }
605}
606
607#[derive(Debug, Clone, Default)]
609pub struct TrainingConfigBuilder {
610 learning_rate: Option<f64>,
611 epochs: Option<usize>,
612 warmup_steps: Option<usize>,
613 weight_decay: Option<f64>,
614 gradient_clipping: Option<f64>,
615}
616
617impl TrainingConfigBuilder {
618 pub fn new() -> Self {
619 Self::default()
620 }
621
622 pub fn learning_rate(mut self, learning_rate: f64) -> Self {
623 self.learning_rate = Some(learning_rate);
624 self
625 }
626
627 pub fn epochs(mut self, epochs: usize) -> Self {
628 self.epochs = Some(epochs);
629 self
630 }
631
632 pub fn warmup_steps(mut self, warmup_steps: usize) -> Self {
633 self.warmup_steps = Some(warmup_steps);
634 self
635 }
636
637 pub fn weight_decay(mut self, weight_decay: f64) -> Self {
638 self.weight_decay = Some(weight_decay);
639 self
640 }
641
642 pub fn gradient_clipping(mut self, gradient_clipping: f64) -> Self {
643 self.gradient_clipping = Some(gradient_clipping);
644 self
645 }
646}
647
648impl Builder<TrainingConfig> for TrainingConfigBuilder {
650 fn build(self) -> Result<TrainingConfig> {
651 let config = TrainingConfig {
652 learning_rate: self.learning_rate.unwrap_or(1e-4),
653 epochs: self.epochs.unwrap_or(10),
654 warmup_steps: self.warmup_steps.unwrap_or(1000),
655 weight_decay: self.weight_decay.unwrap_or(0.01),
656 gradient_clipping: self.gradient_clipping.unwrap_or(1.0),
657 };
658
659 config.validate()?;
660 Ok(config)
661 }
662
663 fn reset(self) -> Self {
664 Self::default()
665 }
666}
667
668#[cfg(test)]
669mod tests {
670 use super::*;
671
672 #[derive(Debug, Clone, Default, PartialEq)]
673 struct TestObject {
674 name: String,
675 value: i32,
676 enabled: bool,
677 }
678
679 #[derive(Debug, Clone, Default, Serialize, Deserialize)]
680 struct TestConfig {
681 timeout: u64,
682 retries: u32,
683 }
684
685 impl StandardConfig for TestConfig {}
686
687 #[test]
688 fn test_standard_builder() {
689 let mut builder: StandardBuilder<TestObject, BuilderIncomplete> = StandardBuilder::new();
690 builder.data_mut().name = "test".to_string();
691 builder.data_mut().value = 42;
692 builder.data_mut().enabled = true;
693
694 let obj = builder.complete().build().expect("operation failed in test");
695 assert_eq!(obj.name, "test");
696 assert_eq!(obj.value, 42);
697 assert!(obj.enabled);
698 }
699
700 #[test]
701 fn test_validated_builder() {
702 let builder = ValidatedBuilder::new().add_validator(|obj: &TestObject| {
703 if obj.name.is_empty() {
704 return Err(anyhow::anyhow!("Name cannot be empty").into());
705 }
706 Ok(())
707 });
708
709 let result = builder.build();
711 assert!(result.is_err());
712
713 let mut builder = ValidatedBuilder::new().add_validator(|obj: &TestObject| {
715 if obj.name.is_empty() {
716 return Err(anyhow::anyhow!("Name cannot be empty").into());
717 }
718 Ok(())
719 });
720
721 builder.data_mut().name = "test".to_string();
722 let result = builder.build();
723 assert!(result.is_ok());
724 }
725
726 #[test]
727 fn test_config_builder() {
728 let config = TestConfig {
729 timeout: 5000,
730 retries: 3,
731 };
732
733 let builder = ConfigBuilderImpl::new()
734 .config(config)
735 .name("test_config")
736 .description("A test configuration")
737 .tag("test")
738 .target(TestObject::default());
739
740 let result = builder.build();
741 assert!(result.is_ok());
742 }
743
744 #[test]
745 fn test_config_serialization() {
746 let config = TestConfig {
747 timeout: 5000,
748 retries: 3,
749 };
750
751 let json = config.to_json().expect("operation failed in test");
752 let deserialized = TestConfig::from_json(&json).expect("operation failed in test");
753
754 assert_eq!(config.timeout, deserialized.timeout);
755 assert_eq!(config.retries, deserialized.retries);
756 }
757
758 quick_builder!(TestObjectBuilder for TestObject {
760 name: String,
761 value: i32,
762 enabled: bool
763 });
764
765 #[test]
766 fn test_quick_builder_creation() {
767 let builder = TestObjectBuilder::new().name("test".to_string()).value(42).enabled(true);
768
769 assert!(builder.name.is_some());
772 assert!(builder.value.is_some());
773 assert!(builder.enabled.is_some());
774 }
775
776 #[test]
777 fn test_model_config_builder() {
778 let config = ModelConfig::builder()
779 .name("test-model")
780 .model_type("gpt")
781 .max_length(1024)
782 .batch_size(4)
783 .temperature(0.7)
784 .top_p(0.9)
785 .build()
786 .expect("operation failed in test");
787
788 assert_eq!(config.name, "test-model");
789 assert_eq!(config.model_type, "gpt");
790 assert_eq!(config.max_length, 1024);
791 assert_eq!(config.batch_size, 4);
792 assert_eq!(config.temperature, 0.7);
793 assert_eq!(config.top_p, 0.9);
794 }
795
796 #[test]
797 fn test_model_config_builder_validation() {
798 let result = ModelConfig::builder()
800 .name("test")
801 .temperature(3.0) .build();
803 assert!(result.is_err());
804
805 let result = ModelConfig::builder()
807 .name("test")
808 .top_p(1.5) .build();
810 assert!(result.is_err());
811
812 let result = ModelConfig::builder().name("test").temperature(0.8).top_p(0.9).build();
814 assert!(result.is_ok());
815 }
816
817 #[test]
818 fn test_training_config_builder() {
819 let config = TrainingConfigBuilder::new()
820 .learning_rate(1e-3)
821 .epochs(5)
822 .warmup_steps(500)
823 .weight_decay(0.001)
824 .gradient_clipping(0.5)
825 .build()
826 .expect("operation failed in test");
827
828 assert_eq!(config.learning_rate, 1e-3);
829 assert_eq!(config.epochs, 5);
830 assert_eq!(config.warmup_steps, 500);
831 assert_eq!(config.weight_decay, 0.001);
832 assert_eq!(config.gradient_clipping, 0.5);
833 }
834
835 #[test]
836 fn test_training_config_builder_defaults() {
837 let config = TrainingConfigBuilder::new().build().expect("operation failed in test");
838
839 assert_eq!(config.learning_rate, 1e-4);
840 assert_eq!(config.epochs, 10);
841 assert_eq!(config.warmup_steps, 1000);
842 assert_eq!(config.weight_decay, 0.01);
843 assert_eq!(config.gradient_clipping, 1.0);
844 }
845
846 #[test]
847 fn test_training_config_validation() {
848 let result = TrainingConfigBuilder::new()
850 .learning_rate(-0.1) .build();
852 assert!(result.is_err());
853
854 let result = TrainingConfigBuilder::new()
856 .epochs(0) .build();
858 assert!(result.is_err());
859 }
860}