solverforge_scoring/api/
weight_overrides.rs1use std::collections::HashMap;
6use std::fmt::Debug;
7use std::sync::Arc;
8
9use solverforge_core::Score;
10
11#[derive(Clone)]
43pub struct ConstraintWeightOverrides<Sc: Score> {
44 weights: HashMap<String, Sc>,
45}
46
47impl<Sc: Score> Debug for ConstraintWeightOverrides<Sc> {
48 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
49 f.debug_struct("ConstraintWeightOverrides")
50 .field("count", &self.weights.len())
51 .finish()
52 }
53}
54
55impl<Sc: Score> Default for ConstraintWeightOverrides<Sc> {
56 fn default() -> Self {
57 Self::new()
58 }
59}
60
61impl<Sc: Score> ConstraintWeightOverrides<Sc> {
62 pub fn new() -> Self {
64 Self {
65 weights: HashMap::new(),
66 }
67 }
68
69 pub fn from_pairs<I, N>(iter: I) -> Self
84 where
85 I: IntoIterator<Item = (N, Sc)>,
86 N: Into<String>,
87 {
88 let weights = iter.into_iter().map(|(n, w)| (n.into(), w)).collect();
89 Self { weights }
90 }
91
92 pub fn put<N: Into<String>>(&mut self, name: N, weight: Sc) {
94 self.weights.insert(name.into(), weight);
95 }
96
97 pub fn remove(&mut self, name: &str) -> Option<Sc> {
99 self.weights.remove(name)
100 }
101
102 pub fn get_or_default(&self, name: &str, default: Sc) -> Sc {
104 self.weights.get(name).cloned().unwrap_or(default)
105 }
106
107 pub fn get(&self, name: &str) -> Option<&Sc> {
109 self.weights.get(name)
110 }
111
112 pub fn contains(&self, name: &str) -> bool {
114 self.weights.contains_key(name)
115 }
116
117 pub fn len(&self) -> usize {
119 self.weights.len()
120 }
121
122 pub fn is_empty(&self) -> bool {
124 self.weights.is_empty()
125 }
126
127 pub fn clear(&mut self) {
129 self.weights.clear();
130 }
131
132 pub fn into_arc(self) -> Arc<Self> {
134 Arc::new(self)
135 }
136}
137
138pub trait WeightProvider<Sc: Score>: Send + Sync {
142 fn weight(&self, name: &str) -> Option<Sc>;
144
145 fn weight_or_default(&self, name: &str, default: Sc) -> Sc {
147 self.weight(name).unwrap_or(default)
148 }
149}
150
151impl<Sc: Score> WeightProvider<Sc> for ConstraintWeightOverrides<Sc> {
152 fn weight(&self, name: &str) -> Option<Sc> {
153 self.get(name).cloned()
154 }
155}
156
157impl<Sc: Score> WeightProvider<Sc> for Arc<ConstraintWeightOverrides<Sc>> {
158 fn weight(&self, name: &str) -> Option<Sc> {
159 self.get(name).cloned()
160 }
161}
162
163#[cfg(test)]
164mod tests {
165 use super::*;
166 use solverforge_core::score::{HardSoftScore, SimpleScore};
167
168 #[test]
169 fn test_new_is_empty() {
170 let overrides = ConstraintWeightOverrides::<SimpleScore>::new();
171 assert!(overrides.is_empty());
172 assert_eq!(overrides.len(), 0);
173 }
174
175 #[test]
176 fn test_put_and_get() {
177 let mut overrides = ConstraintWeightOverrides::<SimpleScore>::new();
178 overrides.put("test", SimpleScore::of(5));
179
180 assert!(overrides.contains("test"));
181 assert_eq!(overrides.get("test"), Some(&SimpleScore::of(5)));
182 }
183
184 #[test]
185 fn test_get_or_default_with_override() {
186 let mut overrides = ConstraintWeightOverrides::<SimpleScore>::new();
187 overrides.put("test", SimpleScore::of(5));
188
189 let weight = overrides.get_or_default("test", SimpleScore::of(1));
190 assert_eq!(weight, SimpleScore::of(5));
191 }
192
193 #[test]
194 fn test_get_or_default_without_override() {
195 let overrides = ConstraintWeightOverrides::<SimpleScore>::new();
196
197 let weight = overrides.get_or_default("test", SimpleScore::of(1));
198 assert_eq!(weight, SimpleScore::of(1));
199 }
200
201 #[test]
202 fn test_remove() {
203 let mut overrides = ConstraintWeightOverrides::<SimpleScore>::new();
204 overrides.put("test", SimpleScore::of(5));
205
206 let removed = overrides.remove("test");
207 assert_eq!(removed, Some(SimpleScore::of(5)));
208 assert!(!overrides.contains("test"));
209 }
210
211 #[test]
212 fn test_from_pairs() {
213 let overrides = ConstraintWeightOverrides::<HardSoftScore>::from_pairs([
214 ("hard_constraint", HardSoftScore::of_hard(1)),
215 ("soft_constraint", HardSoftScore::of_soft(10)),
216 ]);
217
218 assert_eq!(overrides.len(), 2);
219 assert_eq!(
220 overrides.get("hard_constraint"),
221 Some(&HardSoftScore::of_hard(1))
222 );
223 assert_eq!(
224 overrides.get("soft_constraint"),
225 Some(&HardSoftScore::of_soft(10))
226 );
227 }
228
229 #[test]
230 fn test_weight_provider_trait() {
231 let mut overrides = ConstraintWeightOverrides::<SimpleScore>::new();
232 overrides.put("test", SimpleScore::of(5));
233
234 let provider: &dyn WeightProvider<SimpleScore> = &overrides;
235 assert_eq!(provider.weight("test"), Some(SimpleScore::of(5)));
236 assert_eq!(provider.weight("other"), None);
237 assert_eq!(
238 provider.weight_or_default("other", SimpleScore::of(1)),
239 SimpleScore::of(1)
240 );
241 }
242
243 #[test]
244 fn test_arc_weight_provider() {
245 let mut overrides = ConstraintWeightOverrides::<SimpleScore>::new();
246 overrides.put("test", SimpleScore::of(5));
247 let arc_overrides = overrides.into_arc();
248
249 let provider: &dyn WeightProvider<SimpleScore> = &arc_overrides;
250 assert_eq!(provider.weight("test"), Some(SimpleScore::of(5)));
251 }
252}