Skip to main content

rustledger_core/
display_context.rs

1//! Display context for formatting numbers with consistent precision.
2//!
3//! This module provides the [`DisplayContext`] type which tracks the precision
4//! (number of decimal places) seen for each currency during parsing. This allows
5//! numbers to be formatted consistently - for example, if a file contains both
6//! `100 USD` and `50.25 USD`, both should display with 2 decimal places.
7//!
8//! This matches Python beancount's `display_context` behavior.
9//!
10//! # Example
11//!
12//! ```
13//! use rustledger_core::DisplayContext;
14//! use rust_decimal_macros::dec;
15//!
16//! let mut ctx = DisplayContext::new();
17//!
18//! // Track precision from parsed numbers
19//! ctx.update(dec!(100), "USD");      // 0 decimal places
20//! ctx.update(dec!(50.25), "USD");    // 2 decimal places
21//! ctx.update(dec!(1.5), "EUR");      // 1 decimal place
22//!
23//! // Get the precision to use (maximum seen)
24//! assert_eq!(ctx.get_precision("USD"), Some(2));
25//! assert_eq!(ctx.get_precision("EUR"), Some(1));
26//! assert_eq!(ctx.get_precision("GBP"), None);  // Never seen
27//!
28//! // Format a number with the tracked precision
29//! assert_eq!(ctx.format(dec!(100), "USD"), "100.00");
30//! assert_eq!(ctx.format(dec!(50.25), "USD"), "50.25");
31//! assert_eq!(ctx.format(dec!(1.5), "EUR"), "1.5");
32//! ```
33
34use rust_decimal::Decimal;
35use std::collections::HashMap;
36
37/// Display context for formatting numbers with consistent precision per currency.
38///
39/// Tracks the maximum number of decimal places seen for each currency during parsing,
40/// and provides methods to format numbers with that precision.
41#[derive(Debug, Clone, Default)]
42pub struct DisplayContext {
43    /// Maximum decimal places seen per currency.
44    precisions: HashMap<String, u32>,
45
46    /// Whether to render commas in numbers (from `option "render_commas"`).
47    render_commas: bool,
48
49    /// Fixed precision overrides (from `option "display_precision"`).
50    /// These take precedence over inferred precision.
51    fixed_precisions: HashMap<String, u32>,
52}
53
54impl DisplayContext {
55    /// Create a new empty display context.
56    #[must_use]
57    pub fn new() -> Self {
58        Self::default()
59    }
60
61    /// Update the display context with a number for a currency.
62    ///
63    /// This records the decimal precision of the number (number of digits after
64    /// the decimal point) and updates the maximum precision seen for that currency.
65    /// Update the display context with a number for a currency.
66    ///
67    /// This records the decimal precision of the number (number of digits after
68    /// the decimal point) and updates the maximum precision seen for that currency.
69    pub fn update(&mut self, number: Decimal, currency: &str) {
70        let precision = Self::decimal_precision(number);
71        let entry = self.precisions.entry(currency.to_string()).or_insert(0);
72        *entry = (*entry).max(precision);
73    }
74
75    /// Update the display context from another display context.
76    ///
77    /// Takes the maximum precision for each currency from both contexts.
78    pub fn update_from(&mut self, other: &Self) {
79        for (currency, precision) in &other.precisions {
80            let entry = self.precisions.entry(currency.clone()).or_insert(0);
81            *entry = (*entry).max(*precision);
82        }
83    }
84
85    /// Set the `render_commas` flag.
86    pub const fn set_render_commas(&mut self, render_commas: bool) {
87        self.render_commas = render_commas;
88    }
89
90    /// Get the `render_commas` flag.
91    #[must_use]
92    pub const fn render_commas(&self) -> bool {
93        self.render_commas
94    }
95
96    /// Set a fixed precision for a currency (from `option "display_precision"`).
97    ///
98    /// Fixed precision takes precedence over inferred precision.
99    pub fn set_fixed_precision(&mut self, currency: &str, precision: u32) {
100        self.fixed_precisions
101            .insert(currency.to_string(), precision);
102    }
103
104    /// Get the precision for a currency.
105    ///
106    /// Returns the fixed precision if set, otherwise the maximum precision seen,
107    /// or None if the currency has never been seen.
108    #[must_use]
109    pub fn get_precision(&self, currency: &str) -> Option<u32> {
110        // Fixed precision takes precedence
111        if let Some(&precision) = self.fixed_precisions.get(currency) {
112            return Some(precision);
113        }
114        self.precisions.get(currency).copied()
115    }
116
117    /// Quantize a number to the tracked precision for a currency.
118    ///
119    /// Rounds the number to the maximum decimal places seen for the currency.
120    /// If the currency has no tracked precision, returns the number unchanged.
121    #[must_use]
122    pub fn quantize(&self, number: Decimal, currency: &str) -> Decimal {
123        if let Some(dp) = self.get_precision(currency) {
124            number.round_dp(dp)
125        } else {
126            number
127        }
128    }
129
130    /// Format a decimal number for a currency using the tracked precision.
131    ///
132    /// If the currency has been seen, formats with the maximum precision.
133    /// Otherwise, formats with the number's natural precision (no trailing zeros).
134    /// Uses half-up rounding to match Python beancount behavior.
135    #[must_use]
136    pub fn format(&self, number: Decimal, currency: &str) -> String {
137        let precision = self.get_precision(currency);
138
139        if let Some(dp) = precision {
140            // Round with half-up (MidpointAwayFromZero) to match Python behavior
141            // Note: format!("{:.N}", decimal) uses truncation which gives wrong results
142            // for values like -1202.00896 (would give -1202.00 instead of -1202.01)
143            let rounded = number.round_dp(dp);
144            let formatted = format!("{rounded}");
145            // Ensure we have the right number of decimal places (add trailing zeros if needed)
146            let formatted = Self::ensure_decimal_places(&formatted, dp);
147            if self.render_commas {
148                Self::add_commas(&formatted)
149            } else {
150                formatted
151            }
152        } else {
153            // No tracked precision - use natural formatting
154            let formatted = number.normalize().to_string();
155            if self.render_commas {
156                Self::add_commas(&formatted)
157            } else {
158                formatted
159            }
160        }
161    }
162
163    /// Format an amount (number + currency) using the tracked precision.
164    #[must_use]
165    pub fn format_amount(&self, number: Decimal, currency: &str) -> String {
166        format!("{} {}", self.format(number, currency), currency)
167    }
168
169    /// Get the decimal precision (number of digits after decimal point) of a number.
170    const fn decimal_precision(number: Decimal) -> u32 {
171        // scale() returns the number of decimal digits
172        number.scale()
173    }
174
175    /// Ensure a formatted number has exactly `dp` decimal places.
176    /// Adds trailing zeros if needed, or adds ".00..." if no decimal point.
177    fn ensure_decimal_places(s: &str, dp: u32) -> String {
178        if dp == 0 {
179            // No decimal places needed - remove any decimal point
180            return s.split('.').next().unwrap_or(s).to_string();
181        }
182
183        let dp = dp as usize;
184        if let Some(dot_pos) = s.find('.') {
185            let current_decimals = s.len() - dot_pos - 1;
186            if current_decimals >= dp {
187                // Already has enough or more decimals
188                s.to_string()
189            } else {
190                // Need to add trailing zeros
191                let zeros_needed = dp - current_decimals;
192                format!("{s}{}", "0".repeat(zeros_needed))
193            }
194        } else {
195            // No decimal point - add one with zeros
196            format!("{s}.{}", "0".repeat(dp))
197        }
198    }
199
200    /// Add thousand separators (commas) to a formatted number string.
201    fn add_commas(s: &str) -> String {
202        // Split on decimal point
203        let (integer_part, decimal_part) = match s.find('.') {
204            Some(pos) => (&s[..pos], Some(&s[pos..])),
205            None => (s, None),
206        };
207
208        // Handle negative sign
209        let (sign, digits) = if let Some(stripped) = integer_part.strip_prefix('-') {
210            ("-", stripped)
211        } else {
212            ("", integer_part)
213        };
214
215        // Add commas to integer part (from right to left)
216        let mut result = String::with_capacity(digits.len() + digits.len() / 3);
217        for (i, c) in digits.chars().rev().enumerate() {
218            if i > 0 && i % 3 == 0 {
219                result.push(',');
220            }
221            result.push(c);
222        }
223        let integer_with_commas: String = result.chars().rev().collect();
224
225        // Combine parts
226        match decimal_part {
227            Some(dec) => format!("{sign}{integer_with_commas}{dec}"),
228            None => format!("{sign}{integer_with_commas}"),
229        }
230    }
231}
232
233#[cfg(test)]
234mod tests {
235    use super::*;
236    use rust_decimal_macros::dec;
237
238    #[test]
239    fn test_update_and_get_precision() {
240        let mut ctx = DisplayContext::new();
241
242        ctx.update(dec!(100), "USD");
243        assert_eq!(ctx.get_precision("USD"), Some(0));
244
245        ctx.update(dec!(50.25), "USD");
246        assert_eq!(ctx.get_precision("USD"), Some(2));
247
248        // Maximum is kept
249        ctx.update(dec!(1), "USD");
250        assert_eq!(ctx.get_precision("USD"), Some(2));
251
252        // Unknown currency
253        assert_eq!(ctx.get_precision("EUR"), None);
254    }
255
256    #[test]
257    fn test_format_with_precision() {
258        let mut ctx = DisplayContext::new();
259        ctx.update(dec!(100), "USD");
260        ctx.update(dec!(50.25), "USD");
261
262        // Formats to max precision (2)
263        assert_eq!(ctx.format(dec!(100), "USD"), "100.00");
264        assert_eq!(ctx.format(dec!(50.25), "USD"), "50.25");
265        assert_eq!(ctx.format(dec!(7.5), "USD"), "7.50");
266    }
267
268    #[test]
269    fn test_format_unknown_currency() {
270        let ctx = DisplayContext::new();
271
272        // Unknown currency uses natural formatting
273        assert_eq!(ctx.format(dec!(100), "EUR"), "100");
274        assert_eq!(ctx.format(dec!(50.25), "EUR"), "50.25");
275    }
276
277    #[test]
278    fn test_fixed_precision_override() {
279        let mut ctx = DisplayContext::new();
280        ctx.update(dec!(100), "USD");
281        ctx.update(dec!(50.25), "USD");
282
283        // Inferred precision is 2
284        assert_eq!(ctx.get_precision("USD"), Some(2));
285
286        // Set fixed precision to 4
287        ctx.set_fixed_precision("USD", 4);
288        assert_eq!(ctx.get_precision("USD"), Some(4));
289
290        // Formatting uses fixed precision
291        assert_eq!(ctx.format(dec!(100), "USD"), "100.0000");
292    }
293
294    #[test]
295    fn test_render_commas() {
296        let mut ctx = DisplayContext::new();
297        ctx.set_render_commas(true);
298        ctx.update(dec!(1234567.89), "USD");
299
300        assert_eq!(ctx.format(dec!(1234567.89), "USD"), "1,234,567.89");
301        assert_eq!(ctx.format(dec!(1000), "USD"), "1,000.00");
302    }
303
304    #[test]
305    fn test_add_commas() {
306        assert_eq!(DisplayContext::add_commas("1234567"), "1,234,567");
307        assert_eq!(DisplayContext::add_commas("1234567.89"), "1,234,567.89");
308        assert_eq!(DisplayContext::add_commas("-1234567.89"), "-1,234,567.89");
309        assert_eq!(DisplayContext::add_commas("123"), "123");
310        assert_eq!(DisplayContext::add_commas("1"), "1");
311    }
312
313    #[test]
314    fn test_update_from() {
315        let mut ctx1 = DisplayContext::new();
316        ctx1.update(dec!(100), "USD");
317
318        let mut ctx2 = DisplayContext::new();
319        ctx2.update(dec!(50.25), "USD");
320        ctx2.update(dec!(1.5), "EUR");
321
322        ctx1.update_from(&ctx2);
323
324        assert_eq!(ctx1.get_precision("USD"), Some(2));
325        assert_eq!(ctx1.get_precision("EUR"), Some(1));
326    }
327
328    #[test]
329    fn test_format_amount() {
330        let mut ctx = DisplayContext::new();
331        ctx.update(dec!(50.25), "USD");
332
333        assert_eq!(ctx.format_amount(dec!(100), "USD"), "100.00 USD");
334    }
335}