1use rust_decimal::Decimal;
7use rustledger_core::{
8 Amount, Directive, InternedStr, NaiveDate, Price as PriceDirective, Transaction,
9};
10use std::collections::HashMap;
11
12#[derive(Debug, Clone)]
14pub struct PriceEntry {
15 pub date: NaiveDate,
17 pub price: Decimal,
19 pub currency: InternedStr,
21}
22
23#[derive(Debug, Default)]
28pub struct PriceDatabase {
29 prices: HashMap<InternedStr, Vec<PriceEntry>>,
32}
33
34impl PriceDatabase {
35 pub fn new() -> Self {
37 Self {
38 prices: HashMap::new(),
39 }
40 }
41
42 pub fn from_directives(directives: &[Directive]) -> Self {
50 let mut db = Self::new();
51
52 for directive in directives {
53 match directive {
54 Directive::Price(price) => {
55 db.add_price(price);
56 }
57 Directive::Transaction(txn) => {
58 db.add_implicit_prices_from_transaction(txn);
59 }
60 _ => {}
61 }
62 }
63
64 db.sort_prices();
66
67 db
68 }
69
70 pub fn sort_prices(&mut self) {
74 for entries in self.prices.values_mut() {
75 entries.sort_by_key(|e| e.date);
76 }
77 }
78
79 pub fn add_price(&mut self, price: &PriceDirective) {
81 let entry = PriceEntry {
82 date: price.date,
83 price: price.amount.number,
84 currency: price.amount.currency.clone(),
85 };
86
87 self.prices
88 .entry(price.currency.clone())
89 .or_default()
90 .push(entry);
91 }
92
93 pub fn add_implicit_prices_from_transaction(&mut self, txn: &Transaction) {
101 for posting in &txn.postings {
102 if let Some(units) = posting.amount() {
104 if let Some(price_annotation) = &posting.price
107 && let Some(price_amount) = price_annotation.amount()
108 {
109 let per_unit_price = if price_annotation.is_unit() {
111 price_amount.number
112 } else if !units.number.is_zero() {
113 price_amount.number / units.number.abs()
115 } else {
116 continue;
117 };
118
119 self.add_implicit_price(
120 txn.date,
121 &units.currency,
122 per_unit_price,
123 &price_amount.currency,
124 );
125 continue;
127 }
128
129 if let Some(cost_spec) = &posting.cost {
131 if let (Some(number_per), Some(currency)) =
132 (&cost_spec.number_per, &cost_spec.currency)
133 {
134 self.add_implicit_price(txn.date, &units.currency, *number_per, currency);
135 } else if let (Some(number_total), Some(currency)) =
136 (&cost_spec.number_total, &cost_spec.currency)
137 {
138 if !units.number.is_zero() {
140 let per_unit = *number_total / units.number.abs();
141 self.add_implicit_price(txn.date, &units.currency, per_unit, currency);
142 }
143 }
144 }
145 }
146 }
147 }
148
149 fn add_implicit_price(
151 &mut self,
152 date: NaiveDate,
153 base_currency: &InternedStr,
154 price: Decimal,
155 quote_currency: &InternedStr,
156 ) {
157 let entry = PriceEntry {
158 date,
159 price,
160 currency: quote_currency.clone(),
161 };
162
163 self.prices
164 .entry(base_currency.clone())
165 .or_default()
166 .push(entry);
167 }
168
169 pub fn get_price(&self, base: &str, quote: &str, date: NaiveDate) -> Option<Decimal> {
174 if base == quote {
176 return Some(Decimal::ONE);
177 }
178
179 if let Some(price) = self.get_direct_price(base, quote, date) {
181 return Some(price);
182 }
183
184 if let Some(price) = self.get_direct_price(quote, base, date)
186 && price != Decimal::ZERO
187 {
188 return Some(Decimal::ONE / price);
189 }
190
191 self.get_chained_price(base, quote, date)
193 }
194
195 fn get_direct_price(&self, base: &str, quote: &str, date: NaiveDate) -> Option<Decimal> {
197 if let Some(entries) = self.prices.get(base) {
198 for entry in entries.iter().rev() {
199 if entry.date <= date && entry.currency == quote {
200 return Some(entry.price);
201 }
202 }
203 }
204 None
205 }
206
207 fn get_chained_price(&self, base: &str, quote: &str, date: NaiveDate) -> Option<Decimal> {
210 let intermediates: Vec<InternedStr> = if let Some(entries) = self.prices.get(base) {
212 entries
213 .iter()
214 .filter(|e| e.date <= date)
215 .map(|e| e.currency.clone())
216 .collect()
217 } else {
218 Vec::new()
219 };
220
221 for intermediate in intermediates {
223 if intermediate == quote {
224 continue; }
226
227 if let Some(price1) = self.get_direct_price(base, &intermediate, date) {
229 if let Some(price2) = self.get_direct_price(&intermediate, quote, date) {
231 return Some(price1 * price2);
232 }
233 if let Some(price2) = self.get_direct_price(quote, &intermediate, date)
235 && price2 != Decimal::ZERO
236 {
237 return Some(price1 / price2);
238 }
239 }
240 }
241
242 for (currency, entries) in &self.prices {
244 for entry in entries.iter().rev() {
245 if entry.date <= date && entry.currency == base && entry.price != Decimal::ZERO {
246 let price1 = Decimal::ONE / entry.price;
248
249 if let Some(price2) = self.get_direct_price(currency, quote, date) {
251 return Some(price1 * price2);
252 }
253 if let Some(price2) = self.get_direct_price(quote, currency, date)
254 && price2 != Decimal::ZERO
255 {
256 return Some(price1 / price2);
257 }
258 }
259 }
260 }
261
262 None
263 }
264
265 pub fn get_latest_price(&self, base: &str, quote: &str) -> Option<Decimal> {
269 if base == quote {
271 return Some(Decimal::ONE);
272 }
273
274 if let Some(price) = self.get_direct_latest_price(base, quote) {
276 return Some(price);
277 }
278
279 if let Some(price) = self.get_direct_latest_price(quote, base)
281 && price != Decimal::ZERO
282 {
283 return Some(Decimal::ONE / price);
284 }
285
286 self.get_chained_latest_price(base, quote)
288 }
289
290 fn get_direct_latest_price(&self, base: &str, quote: &str) -> Option<Decimal> {
292 if let Some(entries) = self.prices.get(base) {
293 for entry in entries.iter().rev() {
295 if entry.currency == quote {
296 return Some(entry.price);
297 }
298 }
299 }
300 None
301 }
302
303 fn get_chained_latest_price(&self, base: &str, quote: &str) -> Option<Decimal> {
306 let intermediates: Vec<InternedStr> = if let Some(entries) = self.prices.get(base) {
308 entries.iter().map(|e| e.currency.clone()).collect()
309 } else {
310 Vec::new()
311 };
312
313 for intermediate in intermediates {
315 if intermediate == quote {
316 continue; }
318
319 if let Some(price1) = self.get_direct_latest_price(base, &intermediate) {
321 if let Some(price2) = self.get_direct_latest_price(&intermediate, quote) {
323 return Some(price1 * price2);
324 }
325 if let Some(price2) = self.get_direct_latest_price(quote, &intermediate)
327 && price2 != Decimal::ZERO
328 {
329 return Some(price1 / price2);
330 }
331 }
332 }
333
334 for (currency, entries) in &self.prices {
336 for entry in entries.iter().rev() {
337 if entry.currency == base && entry.price != Decimal::ZERO {
338 let price1 = Decimal::ONE / entry.price;
340
341 if let Some(price2) = self.get_direct_latest_price(currency, quote) {
343 return Some(price1 * price2);
344 }
345 if let Some(price2) = self.get_direct_latest_price(quote, currency)
346 && price2 != Decimal::ZERO
347 {
348 return Some(price1 / price2);
349 }
350 }
351 }
352 }
353
354 None
355 }
356
357 pub fn convert(&self, amount: &Amount, to_currency: &str, date: NaiveDate) -> Option<Amount> {
361 if amount.currency == to_currency {
362 return Some(amount.clone());
363 }
364
365 self.get_price(&amount.currency, to_currency, date)
366 .map(|price| Amount::new(amount.number * price, to_currency))
367 }
368
369 pub fn convert_latest(&self, amount: &Amount, to_currency: &str) -> Option<Amount> {
371 if amount.currency == to_currency {
372 return Some(amount.clone());
373 }
374
375 self.get_latest_price(&amount.currency, to_currency)
376 .map(|price| Amount::new(amount.number * price, to_currency))
377 }
378
379 pub fn currencies(&self) -> impl Iterator<Item = &str> {
381 self.prices.keys().map(InternedStr::as_str)
382 }
383
384 pub fn has_prices(&self, currency: &str) -> bool {
386 self.prices.contains_key(currency)
387 }
388
389 pub fn len(&self) -> usize {
391 self.prices.values().map(Vec::len).sum()
392 }
393
394 pub fn is_empty(&self) -> bool {
396 self.prices.is_empty()
397 }
398
399 pub fn iter_entries(&self) -> impl Iterator<Item = (&str, NaiveDate, Decimal, &str)> {
403 self.prices.iter().flat_map(|(base, entries)| {
404 entries
405 .iter()
406 .map(move |e| (base.as_str(), e.date, e.price, e.currency.as_str()))
407 })
408 }
409}
410
411#[cfg(test)]
412mod tests {
413 use super::*;
414 use rust_decimal_macros::dec;
415
416 fn date(y: i32, m: u32, d: u32) -> NaiveDate {
417 NaiveDate::from_ymd_opt(y, m, d).unwrap()
418 }
419
420 #[test]
421 fn test_price_lookup() {
422 let mut db = PriceDatabase::new();
423
424 db.add_price(&PriceDirective {
426 date: date(2024, 1, 1),
427 currency: "AAPL".into(),
428 amount: Amount::new(dec!(150.00), "USD"),
429 meta: Default::default(),
430 });
431
432 db.add_price(&PriceDirective {
433 date: date(2024, 6, 1),
434 currency: "AAPL".into(),
435 amount: Amount::new(dec!(180.00), "USD"),
436 meta: Default::default(),
437 });
438
439 for entries in db.prices.values_mut() {
441 entries.sort_by_key(|e| e.date);
442 }
443
444 assert_eq!(
446 db.get_price("AAPL", "USD", date(2024, 1, 1)),
447 Some(dec!(150.00))
448 );
449
450 assert_eq!(
452 db.get_price("AAPL", "USD", date(2024, 6, 15)),
453 Some(dec!(180.00))
454 );
455
456 assert_eq!(
458 db.get_price("AAPL", "USD", date(2024, 3, 15)),
459 Some(dec!(150.00))
460 );
461
462 assert_eq!(db.get_price("AAPL", "USD", date(2023, 12, 31)), None);
464 }
465
466 #[test]
467 fn test_inverse_price() {
468 let mut db = PriceDatabase::new();
469
470 db.add_price(&PriceDirective {
472 date: date(2024, 1, 1),
473 currency: "USD".into(),
474 amount: Amount::new(dec!(0.92), "EUR"),
475 meta: Default::default(),
476 });
477
478 for entries in db.prices.values_mut() {
480 entries.sort_by_key(|e| e.date);
481 }
482
483 assert_eq!(
485 db.get_price("USD", "EUR", date(2024, 1, 1)),
486 Some(dec!(0.92))
487 );
488
489 let inverse = db.get_price("EUR", "USD", date(2024, 1, 1)).unwrap();
491 assert!(inverse > dec!(1.08) && inverse < dec!(1.09));
493 }
494
495 #[test]
496 fn test_convert() {
497 let mut db = PriceDatabase::new();
498
499 db.add_price(&PriceDirective {
500 date: date(2024, 1, 1),
501 currency: "AAPL".into(),
502 amount: Amount::new(dec!(150.00), "USD"),
503 meta: Default::default(),
504 });
505
506 for entries in db.prices.values_mut() {
507 entries.sort_by_key(|e| e.date);
508 }
509
510 let shares = Amount::new(dec!(10), "AAPL");
511 let usd = db.convert(&shares, "USD", date(2024, 1, 1)).unwrap();
512
513 assert_eq!(usd.number, dec!(1500.00));
514 assert_eq!(usd.currency, "USD");
515 }
516
517 #[test]
518 fn test_same_currency_convert() {
519 let db = PriceDatabase::new();
520 let amount = Amount::new(dec!(100), "USD");
521
522 let result = db.convert(&amount, "USD", date(2024, 1, 1)).unwrap();
523 assert_eq!(result.number, dec!(100));
524 assert_eq!(result.currency, "USD");
525 }
526
527 #[test]
528 fn test_from_directives() {
529 let directives = vec![
530 Directive::Price(PriceDirective {
531 date: date(2024, 1, 1),
532 currency: "AAPL".into(),
533 amount: Amount::new(dec!(150.00), "USD"),
534 meta: Default::default(),
535 }),
536 Directive::Price(PriceDirective {
537 date: date(2024, 1, 1),
538 currency: "EUR".into(),
539 amount: Amount::new(dec!(1.10), "USD"),
540 meta: Default::default(),
541 }),
542 ];
543
544 let db = PriceDatabase::from_directives(&directives);
545
546 assert_eq!(db.len(), 2);
547 assert!(db.has_prices("AAPL"));
548 assert!(db.has_prices("EUR"));
549 }
550
551 #[test]
552 fn test_chained_price_lookup() {
553 let mut db = PriceDatabase::new();
554
555 db.add_price(&PriceDirective {
557 date: date(2024, 1, 1),
558 currency: "AAPL".into(),
559 amount: Amount::new(dec!(150.00), "USD"),
560 meta: Default::default(),
561 });
562
563 db.add_price(&PriceDirective {
565 date: date(2024, 1, 1),
566 currency: "USD".into(),
567 amount: Amount::new(dec!(0.92), "EUR"),
568 meta: Default::default(),
569 });
570
571 for entries in db.prices.values_mut() {
573 entries.sort_by_key(|e| e.date);
574 }
575
576 assert_eq!(
578 db.get_price("AAPL", "USD", date(2024, 1, 1)),
579 Some(dec!(150.00))
580 );
581
582 assert_eq!(
584 db.get_price("USD", "EUR", date(2024, 1, 1)),
585 Some(dec!(0.92))
586 );
587
588 let chained = db.get_price("AAPL", "EUR", date(2024, 1, 1)).unwrap();
591 assert_eq!(chained, dec!(138.00));
592 }
593
594 #[test]
595 fn test_chained_price_with_inverse() {
596 let mut db = PriceDatabase::new();
597
598 db.add_price(&PriceDirective {
600 date: date(2024, 1, 1),
601 currency: "BTC".into(),
602 amount: Amount::new(dec!(40000.00), "USD"),
603 meta: Default::default(),
604 });
605
606 db.add_price(&PriceDirective {
608 date: date(2024, 1, 1),
609 currency: "EUR".into(),
610 amount: Amount::new(dec!(1.10), "USD"),
611 meta: Default::default(),
612 });
613
614 for entries in db.prices.values_mut() {
616 entries.sort_by_key(|e| e.date);
617 }
618
619 let chained = db.get_price("BTC", "EUR", date(2024, 1, 1)).unwrap();
624 assert!(chained > dec!(36363) && chained < dec!(36364));
626 }
627
628 #[test]
629 fn test_chained_price_no_path() {
630 let mut db = PriceDatabase::new();
631
632 db.add_price(&PriceDirective {
634 date: date(2024, 1, 1),
635 currency: "AAPL".into(),
636 amount: Amount::new(dec!(150.00), "USD"),
637 meta: Default::default(),
638 });
639
640 db.add_price(&PriceDirective {
642 date: date(2024, 1, 1),
643 currency: "GBP".into(),
644 amount: Amount::new(dec!(1.17), "EUR"),
645 meta: Default::default(),
646 });
647
648 for entries in db.prices.values_mut() {
650 entries.sort_by_key(|e| e.date);
651 }
652
653 assert_eq!(db.get_price("AAPL", "GBP", date(2024, 1, 1)), None);
655 }
656
657 #[test]
662 fn test_implicit_price_from_annotation() {
663 use rustledger_core::{CostSpec, Posting, PriceAnnotation, Transaction};
664
665 let txn = Transaction::new(date(2024, 1, 15), "Sell stock")
667 .with_posting(
668 Posting::new("Assets:Stocks", Amount::new(dec!(-5), "ABC"))
669 .with_cost(
670 CostSpec::default()
671 .with_number_per(dec!(1.25))
672 .with_currency("EUR"),
673 )
674 .with_price(PriceAnnotation::Unit(Amount::new(dec!(1.40), "EUR"))),
675 )
676 .with_posting(Posting::new("Assets:Cash", Amount::new(dec!(7.00), "EUR")));
677
678 let directives = vec![Directive::Transaction(txn)];
679 let db = PriceDatabase::from_directives(&directives);
680
681 let price = db.get_price("ABC", "EUR", date(2024, 1, 15));
683 assert_eq!(price, Some(dec!(1.40)));
684 }
685
686 #[test]
687 fn test_implicit_price_from_cost_only() {
688 use rustledger_core::{CostSpec, Posting, Transaction};
689
690 let txn = Transaction::new(date(2024, 1, 10), "Buy stock")
692 .with_posting(
693 Posting::new("Assets:Stocks", Amount::new(dec!(10), "XYZ")).with_cost(
694 CostSpec::default()
695 .with_number_per(dec!(50.00))
696 .with_currency("USD"),
697 ),
698 )
699 .with_posting(Posting::new("Assets:Cash", Amount::new(dec!(-500), "USD")));
700
701 let directives = vec![Directive::Transaction(txn)];
702 let db = PriceDatabase::from_directives(&directives);
703
704 let price = db.get_price("XYZ", "USD", date(2024, 1, 10));
706 assert_eq!(price, Some(dec!(50.00)));
707 }
708
709 #[test]
710 fn test_implicit_price_from_total_annotation() {
711 use rustledger_core::{Posting, PriceAnnotation, Transaction};
712
713 let txn = Transaction::new(date(2024, 1, 15), "Sell")
715 .with_posting(
716 Posting::new("Assets:Stocks", Amount::new(dec!(-10), "ABC"))
717 .with_price(PriceAnnotation::Total(Amount::new(dec!(1500), "USD"))),
718 )
719 .with_posting(Posting::new("Assets:Cash", Amount::new(dec!(1500), "USD")));
720
721 let directives = vec![Directive::Transaction(txn)];
722 let db = PriceDatabase::from_directives(&directives);
723
724 let price = db.get_price("ABC", "USD", date(2024, 1, 15));
726 assert_eq!(price, Some(dec!(150)));
727 }
728
729 #[test]
730 fn test_implicit_price_annotation_takes_priority_over_cost() {
731 use rustledger_core::{CostSpec, Posting, PriceAnnotation, Transaction};
732
733 let txn = Transaction::new(date(2024, 1, 15), "Sell")
736 .with_posting(
737 Posting::new("Assets:Stocks", Amount::new(dec!(-5), "ABC"))
738 .with_cost(
739 CostSpec::default()
740 .with_number_per(dec!(1.25))
741 .with_currency("EUR"),
742 )
743 .with_price(PriceAnnotation::Unit(Amount::new(dec!(1.40), "EUR"))),
744 )
745 .with_posting(Posting::new("Assets:Cash", Amount::new(dec!(7.00), "EUR")));
746
747 let directives = vec![Directive::Transaction(txn)];
748 let db = PriceDatabase::from_directives(&directives);
749
750 let price = db.get_price("ABC", "EUR", date(2024, 1, 15));
752 assert_eq!(price, Some(dec!(1.40)));
753 }
754
755 #[test]
756 fn test_implicit_price_combined_with_explicit() {
757 use rustledger_core::{CostSpec, Posting, PriceAnnotation, Transaction};
758
759 let explicit_price = PriceDirective {
761 date: date(2024, 1, 10),
762 currency: "ABC".into(),
763 amount: Amount::new(dec!(1.30), "EUR"),
764 meta: Default::default(),
765 };
766
767 let txn = Transaction::new(date(2024, 1, 15), "Sell")
768 .with_posting(
769 Posting::new("Assets:Stocks", Amount::new(dec!(-5), "ABC"))
770 .with_cost(
771 CostSpec::default()
772 .with_number_per(dec!(1.25))
773 .with_currency("EUR"),
774 )
775 .with_price(PriceAnnotation::Unit(Amount::new(dec!(1.40), "EUR"))),
776 )
777 .with_posting(Posting::new("Assets:Cash", Amount::new(dec!(7.00), "EUR")));
778
779 let directives = vec![
780 Directive::Price(explicit_price),
781 Directive::Transaction(txn),
782 ];
783 let db = PriceDatabase::from_directives(&directives);
784
785 assert_eq!(
787 db.get_price("ABC", "EUR", date(2024, 1, 10)),
788 Some(dec!(1.30))
789 );
790
791 assert_eq!(db.get_latest_price("ABC", "EUR"), Some(dec!(1.40)));
793 }
794}