zzstat/
transform.rs

1//! Stat transforms module.
2//!
3//! Transforms modify stat values after sources are collected.
4//! Transforms can read other stats (dependencies) and must declare
5//! them explicitly via `depends_on()`.
6
7use crate::context::StatContext;
8use crate::error::StatError;
9use crate::stat_id::StatId;
10use std::collections::HashMap;
11
12/// Trait for stat transforms that modify stat values.
13///
14/// Transforms can read other stats (dependencies) and must declare
15/// them explicitly. The resolver ensures dependencies are resolved
16/// before applying the transform.
17///
18/// # Examples
19///
20/// ```rust
21/// use zzstat::transform::{StatTransform, MultiplicativeTransform};
22/// use zzstat::{StatContext, StatId};
23/// use std::collections::HashMap;
24///
25/// let transform = MultiplicativeTransform::new(1.5);
26/// let context = StatContext::new();
27/// let deps = HashMap::new();
28///
29/// let result = transform.apply(100.0, &deps, &context).unwrap();
30/// assert_eq!(result, 150.0);
31/// ```
32pub trait StatTransform: Send + Sync {
33    /// Get the list of stat IDs this transform depends on.
34    ///
35    /// These stats must be resolved before this transform can be applied.
36    /// The resolver uses this information to build the dependency graph
37    /// and determine resolution order.
38    ///
39    /// # Returns
40    ///
41    /// A vector of stat IDs that this transform depends on.
42    fn depends_on(&self) -> Vec<StatId>;
43
44    /// Apply the transform to an input value.
45    ///
46    /// # Arguments
47    ///
48    /// * `input` - The current stat value (after sources and previous transforms)
49    /// * `dependencies` - Map of resolved dependency stats (keyed by StatId)
50    /// * `context` - The stat context (for conditional transforms)
51    ///
52    /// # Returns
53    ///
54    /// The transformed value, or an error if the transform cannot be applied.
55    fn apply(
56        &self,
57        input: f64,
58        dependencies: &HashMap<StatId, f64>,
59        context: &StatContext,
60    ) -> Result<f64, StatError>;
61
62    /// Get a human-readable description of this transform.
63    ///
64    /// Used for debugging and breakdown information in `ResolvedStat`.
65    ///
66    /// # Returns
67    ///
68    /// A string describing what this transform does.
69    fn description(&self) -> String;
70}
71
72/// A multiplicative transform (percentage modifier).
73///
74/// Multiplies the input value by a constant factor.
75///
76/// # Examples
77///
78/// ```rust
79/// use zzstat::transform::{StatTransform, MultiplicativeTransform};
80/// use zzstat::StatContext;
81/// use std::collections::HashMap;
82///
83/// let transform = MultiplicativeTransform::new(1.5);
84/// let context = StatContext::new();
85/// let deps = HashMap::new();
86///
87/// // 100 * 1.5 = 150
88/// assert_eq!(transform.apply(100.0, &deps, &context).unwrap(), 150.0);
89/// ```
90#[derive(Debug, Clone)]
91pub struct MultiplicativeTransform {
92    multiplier: f64,
93}
94
95impl MultiplicativeTransform {
96    /// Create a new multiplicative transform.
97    ///
98    /// # Arguments
99    ///
100    /// * `multiplier` - The multiplier to apply (e.g., 1.5 for +50%)
101    ///
102    /// # Examples
103    ///
104    /// ```rust
105    /// use zzstat::transform::MultiplicativeTransform;
106    ///
107    /// // +50% bonus
108    /// let bonus = MultiplicativeTransform::new(1.5);
109    ///
110    /// // -20% penalty
111    /// let penalty = MultiplicativeTransform::new(0.8);
112    /// ```
113    pub fn new(multiplier: f64) -> Self {
114        Self { multiplier }
115    }
116}
117
118impl StatTransform for MultiplicativeTransform {
119    fn depends_on(&self) -> Vec<StatId> {
120        Vec::new()
121    }
122
123    fn apply(
124        &self,
125        input: f64,
126        _dependencies: &HashMap<StatId, f64>,
127        _context: &StatContext,
128    ) -> Result<f64, StatError> {
129        Ok(input * self.multiplier)
130    }
131
132    fn description(&self) -> String {
133        format!("×{:.2}", self.multiplier)
134    }
135}
136
137/// An additive transform (flat bonus).
138///
139/// Adds a constant value to the input.
140///
141/// # Examples
142///
143/// ```rust
144/// use zzstat::transform::{StatTransform, AdditiveTransform};
145/// use zzstat::StatContext;
146/// use std::collections::HashMap;
147///
148/// let transform = AdditiveTransform::new(25.0);
149/// let context = StatContext::new();
150/// let deps = HashMap::new();
151///
152/// // 100 + 25 = 125
153/// assert_eq!(transform.apply(100.0, &deps, &context).unwrap(), 125.0);
154/// ```
155#[derive(Debug, Clone)]
156pub struct AdditiveTransform {
157    bonus: f64,
158}
159
160impl AdditiveTransform {
161    /// Create a new additive transform.
162    ///
163    /// # Arguments
164    ///
165    /// * `bonus` - The flat bonus to add (can be negative for penalties)
166    ///
167    /// # Examples
168    ///
169    /// ```rust
170    /// use zzstat::transform::AdditiveTransform;
171    ///
172    /// // +25 flat bonus
173    /// let bonus = AdditiveTransform::new(25.0);
174    ///
175    /// // -10 flat penalty
176    /// let penalty = AdditiveTransform::new(-10.0);
177    /// ```
178    pub fn new(bonus: f64) -> Self {
179        Self { bonus }
180    }
181}
182
183impl StatTransform for AdditiveTransform {
184    fn depends_on(&self) -> Vec<StatId> {
185        Vec::new()
186    }
187
188    fn apply(
189        &self,
190        input: f64,
191        _dependencies: &HashMap<StatId, f64>,
192        _context: &StatContext,
193    ) -> Result<f64, StatError> {
194        Ok(input + self.bonus)
195    }
196
197    fn description(&self) -> String {
198        format!("+{:.2}", self.bonus)
199    }
200}
201
202/// A clamp transform that restricts values to a range.
203///
204/// Ensures the output value is between `min` and `max` (inclusive).
205///
206/// # Examples
207///
208/// ```rust
209/// use zzstat::transform::{StatTransform, ClampTransform};
210/// use zzstat::StatContext;
211/// use std::collections::HashMap;
212///
213/// let transform = ClampTransform::new(0.0, 100.0);
214/// let context = StatContext::new();
215/// let deps = HashMap::new();
216///
217/// assert_eq!(transform.apply(150.0, &deps, &context).unwrap(), 100.0);
218/// assert_eq!(transform.apply(-10.0, &deps, &context).unwrap(), 0.0);
219/// assert_eq!(transform.apply(50.0, &deps, &context).unwrap(), 50.0);
220/// ```
221#[derive(Debug, Clone)]
222pub struct ClampTransform {
223    min: f64,
224    max: f64,
225}
226
227impl ClampTransform {
228    /// Create a new clamp transform.
229    ///
230    /// # Arguments
231    ///
232    /// * `min` - Minimum allowed value (inclusive)
233    /// * `max` - Maximum allowed value (inclusive)
234    ///
235    /// # Panics
236    ///
237    /// This function does not panic, but if `min > max`, the behavior
238    /// is undefined (values will never pass the clamp).
239    ///
240    /// # Examples
241    ///
242    /// ```rust
243    /// use zzstat::transform::ClampTransform;
244    ///
245    /// // Clamp between 0 and 100
246    /// let clamp = ClampTransform::new(0.0, 100.0);
247    /// ```
248    pub fn new(min: f64, max: f64) -> Self {
249        Self { min, max }
250    }
251}
252
253impl StatTransform for ClampTransform {
254    fn depends_on(&self) -> Vec<StatId> {
255        Vec::new()
256    }
257
258    fn apply(
259        &self,
260        input: f64,
261        _dependencies: &HashMap<StatId, f64>,
262        _context: &StatContext,
263    ) -> Result<f64, StatError> {
264        Ok(input.clamp(self.min, self.max))
265    }
266
267    fn description(&self) -> String {
268        format!("clamp({:.2}, {:.2})", self.min, self.max)
269    }
270}
271
272/// A conditional transform that applies another transform based on a condition.
273///
274/// Only applies the inner transform if the condition function returns `true`
275/// when called with the current `StatContext`. Otherwise, returns the input
276/// value unchanged.
277///
278/// # Examples
279///
280/// ```rust
281/// use zzstat::transform::{StatTransform, ConditionalTransform, MultiplicativeTransform};
282/// use zzstat::StatContext;
283/// use std::collections::HashMap;
284///
285/// let mut context = StatContext::new();
286/// context.set("in_combat", true);
287///
288/// let inner_transform = Box::new(MultiplicativeTransform::new(1.2));
289/// let transform = ConditionalTransform::new(
290///     |ctx| ctx.get::<bool>("in_combat").unwrap_or(false),
291///     inner_transform,
292///     "combat bonus",
293/// );
294///
295/// let deps = HashMap::new();
296/// // In combat: 100 * 1.2 = 120
297/// assert_eq!(transform.apply(100.0, &deps, &context).unwrap(), 120.0);
298///
299/// context.set("in_combat", false);
300/// // Out of combat: 100 (unchanged)
301/// assert_eq!(transform.apply(100.0, &deps, &context).unwrap(), 100.0);
302/// ```
303pub struct ConditionalTransform {
304    condition: Box<dyn Fn(&StatContext) -> bool + Send + Sync>,
305    transform: Box<dyn StatTransform>,
306    description: String,
307}
308
309impl ConditionalTransform {
310    /// Create a new conditional transform.
311    ///
312    /// # Arguments
313    ///
314    /// * `condition` - A function that takes `&StatContext` and returns `bool`
315    /// * `transform` - The transform to apply when condition is `true`
316    /// * `description` - Human-readable description for debugging
317    ///
318    /// # Examples
319    ///
320    /// ```rust
321    /// use zzstat::transform::{ConditionalTransform, MultiplicativeTransform};
322    ///
323    /// let inner = Box::new(MultiplicativeTransform::new(1.5));
324    /// let transform = ConditionalTransform::new(
325    ///     |ctx| ctx.get::<bool>("in_combat").unwrap_or(false),
326    ///     inner,
327    ///     "combat bonus +50%",
328    /// );
329    /// ```
330    pub fn new<F>(
331        condition: F,
332        transform: Box<dyn StatTransform>,
333        description: impl Into<String>,
334    ) -> Self
335    where
336        F: Fn(&StatContext) -> bool + Send + Sync + 'static,
337    {
338        Self {
339            condition: Box::new(condition),
340            transform,
341            description: description.into(),
342        }
343    }
344}
345
346impl StatTransform for ConditionalTransform {
347    fn depends_on(&self) -> Vec<StatId> {
348        self.transform.depends_on()
349    }
350
351    fn apply(
352        &self,
353        input: f64,
354        dependencies: &HashMap<StatId, f64>,
355        context: &StatContext,
356    ) -> Result<f64, StatError> {
357        if (self.condition)(context) {
358            self.transform.apply(input, dependencies, context)
359        } else {
360            Ok(input)
361        }
362    }
363
364    fn description(&self) -> String {
365        self.description.clone()
366    }
367}
368
369/// A transform that scales based on another stat.
370///
371/// Adds `dependency_value * scale_factor` to the input value.
372/// This is commonly used for derived stats (e.g., ATK = base + STR * 2).
373///
374/// # Examples
375///
376/// ```rust
377/// use zzstat::transform::{StatTransform, ScalingTransform};
378/// use zzstat::{StatId, StatContext};
379/// use std::collections::HashMap;
380///
381/// let str_id = StatId::from_str("STR");
382/// let transform = ScalingTransform::new(str_id.clone(), 2.0);
383///
384/// let mut deps = HashMap::new();
385/// deps.insert(str_id.clone(), 10.0);
386///
387/// let context = StatContext::new();
388/// // 100 (base) + 10 (STR) * 2 = 120
389/// assert_eq!(transform.apply(100.0, &deps, &context).unwrap(), 120.0);
390/// ```
391#[derive(Debug, Clone)]
392pub struct ScalingTransform {
393    dependency: StatId,
394    scale_factor: f64,
395}
396
397impl ScalingTransform {
398    /// Create a new scaling transform.
399    ///
400    /// # Arguments
401    ///
402    /// * `dependency` - The stat ID this transform depends on
403    /// * `scale_factor` - The multiplier to apply to the dependency value
404    ///
405    /// # Examples
406    ///
407    /// ```rust
408    /// use zzstat::transform::ScalingTransform;
409    /// use zzstat::StatId;
410    ///
411    /// let str_id = StatId::from_str("STR");
412    /// // ATK scales with STR: ATK = base + STR * 2
413    /// let transform = ScalingTransform::new(str_id, 2.0);
414    /// ```
415    pub fn new(dependency: StatId, scale_factor: f64) -> Self {
416        Self {
417            dependency,
418            scale_factor,
419        }
420    }
421}
422
423impl StatTransform for ScalingTransform {
424    fn depends_on(&self) -> Vec<StatId> {
425        vec![self.dependency.clone()]
426    }
427
428    fn apply(
429        &self,
430        input: f64,
431        dependencies: &HashMap<StatId, f64>,
432        _context: &StatContext,
433    ) -> Result<f64, StatError> {
434        let dep_value = dependencies
435            .get(&self.dependency)
436            .ok_or_else(|| StatError::MissingDependency(self.dependency.clone()))?;
437        Ok(input + (dep_value * self.scale_factor))
438    }
439
440    fn description(&self) -> String {
441        format!("scale({}, {:.2})", self.dependency, self.scale_factor)
442    }
443}
444
445#[cfg(test)]
446mod tests {
447    use super::*;
448
449    #[test]
450    fn test_multiplicative_transform() {
451        let transform = MultiplicativeTransform::new(1.5);
452        let context = StatContext::new();
453        let deps = HashMap::new();
454
455        assert_eq!(transform.apply(100.0, &deps, &context).unwrap(), 150.0);
456    }
457
458    #[test]
459    fn test_additive_transform() {
460        let transform = AdditiveTransform::new(25.0);
461        let context = StatContext::new();
462        let deps = HashMap::new();
463
464        assert_eq!(transform.apply(100.0, &deps, &context).unwrap(), 125.0);
465    }
466
467    #[test]
468    fn test_clamp_transform() {
469        let transform = ClampTransform::new(0.0, 100.0);
470        let context = StatContext::new();
471        let deps = HashMap::new();
472
473        assert_eq!(transform.apply(150.0, &deps, &context).unwrap(), 100.0);
474        assert_eq!(transform.apply(-10.0, &deps, &context).unwrap(), 0.0);
475        assert_eq!(transform.apply(50.0, &deps, &context).unwrap(), 50.0);
476    }
477
478    #[test]
479    fn test_scaling_transform() {
480        let str_id = StatId::from_str("STR");
481        let transform = ScalingTransform::new(str_id.clone(), 2.0);
482        let context = StatContext::new();
483        let mut deps = HashMap::new();
484        deps.insert(str_id.clone(), 10.0);
485
486        assert_eq!(transform.depends_on(), vec![str_id]);
487        assert_eq!(transform.apply(100.0, &deps, &context).unwrap(), 120.0);
488    }
489
490    #[test]
491    fn test_scaling_transform_missing_dependency() {
492        let str_id = StatId::from_str("STR");
493        let transform = ScalingTransform::new(str_id, 2.0);
494        let context = StatContext::new();
495        let deps = HashMap::new();
496
497        assert!(transform.apply(100.0, &deps, &context).is_err());
498    }
499
500    #[test]
501    fn test_conditional_transform() {
502        let mut context = StatContext::new();
503        context.set("in_combat", true);
504
505        let inner_transform = Box::new(MultiplicativeTransform::new(1.2));
506        let transform = ConditionalTransform::new(
507            |ctx| ctx.get::<bool>("in_combat").unwrap_or(false),
508            inner_transform,
509            "combat bonus",
510        );
511
512        let deps = HashMap::new();
513        assert_eq!(transform.apply(100.0, &deps, &context).unwrap(), 120.0);
514
515        context.set("in_combat", false);
516        assert_eq!(transform.apply(100.0, &deps, &context).unwrap(), 100.0);
517    }
518}