zzstat/transform.rs
1//! Stat transforms module.
2//!
3//! Transforms modify stat values after sources are collected.
4//! Transforms can read other stats (dependencies) and must declare
5//! them explicitly via `depends_on()`.
6
7use crate::context::StatContext;
8use crate::error::StatError;
9use crate::stat_id::StatId;
10use std::collections::HashMap;
11
12/// Trait for stat transforms that modify stat values.
13///
14/// Transforms can read other stats (dependencies) and must declare
15/// them explicitly. The resolver ensures dependencies are resolved
16/// before applying the transform.
17///
18/// # Examples
19///
20/// ```rust
21/// use zzstat::transform::{StatTransform, MultiplicativeTransform};
22/// use zzstat::{StatContext, StatId};
23/// use std::collections::HashMap;
24///
25/// let transform = MultiplicativeTransform::new(1.5);
26/// let context = StatContext::new();
27/// let deps = HashMap::new();
28///
29/// let result = transform.apply(100.0, &deps, &context).unwrap();
30/// assert_eq!(result, 150.0);
31/// ```
32pub trait StatTransform: Send + Sync {
33 /// Get the list of stat IDs this transform depends on.
34 ///
35 /// These stats must be resolved before this transform can be applied.
36 /// The resolver uses this information to build the dependency graph
37 /// and determine resolution order.
38 ///
39 /// # Returns
40 ///
41 /// A vector of stat IDs that this transform depends on.
42 fn depends_on(&self) -> Vec<StatId>;
43
44 /// Apply the transform to an input value.
45 ///
46 /// # Arguments
47 ///
48 /// * `input` - The current stat value (after sources and previous transforms)
49 /// * `dependencies` - Map of resolved dependency stats (keyed by StatId)
50 /// * `context` - The stat context (for conditional transforms)
51 ///
52 /// # Returns
53 ///
54 /// The transformed value, or an error if the transform cannot be applied.
55 fn apply(
56 &self,
57 input: f64,
58 dependencies: &HashMap<StatId, f64>,
59 context: &StatContext,
60 ) -> Result<f64, StatError>;
61
62 /// Get a human-readable description of this transform.
63 ///
64 /// Used for debugging and breakdown information in `ResolvedStat`.
65 ///
66 /// # Returns
67 ///
68 /// A string describing what this transform does.
69 fn description(&self) -> String;
70}
71
72/// A multiplicative transform (percentage modifier).
73///
74/// Multiplies the input value by a constant factor.
75///
76/// # Examples
77///
78/// ```rust
79/// use zzstat::transform::{StatTransform, MultiplicativeTransform};
80/// use zzstat::StatContext;
81/// use std::collections::HashMap;
82///
83/// let transform = MultiplicativeTransform::new(1.5);
84/// let context = StatContext::new();
85/// let deps = HashMap::new();
86///
87/// // 100 * 1.5 = 150
88/// assert_eq!(transform.apply(100.0, &deps, &context).unwrap(), 150.0);
89/// ```
90#[derive(Debug, Clone)]
91pub struct MultiplicativeTransform {
92 multiplier: f64,
93}
94
95impl MultiplicativeTransform {
96 /// Create a new multiplicative transform.
97 ///
98 /// # Arguments
99 ///
100 /// * `multiplier` - The multiplier to apply (e.g., 1.5 for +50%)
101 ///
102 /// # Examples
103 ///
104 /// ```rust
105 /// use zzstat::transform::MultiplicativeTransform;
106 ///
107 /// // +50% bonus
108 /// let bonus = MultiplicativeTransform::new(1.5);
109 ///
110 /// // -20% penalty
111 /// let penalty = MultiplicativeTransform::new(0.8);
112 /// ```
113 pub fn new(multiplier: f64) -> Self {
114 Self { multiplier }
115 }
116}
117
118impl StatTransform for MultiplicativeTransform {
119 fn depends_on(&self) -> Vec<StatId> {
120 Vec::new()
121 }
122
123 fn apply(
124 &self,
125 input: f64,
126 _dependencies: &HashMap<StatId, f64>,
127 _context: &StatContext,
128 ) -> Result<f64, StatError> {
129 Ok(input * self.multiplier)
130 }
131
132 fn description(&self) -> String {
133 format!("×{:.2}", self.multiplier)
134 }
135}
136
137/// An additive transform (flat bonus).
138///
139/// Adds a constant value to the input.
140///
141/// # Examples
142///
143/// ```rust
144/// use zzstat::transform::{StatTransform, AdditiveTransform};
145/// use zzstat::StatContext;
146/// use std::collections::HashMap;
147///
148/// let transform = AdditiveTransform::new(25.0);
149/// let context = StatContext::new();
150/// let deps = HashMap::new();
151///
152/// // 100 + 25 = 125
153/// assert_eq!(transform.apply(100.0, &deps, &context).unwrap(), 125.0);
154/// ```
155#[derive(Debug, Clone)]
156pub struct AdditiveTransform {
157 bonus: f64,
158}
159
160impl AdditiveTransform {
161 /// Create a new additive transform.
162 ///
163 /// # Arguments
164 ///
165 /// * `bonus` - The flat bonus to add (can be negative for penalties)
166 ///
167 /// # Examples
168 ///
169 /// ```rust
170 /// use zzstat::transform::AdditiveTransform;
171 ///
172 /// // +25 flat bonus
173 /// let bonus = AdditiveTransform::new(25.0);
174 ///
175 /// // -10 flat penalty
176 /// let penalty = AdditiveTransform::new(-10.0);
177 /// ```
178 pub fn new(bonus: f64) -> Self {
179 Self { bonus }
180 }
181}
182
183impl StatTransform for AdditiveTransform {
184 fn depends_on(&self) -> Vec<StatId> {
185 Vec::new()
186 }
187
188 fn apply(
189 &self,
190 input: f64,
191 _dependencies: &HashMap<StatId, f64>,
192 _context: &StatContext,
193 ) -> Result<f64, StatError> {
194 Ok(input + self.bonus)
195 }
196
197 fn description(&self) -> String {
198 format!("+{:.2}", self.bonus)
199 }
200}
201
202/// A clamp transform that restricts values to a range.
203///
204/// Ensures the output value is between `min` and `max` (inclusive).
205///
206/// # Examples
207///
208/// ```rust
209/// use zzstat::transform::{StatTransform, ClampTransform};
210/// use zzstat::StatContext;
211/// use std::collections::HashMap;
212///
213/// let transform = ClampTransform::new(0.0, 100.0);
214/// let context = StatContext::new();
215/// let deps = HashMap::new();
216///
217/// assert_eq!(transform.apply(150.0, &deps, &context).unwrap(), 100.0);
218/// assert_eq!(transform.apply(-10.0, &deps, &context).unwrap(), 0.0);
219/// assert_eq!(transform.apply(50.0, &deps, &context).unwrap(), 50.0);
220/// ```
221#[derive(Debug, Clone)]
222pub struct ClampTransform {
223 min: f64,
224 max: f64,
225}
226
227impl ClampTransform {
228 /// Create a new clamp transform.
229 ///
230 /// # Arguments
231 ///
232 /// * `min` - Minimum allowed value (inclusive)
233 /// * `max` - Maximum allowed value (inclusive)
234 ///
235 /// # Panics
236 ///
237 /// This function does not panic, but if `min > max`, the behavior
238 /// is undefined (values will never pass the clamp).
239 ///
240 /// # Examples
241 ///
242 /// ```rust
243 /// use zzstat::transform::ClampTransform;
244 ///
245 /// // Clamp between 0 and 100
246 /// let clamp = ClampTransform::new(0.0, 100.0);
247 /// ```
248 pub fn new(min: f64, max: f64) -> Self {
249 Self { min, max }
250 }
251}
252
253impl StatTransform for ClampTransform {
254 fn depends_on(&self) -> Vec<StatId> {
255 Vec::new()
256 }
257
258 fn apply(
259 &self,
260 input: f64,
261 _dependencies: &HashMap<StatId, f64>,
262 _context: &StatContext,
263 ) -> Result<f64, StatError> {
264 Ok(input.clamp(self.min, self.max))
265 }
266
267 fn description(&self) -> String {
268 format!("clamp({:.2}, {:.2})", self.min, self.max)
269 }
270}
271
272/// A conditional transform that applies another transform based on a condition.
273///
274/// Only applies the inner transform if the condition function returns `true`
275/// when called with the current `StatContext`. Otherwise, returns the input
276/// value unchanged.
277///
278/// # Examples
279///
280/// ```rust
281/// use zzstat::transform::{StatTransform, ConditionalTransform, MultiplicativeTransform};
282/// use zzstat::StatContext;
283/// use std::collections::HashMap;
284///
285/// let mut context = StatContext::new();
286/// context.set("in_combat", true);
287///
288/// let inner_transform = Box::new(MultiplicativeTransform::new(1.2));
289/// let transform = ConditionalTransform::new(
290/// |ctx| ctx.get::<bool>("in_combat").unwrap_or(false),
291/// inner_transform,
292/// "combat bonus",
293/// );
294///
295/// let deps = HashMap::new();
296/// // In combat: 100 * 1.2 = 120
297/// assert_eq!(transform.apply(100.0, &deps, &context).unwrap(), 120.0);
298///
299/// context.set("in_combat", false);
300/// // Out of combat: 100 (unchanged)
301/// assert_eq!(transform.apply(100.0, &deps, &context).unwrap(), 100.0);
302/// ```
303pub struct ConditionalTransform {
304 condition: Box<dyn Fn(&StatContext) -> bool + Send + Sync>,
305 transform: Box<dyn StatTransform>,
306 description: String,
307}
308
309impl ConditionalTransform {
310 /// Create a new conditional transform.
311 ///
312 /// # Arguments
313 ///
314 /// * `condition` - A function that takes `&StatContext` and returns `bool`
315 /// * `transform` - The transform to apply when condition is `true`
316 /// * `description` - Human-readable description for debugging
317 ///
318 /// # Examples
319 ///
320 /// ```rust
321 /// use zzstat::transform::{ConditionalTransform, MultiplicativeTransform};
322 ///
323 /// let inner = Box::new(MultiplicativeTransform::new(1.5));
324 /// let transform = ConditionalTransform::new(
325 /// |ctx| ctx.get::<bool>("in_combat").unwrap_or(false),
326 /// inner,
327 /// "combat bonus +50%",
328 /// );
329 /// ```
330 pub fn new<F>(
331 condition: F,
332 transform: Box<dyn StatTransform>,
333 description: impl Into<String>,
334 ) -> Self
335 where
336 F: Fn(&StatContext) -> bool + Send + Sync + 'static,
337 {
338 Self {
339 condition: Box::new(condition),
340 transform,
341 description: description.into(),
342 }
343 }
344}
345
346impl StatTransform for ConditionalTransform {
347 fn depends_on(&self) -> Vec<StatId> {
348 self.transform.depends_on()
349 }
350
351 fn apply(
352 &self,
353 input: f64,
354 dependencies: &HashMap<StatId, f64>,
355 context: &StatContext,
356 ) -> Result<f64, StatError> {
357 if (self.condition)(context) {
358 self.transform.apply(input, dependencies, context)
359 } else {
360 Ok(input)
361 }
362 }
363
364 fn description(&self) -> String {
365 self.description.clone()
366 }
367}
368
369/// A transform that scales based on another stat.
370///
371/// Adds `dependency_value * scale_factor` to the input value.
372/// This is commonly used for derived stats (e.g., ATK = base + STR * 2).
373///
374/// # Examples
375///
376/// ```rust
377/// use zzstat::transform::{StatTransform, ScalingTransform};
378/// use zzstat::{StatId, StatContext};
379/// use std::collections::HashMap;
380///
381/// let str_id = StatId::from_str("STR");
382/// let transform = ScalingTransform::new(str_id.clone(), 2.0);
383///
384/// let mut deps = HashMap::new();
385/// deps.insert(str_id.clone(), 10.0);
386///
387/// let context = StatContext::new();
388/// // 100 (base) + 10 (STR) * 2 = 120
389/// assert_eq!(transform.apply(100.0, &deps, &context).unwrap(), 120.0);
390/// ```
391#[derive(Debug, Clone)]
392pub struct ScalingTransform {
393 dependency: StatId,
394 scale_factor: f64,
395}
396
397impl ScalingTransform {
398 /// Create a new scaling transform.
399 ///
400 /// # Arguments
401 ///
402 /// * `dependency` - The stat ID this transform depends on
403 /// * `scale_factor` - The multiplier to apply to the dependency value
404 ///
405 /// # Examples
406 ///
407 /// ```rust
408 /// use zzstat::transform::ScalingTransform;
409 /// use zzstat::StatId;
410 ///
411 /// let str_id = StatId::from_str("STR");
412 /// // ATK scales with STR: ATK = base + STR * 2
413 /// let transform = ScalingTransform::new(str_id, 2.0);
414 /// ```
415 pub fn new(dependency: StatId, scale_factor: f64) -> Self {
416 Self {
417 dependency,
418 scale_factor,
419 }
420 }
421}
422
423impl StatTransform for ScalingTransform {
424 fn depends_on(&self) -> Vec<StatId> {
425 vec![self.dependency.clone()]
426 }
427
428 fn apply(
429 &self,
430 input: f64,
431 dependencies: &HashMap<StatId, f64>,
432 _context: &StatContext,
433 ) -> Result<f64, StatError> {
434 let dep_value = dependencies
435 .get(&self.dependency)
436 .ok_or_else(|| StatError::MissingDependency(self.dependency.clone()))?;
437 Ok(input + (dep_value * self.scale_factor))
438 }
439
440 fn description(&self) -> String {
441 format!("scale({}, {:.2})", self.dependency, self.scale_factor)
442 }
443}
444
445#[cfg(test)]
446mod tests {
447 use super::*;
448
449 #[test]
450 fn test_multiplicative_transform() {
451 let transform = MultiplicativeTransform::new(1.5);
452 let context = StatContext::new();
453 let deps = HashMap::new();
454
455 assert_eq!(transform.apply(100.0, &deps, &context).unwrap(), 150.0);
456 }
457
458 #[test]
459 fn test_additive_transform() {
460 let transform = AdditiveTransform::new(25.0);
461 let context = StatContext::new();
462 let deps = HashMap::new();
463
464 assert_eq!(transform.apply(100.0, &deps, &context).unwrap(), 125.0);
465 }
466
467 #[test]
468 fn test_clamp_transform() {
469 let transform = ClampTransform::new(0.0, 100.0);
470 let context = StatContext::new();
471 let deps = HashMap::new();
472
473 assert_eq!(transform.apply(150.0, &deps, &context).unwrap(), 100.0);
474 assert_eq!(transform.apply(-10.0, &deps, &context).unwrap(), 0.0);
475 assert_eq!(transform.apply(50.0, &deps, &context).unwrap(), 50.0);
476 }
477
478 #[test]
479 fn test_scaling_transform() {
480 let str_id = StatId::from_str("STR");
481 let transform = ScalingTransform::new(str_id.clone(), 2.0);
482 let context = StatContext::new();
483 let mut deps = HashMap::new();
484 deps.insert(str_id.clone(), 10.0);
485
486 assert_eq!(transform.depends_on(), vec![str_id]);
487 assert_eq!(transform.apply(100.0, &deps, &context).unwrap(), 120.0);
488 }
489
490 #[test]
491 fn test_scaling_transform_missing_dependency() {
492 let str_id = StatId::from_str("STR");
493 let transform = ScalingTransform::new(str_id, 2.0);
494 let context = StatContext::new();
495 let deps = HashMap::new();
496
497 assert!(transform.apply(100.0, &deps, &context).is_err());
498 }
499
500 #[test]
501 fn test_conditional_transform() {
502 let mut context = StatContext::new();
503 context.set("in_combat", true);
504
505 let inner_transform = Box::new(MultiplicativeTransform::new(1.2));
506 let transform = ConditionalTransform::new(
507 |ctx| ctx.get::<bool>("in_combat").unwrap_or(false),
508 inner_transform,
509 "combat bonus",
510 );
511
512 let deps = HashMap::new();
513 assert_eq!(transform.apply(100.0, &deps, &context).unwrap(), 120.0);
514
515 context.set("in_combat", false);
516 assert_eq!(transform.apply(100.0, &deps, &context).unwrap(), 100.0);
517 }
518}