1use rustledger_plugin_types::{DirectiveData, DirectiveWrapper};
8
9#[derive(Debug, Clone)]
11pub struct CategorizationRequest {
12 pub payee: Option<String>,
14 pub narration: String,
16 pub amount: Option<String>,
18 pub currency: Option<String>,
20 pub date: String,
22 pub known_accounts: Vec<String>,
24}
25
26#[derive(Debug, Clone)]
28pub struct CategorizationResponse {
29 pub account: String,
31 pub reasoning: String,
33}
34
35#[must_use]
40pub fn build_categorization_prompt(request: &CategorizationRequest) -> String {
41 let mut prompt = String::new();
42
43 prompt.push_str("Categorize this financial transaction into the most appropriate account.\n\n");
44 prompt.push_str("Transaction:\n");
45 prompt.push_str(&format!(" Date: {}\n", request.date));
46 if let Some(ref payee) = request.payee {
47 prompt.push_str(&format!(" Payee: {payee}\n"));
48 }
49 prompt.push_str(&format!(" Description: {}\n", request.narration));
50 if let Some(ref amount) = request.amount {
51 let currency = request.currency.as_deref().unwrap_or("USD");
52 prompt.push_str(&format!(" Amount: {amount} {currency}\n"));
53 }
54
55 prompt.push_str("\nAvailable accounts:\n");
56 for account in &request.known_accounts {
57 prompt.push_str(&format!(" - {account}\n"));
58 }
59
60 prompt.push_str("\nRespond with ONLY the account name on the first line, ");
61 prompt.push_str("followed by a brief reason on the second line.\n");
62 prompt.push_str("Example:\n");
63 prompt.push_str("Expenses:Groceries\n");
64 prompt.push_str("Whole Foods is a grocery store\n");
65
66 prompt
67}
68
69#[must_use]
74pub fn parse_categorization_response(response: &str) -> Option<CategorizationResponse> {
75 let mut lines = response.trim().lines();
76 let account = lines.next()?.trim().to_string();
77
78 if !account.contains(':') {
80 return None;
81 }
82
83 let reasoning = lines.next().unwrap_or("").trim().to_string();
84
85 Some(CategorizationResponse { account, reasoning })
86}
87
88#[must_use]
90pub fn extract_known_accounts(directives: &[DirectiveWrapper]) -> Vec<String> {
91 let mut accounts = std::collections::BTreeSet::new();
92
93 for d in directives {
94 match &d.data {
95 DirectiveData::Transaction(txn) => {
96 for posting in &txn.postings {
97 if posting.account.starts_with("Expenses:")
98 || posting.account.starts_with("Income:")
99 {
100 accounts.insert(posting.account.clone());
101 }
102 }
103 }
104 DirectiveData::Open(open)
105 if (open.account.starts_with("Expenses:")
106 || open.account.starts_with("Income:")) =>
107 {
108 accounts.insert(open.account.clone());
109 }
110 _ => {}
111 }
112 }
113
114 accounts.into_iter().collect()
115}
116
117#[cfg(test)]
118mod tests {
119 use super::*;
120
121 #[test]
122 fn build_prompt_basic() {
123 let request = CategorizationRequest {
124 payee: Some("WHOLE FOODS MARKET".to_string()),
125 narration: "Groceries".to_string(),
126 amount: Some("-85.23".to_string()),
127 currency: Some("USD".to_string()),
128 date: "2024-01-15".to_string(),
129 known_accounts: vec![
130 "Expenses:Groceries".to_string(),
131 "Expenses:Dining".to_string(),
132 "Expenses:Transport".to_string(),
133 ],
134 };
135 let prompt = build_categorization_prompt(&request);
136 assert!(prompt.contains("WHOLE FOODS MARKET"));
137 assert!(prompt.contains("-85.23 USD"));
138 assert!(prompt.contains("Expenses:Groceries"));
139 assert!(prompt.contains("Expenses:Dining"));
140 }
141
142 #[test]
143 fn parse_response_valid() {
144 let response = "Expenses:Groceries\nWhole Foods is a grocery store";
145 let parsed = parse_categorization_response(response).unwrap();
146 assert_eq!(parsed.account, "Expenses:Groceries");
147 assert_eq!(parsed.reasoning, "Whole Foods is a grocery store");
148 }
149
150 #[test]
151 fn parse_response_no_reasoning() {
152 let response = "Expenses:Dining\n";
153 let parsed = parse_categorization_response(response).unwrap();
154 assert_eq!(parsed.account, "Expenses:Dining");
155 assert_eq!(parsed.reasoning, "");
156 }
157
158 #[test]
159 fn parse_response_invalid() {
160 let response = "This is not an account";
161 assert!(parse_categorization_response(response).is_none());
162 }
163
164 #[test]
165 fn extract_accounts() {
166 use rustledger_plugin_types::OpenData;
167
168 let directives = vec![
169 DirectiveWrapper {
170 directive_type: "open".to_string(),
171 date: "2024-01-01".to_string(),
172 filename: None,
173 lineno: None,
174 data: DirectiveData::Open(OpenData {
175 account: "Expenses:Groceries".to_string(),
176 currencies: vec![],
177 booking: None,
178 metadata: vec![],
179 }),
180 },
181 DirectiveWrapper {
182 directive_type: "open".to_string(),
183 date: "2024-01-01".to_string(),
184 filename: None,
185 lineno: None,
186 data: DirectiveData::Open(OpenData {
187 account: "Assets:Bank".to_string(),
188 currencies: vec![],
189 booking: None,
190 metadata: vec![],
191 }),
192 },
193 ];
194 let accounts = extract_known_accounts(&directives);
195 assert_eq!(accounts, vec!["Expenses:Groceries"]);
196 }
198}