1use crate::AggregateOrders;
2use crate::MarketError;
3use crate::SignedFill;
4use alloy::primitives::{Address, U256};
5use serde::{Deserialize, Serialize};
6use signet_zenith::RollupOrders;
7use std::collections::HashMap;
8
9#[derive(Debug, Default, Clone, PartialEq, Eq, Serialize, Deserialize)]
47pub struct AggregateFills {
48 fills: HashMap<(u64, Address), HashMap<Address, U256>>,
51}
52
53impl AggregateFills {
54 pub fn new() -> Self {
56 Self::default()
57 }
58
59 pub fn filled(&self, output_asset: &(u64, Address), recipient: Address) -> U256 {
61 self.fills.get(output_asset).and_then(|m| m.get(&recipient)).copied().unwrap_or_default()
62 }
63
64 pub fn check_filled(
67 &self,
68 output_asset: &(u64, Address),
69 recipient: Address,
70 amount: U256,
71 ) -> Result<(), MarketError> {
72 if self.filled(output_asset, recipient) < amount {
73 return Err(MarketError::InsufficientBalance {
74 chain_id: output_asset.0,
75 asset: output_asset.1,
76 recipient,
77 amount,
78 });
79 }
80 Ok(())
81 }
82
83 pub fn add_raw_fill(
86 &mut self,
87 chain_id: u64,
88 asset: Address,
89 recipient: Address,
90 amount: U256,
91 ) {
92 let entry = self.fills.entry((chain_id, asset)).or_default().entry(recipient).or_default();
93 *entry = entry.saturating_add(amount);
94 }
95
96 fn add_fill_output(&mut self, chain_id: u64, output: &RollupOrders::Output) {
98 self.add_raw_fill(chain_id, output.token, output.recipient, output.amount)
99 }
100
101 pub fn add_fill(&mut self, chain_id: u64, fill: &RollupOrders::Filled) {
109 fill.outputs.iter().for_each(|o| self.add_fill_output(chain_id, o));
110 }
111
112 pub fn add_signed_fill(&mut self, chain_id: u64, fill: &SignedFill) {
120 fill.outputs.iter().for_each(|o| self.add_fill_output(chain_id, o));
121 }
122
123 pub fn absorb(&mut self, other: &Self) {
125 for (output_asset, recipients) in other.fills.iter() {
126 let context_recipients = self.fills.entry(*output_asset).or_default();
127 for (recipient, value) in recipients {
128 let filled = context_recipients.entry(*recipient).or_default();
129 *filled = filled.saturating_add(*value);
130 }
131 }
132 }
133
134 pub fn unchecked_unabsorb(&mut self, other: &Self) -> Result<(), MarketError> {
136 for (output_asset, recipients) in other.fills.iter() {
137 if let Some(context_recipients) = self.fills.get_mut(output_asset) {
138 for (recipient, value) in recipients {
139 if let Some(filled) = context_recipients.get_mut(recipient) {
140 *filled =
141 filled.checked_sub(*value).ok_or(MarketError::InsufficientBalance {
142 chain_id: output_asset.0,
143 asset: output_asset.1,
144 recipient: *recipient,
145 amount: *value,
146 })?;
147 }
148 }
149 }
150 }
151 Ok(())
152 }
153
154 pub fn check_aggregate(&self, aggregate: &AggregateOrders) -> Result<(), MarketError> {
156 for (output_asset, recipients) in aggregate.outputs.iter() {
157 if !self.fills.contains_key(output_asset) {
158 return Err(MarketError::MissingAsset {
159 chain_id: output_asset.0,
160 asset: output_asset.1,
161 });
162 };
163
164 for (recipient, value) in recipients {
165 self.check_filled(output_asset, *recipient, *value)?;
166 }
167 }
168 Ok(())
169 }
170
171 pub fn unchecked_remove_aggregate(
176 &mut self,
177 aggregate: &AggregateOrders,
178 ) -> Result<(), MarketError> {
179 for (output_asset, recipients) in aggregate.outputs.iter() {
180 let context_recipients =
181 self.fills.get_mut(output_asset).ok_or(MarketError::MissingAsset {
182 chain_id: output_asset.0,
183 asset: output_asset.1,
184 })?;
185
186 for (recipient, amount) in recipients {
187 let filled = context_recipients.get_mut(recipient).unwrap();
188 *filled = filled.saturating_sub(*amount);
189 }
190 }
191
192 Ok(())
193 }
194
195 pub fn checked_remove_aggregate(
198 &mut self,
199 aggregate: &AggregateOrders,
200 ) -> Result<(), MarketError> {
201 self.check_aggregate(aggregate)?;
202
203 for (output_asset, recipients) in aggregate.outputs.iter() {
204 let context_recipients =
205 self.fills.get_mut(output_asset).expect("checked in check_aggregate");
206
207 for (recipient, amount) in recipients {
208 let filled = context_recipients.get_mut(recipient).unwrap();
209 *filled = filled.checked_sub(*amount).unwrap();
210 }
211 }
212
213 Ok(())
214 }
215
216 pub fn check_order(&self, order: &RollupOrders::Order) -> Result<(), MarketError> {
218 self.check_aggregate(&std::iter::once(order).collect())
219 }
220
221 pub fn checked_remove_order(&mut self, order: &RollupOrders::Order) -> Result<(), MarketError> {
224 let aggregate = std::iter::once(order).collect();
225 self.check_aggregate(&aggregate)?;
226 self.unchecked_remove_aggregate(&aggregate)
227 }
228
229 pub fn unchecked_remove_order(
234 &mut self,
235 order: &RollupOrders::Order,
236 ) -> Result<(), MarketError> {
237 let aggregate = std::iter::once(order).collect();
238 self.unchecked_remove_aggregate(&aggregate)
239 }
240
241 pub const fn fills(&self) -> &HashMap<(u64, Address), HashMap<Address, U256>> {
243 &self.fills
244 }
245
246 pub const fn fills_mut(&mut self) -> &mut HashMap<(u64, Address), HashMap<Address, U256>> {
248 &mut self.fills
249 }
250
251 pub fn check_ru_tx_events(
255 &self,
256 fills: &AggregateFills,
257 orders: &AggregateOrders,
258 ) -> Result<(), MarketError> {
259 let combined = CombinedContext { context: self, extra: fills };
261
262 combined.check_aggregate(orders)?;
263
264 Ok(())
265 }
266
267 pub fn checked_remove_ru_tx_events(
273 &mut self,
274 fills: &AggregateFills,
275 orders: &AggregateOrders,
276 ) -> Result<(), MarketError> {
277 self.check_ru_tx_events(fills, orders)?;
278 self.absorb(fills);
279 self.unchecked_remove_aggregate(orders)
280 }
281
282 pub fn unchecked_remove_ru_tx_events(
286 &mut self,
287 fills: &AggregateFills,
288 orders: &AggregateOrders,
289 ) -> Result<(), MarketError> {
290 self.absorb(fills);
291 self.unchecked_remove_aggregate(orders)
292 }
293}
294
295struct CombinedContext<'a, 'b> {
298 context: &'a AggregateFills,
299 extra: &'b AggregateFills,
300}
301
302impl CombinedContext<'_, '_> {
303 fn balance(&self, output_asset: &(u64, Address), recipient: Address) -> U256 {
305 self.context.filled(output_asset, recipient) + self.extra.filled(output_asset, recipient)
306 }
307
308 fn check_filled(
311 &self,
312 output_asset: &(u64, Address),
313 recipient: Address,
314 amount: U256,
315 ) -> Result<(), MarketError> {
316 if self.balance(output_asset, recipient) < amount {
317 return Err(MarketError::InsufficientBalance {
318 chain_id: output_asset.0,
319 asset: output_asset.1,
320 recipient,
321 amount,
322 });
323 }
324 Ok(())
325 }
326
327 fn check_aggregate(&self, aggregate: &AggregateOrders) -> Result<(), MarketError> {
329 for (output_asset, recipients) in aggregate.outputs.iter() {
330 for (recipient, amount) in recipients {
331 self.check_filled(output_asset, *recipient, *amount)?;
332 }
333 }
334 Ok(())
335 }
336}
337
338#[cfg(test)]
339mod test {
340 use super::*;
341 use signet_zenith::RollupOrders::{Filled, Order, Output};
342
343 #[test]
344 fn basic_fills() {
345 let user_a = Address::with_last_byte(1);
346 let user_b = Address::with_last_byte(2);
347
348 let asset_a = Address::with_last_byte(3);
349 let asset_b = Address::with_last_byte(4);
350
351 let a_to_a =
353 Output { token: asset_a, amount: U256::from(100), recipient: user_a, chainId: 1 };
354 let b_to_b =
355 Output { token: asset_b, amount: U256::from(200), recipient: user_b, chainId: 1 };
356 let a_to_b =
357 Output { token: asset_a, amount: U256::from(300), recipient: user_b, chainId: 1 };
358
359 let fill = Filled { outputs: vec![a_to_a, b_to_b, a_to_b] };
360
361 let order =
362 Order { deadline: U256::ZERO, inputs: vec![], outputs: vec![a_to_a, b_to_b, a_to_b] };
363
364 let mut context = AggregateFills::default();
365 context.add_fill(1, &fill);
366
367 assert_eq!(context.fills().len(), 2);
368 assert_eq!(
369 context.fills().get(&(1, asset_a)).unwrap().get(&user_a).unwrap(),
370 &U256::from(100)
371 );
372 assert_eq!(
373 context.fills().get(&(1, asset_b)).unwrap().get(&user_b).unwrap(),
374 &U256::from(200)
375 );
376 assert_eq!(
377 context.fills().get(&(1, asset_a)).unwrap().get(&user_b).unwrap(),
378 &U256::from(300)
379 );
380
381 context.checked_remove_order(&order).unwrap();
382 assert_eq!(context.fills().len(), 2);
383 assert_eq!(
384 context.fills().get(&(1, asset_a)).unwrap().get(&user_a).unwrap(),
385 &U256::from(0)
386 );
387 assert_eq!(
388 context.fills().get(&(1, asset_b)).unwrap().get(&user_b).unwrap(),
389 &U256::from(0)
390 );
391 assert_eq!(
392 context.fills().get(&(1, asset_a)).unwrap().get(&user_b).unwrap(),
393 &U256::from(0)
394 );
395 }
396
397 #[test]
399 fn empty_everything() {
400 AggregateFills::default()
401 .checked_remove_ru_tx_events(&Default::default(), &Default::default())
402 .unwrap();
403 }
404
405 #[test]
406 fn absorb_unabsorb() {
407 let mut context_a = AggregateFills::default();
408 let mut context_b = AggregateFills::default();
409 let user = Address::with_last_byte(1);
410 let asset = Address::with_last_byte(2);
411 context_a.add_raw_fill(1, asset, user, U256::from(100));
412 context_b.add_raw_fill(1, asset, user, U256::from(200));
413
414 let pre_absorb = context_a.clone();
415 context_a.absorb(&context_b);
416 assert_eq!(context_a.filled(&(1, asset), user), U256::from(300));
417 context_a.unchecked_unabsorb(&context_b).unwrap();
418 assert_eq!(context_a, pre_absorb);
419 }
420}