1use crate::AggregateOrders;
2use crate::MarketError;
3use alloy::primitives::{Address, U256};
4use serde::{Deserialize, Serialize};
5use signet_zenith::RollupOrders;
6use std::collections::HashMap;
7
8#[derive(Debug, Default, Clone, PartialEq, Eq, Serialize, Deserialize)]
46pub struct AggregateFills {
47 fills: HashMap<(u64, Address), HashMap<Address, U256>>,
50}
51
52impl AggregateFills {
53 pub fn new() -> Self {
55 Self::default()
56 }
57
58 pub fn filled(&self, output_asset: &(u64, Address), recipient: Address) -> U256 {
60 self.fills.get(output_asset).and_then(|m| m.get(&recipient)).copied().unwrap_or_default()
61 }
62
63 pub fn check_filled(
66 &self,
67 output_asset: &(u64, Address),
68 recipient: Address,
69 amount: U256,
70 ) -> Result<(), MarketError> {
71 if self.filled(output_asset, recipient) < amount {
72 return Err(MarketError::InsufficientBalance {
73 chain_id: output_asset.0,
74 asset: output_asset.1,
75 recipient,
76 amount,
77 });
78 }
79 Ok(())
80 }
81
82 pub fn add_raw_fill(
85 &mut self,
86 chain_id: u64,
87 asset: Address,
88 recipient: Address,
89 amount: U256,
90 ) {
91 let entry = self.fills.entry((chain_id, asset)).or_default().entry(recipient).or_default();
92 *entry = entry.saturating_add(amount);
93 }
94
95 fn add_fill_output(&mut self, chain_id: u64, output: &RollupOrders::Output) {
97 self.add_raw_fill(chain_id, output.token, output.recipient, output.amount)
98 }
99
100 pub fn add_fill(&mut self, chain_id: u64, fill: &RollupOrders::Filled) {
108 fill.outputs.iter().for_each(|o| self.add_fill_output(chain_id, o));
109 }
110
111 fn absorb(&mut self, other: &Self) {
113 for (output_asset, recipients) in other.fills.iter() {
114 let context_recipients = self.fills.entry(*output_asset).or_default();
115 for (recipient, value) in recipients {
116 let filled = context_recipients.entry(*recipient).or_default();
117 *filled = filled.saturating_add(*value);
118 }
119 }
120 }
121
122 pub fn check_aggregate(&self, aggregate: &AggregateOrders) -> Result<(), MarketError> {
124 for (output_asset, recipients) in aggregate.outputs.iter() {
125 if !self.fills.contains_key(output_asset) {
126 return Err(MarketError::MissingAsset {
127 chain_id: output_asset.0,
128 asset: output_asset.1,
129 });
130 };
131
132 for (recipient, value) in recipients {
133 self.check_filled(output_asset, *recipient, *value)?;
134 }
135 }
136 Ok(())
137 }
138
139 pub fn unchecked_remove_aggregate(
144 &mut self,
145 aggregate: &AggregateOrders,
146 ) -> Result<(), MarketError> {
147 for (output_asset, recipients) in aggregate.outputs.iter() {
148 let context_recipients =
149 self.fills.get_mut(output_asset).ok_or(MarketError::MissingAsset {
150 chain_id: output_asset.0,
151 asset: output_asset.1,
152 })?;
153
154 for (recipient, amount) in recipients {
155 let filled = context_recipients.get_mut(recipient).unwrap();
156 *filled = filled.saturating_sub(*amount);
157 }
158 }
159
160 Ok(())
161 }
162
163 pub fn checked_remove_aggregate(
166 &mut self,
167 aggregate: &AggregateOrders,
168 ) -> Result<(), MarketError> {
169 self.check_aggregate(aggregate)?;
170
171 for (output_asset, recipients) in aggregate.outputs.iter() {
172 let context_recipients =
173 self.fills.get_mut(output_asset).expect("checked in check_aggregate");
174
175 for (recipient, amount) in recipients {
176 let filled = context_recipients.get_mut(recipient).unwrap();
177 *filled = filled.checked_sub(*amount).unwrap();
178 }
179 }
180
181 Ok(())
182 }
183
184 pub fn check_order(&self, order: &RollupOrders::Order) -> Result<(), MarketError> {
186 self.check_aggregate(&std::iter::once(order).collect())
187 }
188
189 pub fn checked_remove_order(&mut self, order: &RollupOrders::Order) -> Result<(), MarketError> {
192 let aggregate = std::iter::once(order).collect();
193 self.check_aggregate(&aggregate)?;
194 self.unchecked_remove_aggregate(&aggregate)
195 }
196
197 pub fn unchecked_remove_order(
202 &mut self,
203 order: &RollupOrders::Order,
204 ) -> Result<(), MarketError> {
205 let aggregate = std::iter::once(order).collect();
206 self.unchecked_remove_aggregate(&aggregate)
207 }
208
209 pub const fn fills(&self) -> &HashMap<(u64, Address), HashMap<Address, U256>> {
211 &self.fills
212 }
213
214 pub fn fills_mut(&mut self) -> &mut HashMap<(u64, Address), HashMap<Address, U256>> {
216 &mut self.fills
217 }
218
219 pub fn check_ru_tx_events(
223 &self,
224 fills: &AggregateFills,
225 aggregate: &AggregateOrders,
226 ) -> Result<(), MarketError> {
227 let combined = CombinedContext { context: self, extra: fills };
229
230 combined.check_aggregate(aggregate)?;
231
232 Ok(())
233 }
234
235 pub fn checked_remove_ru_tx_events(
241 &mut self,
242 aggregate: &AggregateOrders,
243 fills: &AggregateFills,
244 ) -> Result<(), MarketError> {
245 self.check_ru_tx_events(fills, aggregate)?;
246 self.absorb(fills);
247 self.unchecked_remove_aggregate(aggregate)
248 }
249}
250
251struct CombinedContext<'a, 'b> {
254 context: &'a AggregateFills,
255 extra: &'b AggregateFills,
256}
257
258impl CombinedContext<'_, '_> {
259 fn balance(&self, output_asset: &(u64, Address), recipient: Address) -> U256 {
261 self.context.filled(output_asset, recipient) + self.extra.filled(output_asset, recipient)
262 }
263
264 fn check_filled(
267 &self,
268 output_asset: &(u64, Address),
269 recipient: Address,
270 amount: U256,
271 ) -> Result<(), MarketError> {
272 if self.balance(output_asset, recipient) < amount {
273 return Err(MarketError::InsufficientBalance {
274 chain_id: output_asset.0,
275 asset: output_asset.1,
276 recipient,
277 amount,
278 });
279 }
280 Ok(())
281 }
282
283 fn check_aggregate(&self, aggregate: &AggregateOrders) -> Result<(), MarketError> {
285 for (output_asset, recipients) in aggregate.outputs.iter() {
286 for (recipient, amount) in recipients {
287 self.check_filled(output_asset, *recipient, *amount)?;
288 }
289 }
290 Ok(())
291 }
292}
293
294#[cfg(test)]
295mod test {
296 use super::*;
297 use signet_zenith::RollupOrders::{Filled, Order, Output};
298
299 #[test]
300 fn basic_fills() {
301 let user_a = Address::with_last_byte(1);
302 let user_b = Address::with_last_byte(2);
303
304 let asset_a = Address::with_last_byte(3);
305 let asset_b = Address::with_last_byte(4);
306
307 let a_to_a =
309 Output { token: asset_a, amount: U256::from(100), recipient: user_a, chainId: 1 };
310 let b_to_b =
311 Output { token: asset_b, amount: U256::from(200), recipient: user_b, chainId: 1 };
312 let a_to_b =
313 Output { token: asset_a, amount: U256::from(300), recipient: user_b, chainId: 1 };
314
315 let fill = Filled { outputs: vec![a_to_a, b_to_b, a_to_b] };
316
317 let order =
318 Order { deadline: U256::ZERO, inputs: vec![], outputs: vec![a_to_a, b_to_b, a_to_b] };
319
320 let mut context = AggregateFills::default();
321 context.add_fill(1, &fill);
322
323 assert_eq!(context.fills().len(), 2);
324 assert_eq!(
325 context.fills().get(&(1, asset_a)).unwrap().get(&user_a).unwrap(),
326 &U256::from(100)
327 );
328 assert_eq!(
329 context.fills().get(&(1, asset_b)).unwrap().get(&user_b).unwrap(),
330 &U256::from(200)
331 );
332 assert_eq!(
333 context.fills().get(&(1, asset_a)).unwrap().get(&user_b).unwrap(),
334 &U256::from(300)
335 );
336
337 context.checked_remove_order(&order).unwrap();
338 assert_eq!(context.fills().len(), 2);
339 assert_eq!(
340 context.fills().get(&(1, asset_a)).unwrap().get(&user_a).unwrap(),
341 &U256::from(0)
342 );
343 assert_eq!(
344 context.fills().get(&(1, asset_b)).unwrap().get(&user_b).unwrap(),
345 &U256::from(0)
346 );
347 assert_eq!(
348 context.fills().get(&(1, asset_a)).unwrap().get(&user_b).unwrap(),
349 &U256::from(0)
350 );
351 }
352}