solverforge_scoring/api/
weight_overrides.rs

1//! Runtime constraint weight configuration.
2//!
3//! Allows dynamic adjustment of constraint weights without recompiling.
4
5use std::collections::HashMap;
6use std::fmt::Debug;
7use std::sync::Arc;
8
9use solverforge_core::Score;
10
11/// Holds runtime overrides for constraint weights.
12///
13/// Use this to adjust constraint weights without recompiling. Weights can be
14/// changed between solver runs or even during solving (if you rebuild constraints).
15///
16/// # Example
17///
18/// ```
19/// use solverforge_scoring::ConstraintWeightOverrides;
20/// use solverforge_core::score::HardSoftScore;
21///
22/// let mut overrides = ConstraintWeightOverrides::<HardSoftScore>::new();
23///
24/// // Override specific constraint weights
25/// overrides.put("room_conflict", HardSoftScore::of_hard(2));
26/// overrides.put("preferred_room", HardSoftScore::of_soft(5));
27///
28/// // Get weight with fallback to default
29/// let weight = overrides.get_or_default(
30///     "room_conflict",
31///     HardSoftScore::of_hard(1), // default if not overridden
32/// );
33/// assert_eq!(weight, HardSoftScore::of_hard(2));
34///
35/// // Non-overridden constraint uses default
36/// let other = overrides.get_or_default(
37///     "other_constraint",
38///     HardSoftScore::of_soft(10),
39/// );
40/// assert_eq!(other, HardSoftScore::of_soft(10));
41/// ```
42#[derive(Clone)]
43pub struct ConstraintWeightOverrides<Sc: Score> {
44    weights: HashMap<String, Sc>,
45}
46
47impl<Sc: Score> Debug for ConstraintWeightOverrides<Sc> {
48    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
49        f.debug_struct("ConstraintWeightOverrides")
50            .field("count", &self.weights.len())
51            .finish()
52    }
53}
54
55impl<Sc: Score> Default for ConstraintWeightOverrides<Sc> {
56    fn default() -> Self {
57        Self::new()
58    }
59}
60
61impl<Sc: Score> ConstraintWeightOverrides<Sc> {
62    /// Creates an empty overrides container.
63    pub fn new() -> Self {
64        Self {
65            weights: HashMap::new(),
66        }
67    }
68
69    /// Creates overrides from an iterator of (name, weight) pairs.
70    ///
71    /// # Example
72    ///
73    /// ```
74    /// use solverforge_scoring::ConstraintWeightOverrides;
75    /// use solverforge_core::score::HardSoftScore;
76    ///
77    /// let overrides = ConstraintWeightOverrides::from_pairs([
78    ///     ("hard_constraint", HardSoftScore::of_hard(1)),
79    ///     ("soft_constraint", HardSoftScore::of_soft(10)),
80    /// ]);
81    /// assert_eq!(overrides.len(), 2);
82    /// ```
83    pub fn from_pairs<I, N>(iter: I) -> Self
84    where
85        I: IntoIterator<Item = (N, Sc)>,
86        N: Into<String>,
87    {
88        let weights = iter.into_iter().map(|(n, w)| (n.into(), w)).collect();
89        Self { weights }
90    }
91
92    /// Sets the weight for a constraint.
93    pub fn put<N: Into<String>>(&mut self, name: N, weight: Sc) {
94        self.weights.insert(name.into(), weight);
95    }
96
97    /// Removes the override for a constraint.
98    pub fn remove(&mut self, name: &str) -> Option<Sc> {
99        self.weights.remove(name)
100    }
101
102    /// Gets the overridden weight, or returns the default if not overridden.
103    pub fn get_or_default(&self, name: &str, default: Sc) -> Sc {
104        self.weights.get(name).cloned().unwrap_or(default)
105    }
106
107    /// Gets the overridden weight if present.
108    pub fn get(&self, name: &str) -> Option<&Sc> {
109        self.weights.get(name)
110    }
111
112    /// Returns true if this constraint has an override.
113    pub fn contains(&self, name: &str) -> bool {
114        self.weights.contains_key(name)
115    }
116
117    /// Returns the number of overrides.
118    pub fn len(&self) -> usize {
119        self.weights.len()
120    }
121
122    /// Returns true if there are no overrides.
123    pub fn is_empty(&self) -> bool {
124        self.weights.is_empty()
125    }
126
127    /// Clears all overrides.
128    pub fn clear(&mut self) {
129        self.weights.clear();
130    }
131
132    /// Creates an `Arc`-wrapped version for sharing across threads.
133    pub fn into_arc(self) -> Arc<Self> {
134        Arc::new(self)
135    }
136}
137
138/// Helper trait for creating weight functions from overrides.
139///
140/// This enables zero-erasure constraint building with runtime weight lookup.
141pub trait WeightProvider<Sc: Score>: Send + Sync {
142    /// Gets the weight for a constraint by name.
143    fn weight(&self, name: &str) -> Option<Sc>;
144
145    /// Gets the weight or returns the default.
146    fn weight_or_default(&self, name: &str, default: Sc) -> Sc {
147        self.weight(name).unwrap_or(default)
148    }
149}
150
151impl<Sc: Score> WeightProvider<Sc> for ConstraintWeightOverrides<Sc> {
152    fn weight(&self, name: &str) -> Option<Sc> {
153        self.get(name).cloned()
154    }
155}
156
157impl<Sc: Score> WeightProvider<Sc> for Arc<ConstraintWeightOverrides<Sc>> {
158    fn weight(&self, name: &str) -> Option<Sc> {
159        self.get(name).cloned()
160    }
161}
162
163#[cfg(test)]
164mod tests {
165    use super::*;
166    use solverforge_core::score::{HardSoftScore, SimpleScore};
167
168    #[test]
169    fn test_new_is_empty() {
170        let overrides = ConstraintWeightOverrides::<SimpleScore>::new();
171        assert!(overrides.is_empty());
172        assert_eq!(overrides.len(), 0);
173    }
174
175    #[test]
176    fn test_put_and_get() {
177        let mut overrides = ConstraintWeightOverrides::<SimpleScore>::new();
178        overrides.put("test", SimpleScore::of(5));
179
180        assert!(overrides.contains("test"));
181        assert_eq!(overrides.get("test"), Some(&SimpleScore::of(5)));
182    }
183
184    #[test]
185    fn test_get_or_default_with_override() {
186        let mut overrides = ConstraintWeightOverrides::<SimpleScore>::new();
187        overrides.put("test", SimpleScore::of(5));
188
189        let weight = overrides.get_or_default("test", SimpleScore::of(1));
190        assert_eq!(weight, SimpleScore::of(5));
191    }
192
193    #[test]
194    fn test_get_or_default_without_override() {
195        let overrides = ConstraintWeightOverrides::<SimpleScore>::new();
196
197        let weight = overrides.get_or_default("test", SimpleScore::of(1));
198        assert_eq!(weight, SimpleScore::of(1));
199    }
200
201    #[test]
202    fn test_remove() {
203        let mut overrides = ConstraintWeightOverrides::<SimpleScore>::new();
204        overrides.put("test", SimpleScore::of(5));
205
206        let removed = overrides.remove("test");
207        assert_eq!(removed, Some(SimpleScore::of(5)));
208        assert!(!overrides.contains("test"));
209    }
210
211    #[test]
212    fn test_from_pairs() {
213        let overrides = ConstraintWeightOverrides::<HardSoftScore>::from_pairs([
214            ("hard_constraint", HardSoftScore::of_hard(1)),
215            ("soft_constraint", HardSoftScore::of_soft(10)),
216        ]);
217
218        assert_eq!(overrides.len(), 2);
219        assert_eq!(
220            overrides.get("hard_constraint"),
221            Some(&HardSoftScore::of_hard(1))
222        );
223        assert_eq!(
224            overrides.get("soft_constraint"),
225            Some(&HardSoftScore::of_soft(10))
226        );
227    }
228
229    #[test]
230    fn test_weight_provider_trait() {
231        let mut overrides = ConstraintWeightOverrides::<SimpleScore>::new();
232        overrides.put("test", SimpleScore::of(5));
233
234        let provider: &dyn WeightProvider<SimpleScore> = &overrides;
235        assert_eq!(provider.weight("test"), Some(SimpleScore::of(5)));
236        assert_eq!(provider.weight("other"), None);
237        assert_eq!(
238            provider.weight_or_default("other", SimpleScore::of(1)),
239            SimpleScore::of(1)
240        );
241    }
242
243    #[test]
244    fn test_arc_weight_provider() {
245        let mut overrides = ConstraintWeightOverrides::<SimpleScore>::new();
246        overrides.put("test", SimpleScore::of(5));
247        let arc_overrides = overrides.into_arc();
248
249        let provider: &dyn WeightProvider<SimpleScore> = &arc_overrides;
250        assert_eq!(provider.weight("test"), Some(SimpleScore::of(5)));
251    }
252}