Skip to main content

sheetkit_core/
pivot.rs

1//! Pivot table configuration and builder.
2
3use crate::error::{Error, Result};
4
5/// Aggregate function for pivot table data fields.
6#[derive(Debug, Clone, PartialEq)]
7pub enum AggregateFunction {
8    Sum,
9    Count,
10    Average,
11    Max,
12    Min,
13    Product,
14    CountNums,
15    StdDev,
16    StdDevP,
17    Var,
18    VarP,
19}
20
21impl AggregateFunction {
22    /// Returns the XML string representation of this aggregate function.
23    pub fn to_xml_str(&self) -> &str {
24        match self {
25            Self::Sum => "sum",
26            Self::Count => "count",
27            Self::Average => "average",
28            Self::Max => "max",
29            Self::Min => "min",
30            Self::Product => "product",
31            Self::CountNums => "countNums",
32            Self::StdDev => "stdDev",
33            Self::StdDevP => "stdDevp",
34            Self::Var => "var",
35            Self::VarP => "varp",
36        }
37    }
38
39    /// Parses an aggregate function from its XML string representation.
40    pub fn from_xml_str(s: &str) -> Option<Self> {
41        match s {
42            "sum" => Some(Self::Sum),
43            "count" => Some(Self::Count),
44            "average" => Some(Self::Average),
45            "max" => Some(Self::Max),
46            "min" => Some(Self::Min),
47            "product" => Some(Self::Product),
48            "countNums" => Some(Self::CountNums),
49            "stdDev" => Some(Self::StdDev),
50            "stdDevp" => Some(Self::StdDevP),
51            "var" => Some(Self::Var),
52            "varp" => Some(Self::VarP),
53            _ => None,
54        }
55    }
56}
57
58/// Configuration for adding a pivot table.
59#[derive(Debug, Clone)]
60pub struct PivotTableConfig {
61    /// Name of the pivot table.
62    pub name: String,
63    /// Source data sheet name.
64    pub source_sheet: String,
65    /// Source data range (e.g., "A1:D10").
66    pub source_range: String,
67    /// Target sheet name where the pivot table will be placed.
68    pub target_sheet: String,
69    /// Target cell (top-left corner of pivot table, e.g., "A1").
70    pub target_cell: String,
71    /// Row fields (column names from source data).
72    pub rows: Vec<PivotField>,
73    /// Column fields.
74    pub columns: Vec<PivotField>,
75    /// Data/value fields.
76    pub data: Vec<PivotDataField>,
77}
78
79/// A field used as row or column in the pivot table.
80#[derive(Debug, Clone)]
81pub struct PivotField {
82    /// Column name from the source data header row.
83    pub name: String,
84}
85
86/// A data/value field in the pivot table.
87#[derive(Debug, Clone)]
88pub struct PivotDataField {
89    /// Column name from the source data header row.
90    pub name: String,
91    /// Aggregate function to apply.
92    pub function: AggregateFunction,
93    /// Optional custom display name.
94    pub display_name: Option<String>,
95}
96
97/// Information about an existing pivot table.
98#[derive(Debug, Clone)]
99pub struct PivotTableInfo {
100    pub name: String,
101    pub source_sheet: String,
102    pub source_range: String,
103    pub target_sheet: String,
104    pub location: String,
105}
106
107/// Builds the pivot table definition XML from config and field names.
108pub fn build_pivot_table_xml(
109    config: &PivotTableConfig,
110    cache_id: u32,
111    field_names: &[String],
112) -> Result<sheetkit_xml::pivot_table::PivotTableDefinition> {
113    use sheetkit_xml::pivot_table::*;
114
115    let ns = sheetkit_xml::namespaces::SPREADSHEET_ML;
116
117    let find_field_index = |name: &str| -> Result<usize> {
118        field_names.iter().position(|n| n == name).ok_or_else(|| {
119            Error::Internal(format!("pivot field '{}' not found in source data", name))
120        })
121    };
122
123    let mut pivot_field_defs = Vec::new();
124    for field_name in field_names {
125        let is_row = config.rows.iter().any(|r| r.name == *field_name);
126        let is_col = config.columns.iter().any(|c| c.name == *field_name);
127        let is_data = config.data.iter().any(|d| d.name == *field_name);
128
129        let axis = if is_row {
130            Some("axisRow".to_string())
131        } else if is_col {
132            Some("axisCol".to_string())
133        } else {
134            None
135        };
136
137        pivot_field_defs.push(PivotFieldDef {
138            axis,
139            data_field: if is_data { Some(true) } else { None },
140            show_all: Some(false),
141            items: None,
142        });
143    }
144
145    let row_fields = if config.rows.is_empty() {
146        None
147    } else {
148        let fields: Result<Vec<FieldRef>> = config
149            .rows
150            .iter()
151            .map(|r| find_field_index(&r.name).map(|i| FieldRef { index: i as i32 }))
152            .collect();
153        Some(FieldList {
154            count: Some(config.rows.len() as u32),
155            fields: fields?,
156        })
157    };
158
159    let col_fields = if config.columns.is_empty() {
160        None
161    } else {
162        let fields: Result<Vec<FieldRef>> = config
163            .columns
164            .iter()
165            .map(|c| find_field_index(&c.name).map(|i| FieldRef { index: i as i32 }))
166            .collect();
167        Some(FieldList {
168            count: Some(config.columns.len() as u32),
169            fields: fields?,
170        })
171    };
172
173    let data_fields = if config.data.is_empty() {
174        None
175    } else {
176        let fields: Result<Vec<DataFieldDef>> = config
177            .data
178            .iter()
179            .map(|d| {
180                let idx = find_field_index(&d.name)?;
181                Ok(DataFieldDef {
182                    name: d.display_name.clone().or_else(|| {
183                        Some(format!(
184                            "{} of {}",
185                            capitalize_first(d.function.to_xml_str()),
186                            d.name
187                        ))
188                    }),
189                    field_index: idx as u32,
190                    subtotal: Some(d.function.to_xml_str().to_string()),
191                    base_field: Some(0),
192                    base_item: Some(0),
193                })
194            })
195            .collect();
196        Some(DataFields {
197            count: Some(config.data.len() as u32),
198            fields: fields?,
199        })
200    };
201
202    Ok(PivotTableDefinition {
203        xmlns: ns.to_string(),
204        name: config.name.clone(),
205        cache_id,
206        data_on_rows: Some(false),
207        apply_number_formats: Some(false),
208        apply_border_formats: Some(false),
209        apply_font_formats: Some(false),
210        apply_pattern_formats: Some(false),
211        apply_alignment_formats: Some(false),
212        apply_width_height_formats: Some(true),
213        location: PivotLocation {
214            reference: config.target_cell.clone(),
215            first_header_row: 1,
216            first_data_row: 1,
217            first_data_col: 1,
218        },
219        pivot_fields: PivotFields {
220            count: Some(field_names.len() as u32),
221            fields: pivot_field_defs,
222        },
223        row_fields,
224        col_fields,
225        data_fields,
226    })
227}
228
229/// Builds the pivot cache definition XML.
230pub fn build_pivot_cache_definition(
231    source_sheet: &str,
232    source_range: &str,
233    field_names: &[String],
234) -> sheetkit_xml::pivot_cache::PivotCacheDefinition {
235    use sheetkit_xml::pivot_cache::*;
236
237    let cache_fields = CacheFields {
238        count: Some(field_names.len() as u32),
239        fields: field_names
240            .iter()
241            .map(|name| CacheField {
242                name: name.clone(),
243                num_fmt_id: Some(0),
244                shared_items: Some(SharedItems {
245                    contains_semi_mixed_types: None,
246                    contains_string: None,
247                    contains_number: None,
248                    contains_blank: None,
249                    count: Some(0),
250                    string_items: vec![],
251                    number_items: vec![],
252                }),
253            })
254            .collect(),
255    };
256
257    PivotCacheDefinition {
258        xmlns: sheetkit_xml::namespaces::SPREADSHEET_ML.to_string(),
259        xmlns_r: sheetkit_xml::namespaces::RELATIONSHIPS.to_string(),
260        r_id: None,
261        record_count: Some(0),
262        cache_source: CacheSource {
263            source_type: "worksheet".to_string(),
264            worksheet_source: Some(WorksheetSource {
265                reference: source_range.to_string(),
266                sheet: source_sheet.to_string(),
267            }),
268        },
269        cache_fields,
270    }
271}
272
273fn capitalize_first(s: &str) -> String {
274    let mut c = s.chars();
275    match c.next() {
276        None => String::new(),
277        Some(f) => f.to_uppercase().collect::<String>() + c.as_str(),
278    }
279}
280
281#[cfg(test)]
282mod tests {
283    use super::*;
284
285    #[test]
286    fn test_aggregate_function_to_xml_str() {
287        assert_eq!(AggregateFunction::Sum.to_xml_str(), "sum");
288        assert_eq!(AggregateFunction::Count.to_xml_str(), "count");
289        assert_eq!(AggregateFunction::Average.to_xml_str(), "average");
290        assert_eq!(AggregateFunction::Max.to_xml_str(), "max");
291        assert_eq!(AggregateFunction::Min.to_xml_str(), "min");
292        assert_eq!(AggregateFunction::Product.to_xml_str(), "product");
293        assert_eq!(AggregateFunction::CountNums.to_xml_str(), "countNums");
294        assert_eq!(AggregateFunction::StdDev.to_xml_str(), "stdDev");
295        assert_eq!(AggregateFunction::StdDevP.to_xml_str(), "stdDevp");
296        assert_eq!(AggregateFunction::Var.to_xml_str(), "var");
297        assert_eq!(AggregateFunction::VarP.to_xml_str(), "varp");
298    }
299
300    #[test]
301    fn test_aggregate_function_from_xml_str() {
302        assert_eq!(
303            AggregateFunction::from_xml_str("sum"),
304            Some(AggregateFunction::Sum)
305        );
306        assert_eq!(
307            AggregateFunction::from_xml_str("count"),
308            Some(AggregateFunction::Count)
309        );
310        assert_eq!(
311            AggregateFunction::from_xml_str("average"),
312            Some(AggregateFunction::Average)
313        );
314        assert_eq!(
315            AggregateFunction::from_xml_str("max"),
316            Some(AggregateFunction::Max)
317        );
318        assert_eq!(
319            AggregateFunction::from_xml_str("min"),
320            Some(AggregateFunction::Min)
321        );
322        assert_eq!(
323            AggregateFunction::from_xml_str("product"),
324            Some(AggregateFunction::Product)
325        );
326        assert_eq!(
327            AggregateFunction::from_xml_str("countNums"),
328            Some(AggregateFunction::CountNums)
329        );
330        assert_eq!(
331            AggregateFunction::from_xml_str("stdDev"),
332            Some(AggregateFunction::StdDev)
333        );
334        assert_eq!(
335            AggregateFunction::from_xml_str("stdDevp"),
336            Some(AggregateFunction::StdDevP)
337        );
338        assert_eq!(
339            AggregateFunction::from_xml_str("var"),
340            Some(AggregateFunction::Var)
341        );
342        assert_eq!(
343            AggregateFunction::from_xml_str("varp"),
344            Some(AggregateFunction::VarP)
345        );
346    }
347
348    #[test]
349    fn test_aggregate_function_from_xml_str_unknown() {
350        assert_eq!(AggregateFunction::from_xml_str("unknown"), None);
351        assert_eq!(AggregateFunction::from_xml_str(""), None);
352        assert_eq!(AggregateFunction::from_xml_str("SUM"), None);
353    }
354
355    #[test]
356    fn test_aggregate_function_roundtrip() {
357        let functions = vec![
358            AggregateFunction::Sum,
359            AggregateFunction::Count,
360            AggregateFunction::Average,
361            AggregateFunction::Max,
362            AggregateFunction::Min,
363            AggregateFunction::Product,
364            AggregateFunction::CountNums,
365            AggregateFunction::StdDev,
366            AggregateFunction::StdDevP,
367            AggregateFunction::Var,
368            AggregateFunction::VarP,
369        ];
370        for func in functions {
371            let xml_str = func.to_xml_str();
372            let parsed = AggregateFunction::from_xml_str(xml_str).unwrap();
373            assert_eq!(func, parsed);
374        }
375    }
376
377    #[test]
378    fn test_capitalize_first() {
379        assert_eq!(capitalize_first("sum"), "Sum");
380        assert_eq!(capitalize_first("count"), "Count");
381        assert_eq!(capitalize_first("average"), "Average");
382        assert_eq!(capitalize_first(""), "");
383        assert_eq!(capitalize_first("a"), "A");
384    }
385
386    #[test]
387    fn test_build_pivot_table_xml_basic() {
388        let config = PivotTableConfig {
389            name: "PivotTable1".to_string(),
390            source_sheet: "Data".to_string(),
391            source_range: "A1:C5".to_string(),
392            target_sheet: "Pivot".to_string(),
393            target_cell: "A1".to_string(),
394            rows: vec![PivotField {
395                name: "Region".to_string(),
396            }],
397            columns: vec![],
398            data: vec![PivotDataField {
399                name: "Sales".to_string(),
400                function: AggregateFunction::Sum,
401                display_name: None,
402            }],
403        };
404        let field_names = vec![
405            "Region".to_string(),
406            "Product".to_string(),
407            "Sales".to_string(),
408        ];
409
410        let def = build_pivot_table_xml(&config, 0, &field_names).unwrap();
411        assert_eq!(def.name, "PivotTable1");
412        assert_eq!(def.cache_id, 0);
413        assert_eq!(def.pivot_fields.count, Some(3));
414        assert_eq!(def.pivot_fields.fields.len(), 3);
415
416        // Region is axisRow
417        assert_eq!(def.pivot_fields.fields[0].axis, Some("axisRow".to_string()));
418        assert_eq!(def.pivot_fields.fields[0].data_field, None);
419
420        // Product has no axis
421        assert_eq!(def.pivot_fields.fields[1].axis, None);
422
423        // Sales is data field
424        assert_eq!(def.pivot_fields.fields[2].axis, None);
425        assert_eq!(def.pivot_fields.fields[2].data_field, Some(true));
426
427        // Row fields
428        let row_fields = def.row_fields.unwrap();
429        assert_eq!(row_fields.count, Some(1));
430        assert_eq!(row_fields.fields[0].index, 0);
431
432        // No col fields
433        assert!(def.col_fields.is_none());
434
435        // Data fields
436        let data_fields = def.data_fields.unwrap();
437        assert_eq!(data_fields.count, Some(1));
438        assert_eq!(data_fields.fields[0].field_index, 2);
439        assert_eq!(data_fields.fields[0].subtotal, Some("sum".to_string()));
440        assert_eq!(data_fields.fields[0].name, Some("Sum of Sales".to_string()));
441    }
442
443    #[test]
444    fn test_build_pivot_table_xml_with_columns() {
445        let config = PivotTableConfig {
446            name: "SalesReport".to_string(),
447            source_sheet: "Data".to_string(),
448            source_range: "A1:D10".to_string(),
449            target_sheet: "Report".to_string(),
450            target_cell: "A1".to_string(),
451            rows: vec![PivotField {
452                name: "Region".to_string(),
453            }],
454            columns: vec![PivotField {
455                name: "Quarter".to_string(),
456            }],
457            data: vec![PivotDataField {
458                name: "Revenue".to_string(),
459                function: AggregateFunction::Average,
460                display_name: Some("Avg Revenue".to_string()),
461            }],
462        };
463        let field_names = vec![
464            "Region".to_string(),
465            "Quarter".to_string(),
466            "Revenue".to_string(),
467        ];
468
469        let def = build_pivot_table_xml(&config, 1, &field_names).unwrap();
470        assert_eq!(def.cache_id, 1);
471
472        // Region = axisRow, Quarter = axisCol
473        assert_eq!(def.pivot_fields.fields[0].axis, Some("axisRow".to_string()));
474        assert_eq!(def.pivot_fields.fields[1].axis, Some("axisCol".to_string()));
475
476        let col_fields = def.col_fields.unwrap();
477        assert_eq!(col_fields.count, Some(1));
478        assert_eq!(col_fields.fields[0].index, 1);
479
480        let data_fields = def.data_fields.unwrap();
481        assert_eq!(data_fields.fields[0].name, Some("Avg Revenue".to_string()));
482        assert_eq!(data_fields.fields[0].subtotal, Some("average".to_string()));
483    }
484
485    #[test]
486    fn test_build_pivot_table_xml_unknown_field() {
487        let config = PivotTableConfig {
488            name: "Bad".to_string(),
489            source_sheet: "Data".to_string(),
490            source_range: "A1:B2".to_string(),
491            target_sheet: "Pivot".to_string(),
492            target_cell: "A1".to_string(),
493            rows: vec![PivotField {
494                name: "NonExistent".to_string(),
495            }],
496            columns: vec![],
497            data: vec![],
498        };
499        let field_names = vec!["Actual".to_string()];
500
501        let result = build_pivot_table_xml(&config, 0, &field_names);
502        assert!(result.is_err());
503        let err = result.unwrap_err().to_string();
504        assert!(err.contains("NonExistent"));
505    }
506
507    #[test]
508    fn test_build_pivot_table_xml_no_rows_or_cols() {
509        let config = PivotTableConfig {
510            name: "DataOnly".to_string(),
511            source_sheet: "Sheet1".to_string(),
512            source_range: "A1:B5".to_string(),
513            target_sheet: "Pivot".to_string(),
514            target_cell: "A1".to_string(),
515            rows: vec![],
516            columns: vec![],
517            data: vec![PivotDataField {
518                name: "Amount".to_string(),
519                function: AggregateFunction::Count,
520                display_name: None,
521            }],
522        };
523        let field_names = vec!["Amount".to_string()];
524
525        let def = build_pivot_table_xml(&config, 0, &field_names).unwrap();
526        assert!(def.row_fields.is_none());
527        assert!(def.col_fields.is_none());
528        assert!(def.data_fields.is_some());
529    }
530
531    #[test]
532    fn test_build_pivot_cache_definition() {
533        let field_names = vec![
534            "Name".to_string(),
535            "Region".to_string(),
536            "Sales".to_string(),
537        ];
538        let def = build_pivot_cache_definition("Sheet1", "A1:C10", &field_names);
539
540        assert_eq!(def.xmlns, sheetkit_xml::namespaces::SPREADSHEET_ML);
541        assert_eq!(def.cache_source.source_type, "worksheet");
542        let ws = def.cache_source.worksheet_source.unwrap();
543        assert_eq!(ws.sheet, "Sheet1");
544        assert_eq!(ws.reference, "A1:C10");
545
546        assert_eq!(def.cache_fields.count, Some(3));
547        assert_eq!(def.cache_fields.fields.len(), 3);
548        assert_eq!(def.cache_fields.fields[0].name, "Name");
549        assert_eq!(def.cache_fields.fields[1].name, "Region");
550        assert_eq!(def.cache_fields.fields[2].name, "Sales");
551
552        // Each field should have empty shared items
553        for field in &def.cache_fields.fields {
554            assert!(field.shared_items.is_some());
555            let items = field.shared_items.as_ref().unwrap();
556            assert_eq!(items.count, Some(0));
557        }
558
559        assert_eq!(def.record_count, Some(0));
560        assert!(def.r_id.is_none());
561    }
562
563    #[test]
564    fn test_build_pivot_cache_definition_empty_fields() {
565        let field_names: Vec<String> = vec![];
566        let def = build_pivot_cache_definition("Sheet1", "A1:A1", &field_names);
567        assert_eq!(def.cache_fields.count, Some(0));
568        assert!(def.cache_fields.fields.is_empty());
569    }
570
571    #[test]
572    fn test_pivot_table_info_struct() {
573        let info = PivotTableInfo {
574            name: "PT1".to_string(),
575            source_sheet: "Data".to_string(),
576            source_range: "A1:D10".to_string(),
577            target_sheet: "Report".to_string(),
578            location: "A3:E20".to_string(),
579        };
580        assert_eq!(info.name, "PT1");
581        assert_eq!(info.source_sheet, "Data");
582        assert_eq!(info.source_range, "A1:D10");
583        assert_eq!(info.target_sheet, "Report");
584        assert_eq!(info.location, "A3:E20");
585    }
586
587    #[test]
588    fn test_build_pivot_table_xml_generates_default_display_name() {
589        let config = PivotTableConfig {
590            name: "PT".to_string(),
591            source_sheet: "S".to_string(),
592            source_range: "A1:B2".to_string(),
593            target_sheet: "T".to_string(),
594            target_cell: "A1".to_string(),
595            rows: vec![],
596            columns: vec![],
597            data: vec![
598                PivotDataField {
599                    name: "Amount".to_string(),
600                    function: AggregateFunction::Sum,
601                    display_name: None,
602                },
603                PivotDataField {
604                    name: "Count".to_string(),
605                    function: AggregateFunction::Count,
606                    display_name: Some("Total Count".to_string()),
607                },
608            ],
609        };
610        let field_names = vec!["Amount".to_string(), "Count".to_string()];
611
612        let def = build_pivot_table_xml(&config, 0, &field_names).unwrap();
613        let data_fields = def.data_fields.unwrap();
614
615        // No display_name -> auto-generated
616        assert_eq!(
617            data_fields.fields[0].name,
618            Some("Sum of Amount".to_string())
619        );
620        // Custom display_name preserved
621        assert_eq!(data_fields.fields[1].name, Some("Total Count".to_string()));
622    }
623
624    #[test]
625    fn test_error_pivot_table_not_found() {
626        let err = Error::PivotTableNotFound {
627            name: "Missing".to_string(),
628        };
629        assert_eq!(err.to_string(), "pivot table 'Missing' not found");
630    }
631
632    #[test]
633    fn test_error_pivot_table_already_exists() {
634        let err = Error::PivotTableAlreadyExists {
635            name: "PT1".to_string(),
636        };
637        assert_eq!(err.to_string(), "pivot table 'PT1' already exists");
638    }
639
640    #[test]
641    fn test_error_invalid_source_range() {
642        let err = Error::InvalidSourceRange("bad range".to_string());
643        assert_eq!(err.to_string(), "invalid source range: bad range");
644    }
645}