1use crate::error::{CoinJoinError, Result};
7use crate::types::{InputRef, OutputDef};
8use serde::{Deserialize, Serialize};
9
10#[derive(Debug, Clone, Serialize, Deserialize)]
12pub struct PayJoinRequest {
13 pub original_psbt: String,
15 pub receiver_inputs: Vec<InputRef>,
17 pub outputs: Vec<OutputDef>,
19 pub fee_contribution: u64,
21 pub min_fee_rate: Option<f64>,
23}
24
25impl PayJoinRequest {
26 pub fn new(original_psbt: String) -> Self {
28 Self {
29 original_psbt,
30 receiver_inputs: Vec::new(),
31 outputs: Vec::new(),
32 fee_contribution: 0,
33 min_fee_rate: None,
34 }
35 }
36
37 pub fn add_input(&mut self, input: InputRef) {
39 self.receiver_inputs.push(input);
40 }
41
42 pub fn set_fee_contribution(&mut self, amount: u64) {
44 self.fee_contribution = amount;
45 }
46
47 pub fn set_min_fee_rate(&mut self, rate: f64) {
49 self.min_fee_rate = Some(rate);
50 }
51
52 pub fn receiver_input_total(&self) -> u64 {
54 self.receiver_inputs.iter().map(|i| i.amount).sum()
55 }
56}
57
58pub struct PayJoinReceiver {
60 output_script: Vec<u8>,
62 expected_amount: u64,
64 utxos: Vec<InputRef>,
66}
67
68impl PayJoinReceiver {
69 pub fn new(output_script: Vec<u8>, expected_amount: u64) -> Self {
71 Self {
72 output_script,
73 expected_amount,
74 utxos: Vec::new(),
75 }
76 }
77
78 pub fn add_utxo(&mut self, utxo: InputRef) {
80 self.utxos.push(utxo);
81 }
82
83 pub fn add_utxos(&mut self, utxos: impl IntoIterator<Item = InputRef>) {
85 self.utxos.extend(utxos);
86 }
87
88 pub fn create_request(&self, original_psbt: &str) -> Result<PayJoinRequest> {
92 if self.utxos.is_empty() {
93 return Err(CoinJoinError::PayJoinError(
94 "No UTXOs available for PayJoin".into(),
95 ));
96 }
97
98 let mut request = PayJoinRequest::new(original_psbt.to_string());
99
100 let selected = self.select_inputs()?;
103 for input in selected {
104 request.add_input(input);
105 }
106
107 let receiver_total = request.receiver_input_total();
109 let receiver_output = OutputDef::new(
110 self.expected_amount + receiver_total,
111 self.output_script.clone(),
112 );
113 request.outputs.push(receiver_output);
114
115 Ok(request)
116 }
117
118 fn select_inputs(&self) -> Result<Vec<InputRef>> {
120 if let Some(utxo) = self.utxos.first() {
122 Ok(vec![utxo.clone()])
123 } else {
124 Err(CoinJoinError::PayJoinError("No UTXOs to select".into()))
125 }
126 }
127
128 pub fn verify_proposal(&self, request: &PayJoinRequest) -> Result<()> {
130 for input in &request.receiver_inputs {
132 let found = self.utxos.iter().any(|u| u.txid == input.txid && u.vout == input.vout);
133 if !found {
134 return Err(CoinJoinError::VerificationFailed(
135 "Unknown input in proposal".into(),
136 ));
137 }
138 }
139
140 Ok(())
141 }
142}
143
144pub struct PayJoinSender {
146 utxos: Vec<InputRef>,
148 max_additional_fee: u64,
150}
151
152impl PayJoinSender {
153 pub fn new() -> Self {
155 Self {
156 utxos: Vec::new(),
157 max_additional_fee: 10_000, }
159 }
160
161 pub fn add_utxo(&mut self, utxo: InputRef) {
163 self.utxos.push(utxo);
164 }
165
166 pub fn set_max_additional_fee(&mut self, amount: u64) {
168 self.max_additional_fee = amount;
169 }
170
171 pub fn process_request(&self, request: &PayJoinRequest) -> Result<PayJoinProposal> {
175 self.validate_request(request)?;
177
178 let mut all_inputs = Vec::new();
180
181 all_inputs.extend(self.utxos.clone());
183
184 all_inputs.extend(request.receiver_inputs.clone());
186
187 let proposal = PayJoinProposal {
188 inputs: all_inputs,
189 outputs: request.outputs.clone(),
190 original_psbt: request.original_psbt.clone(),
191 fee_contribution: request.fee_contribution,
192 };
193
194 Ok(proposal)
195 }
196
197 fn validate_request(&self, request: &PayJoinRequest) -> Result<()> {
199 if request.receiver_inputs.is_empty() {
201 return Err(CoinJoinError::PayJoinError(
202 "Receiver must contribute at least one input".into(),
203 ));
204 }
205
206 if request.fee_contribution > self.max_additional_fee {
208 return Err(CoinJoinError::PayJoinError(format!(
209 "Fee contribution {} exceeds maximum {}",
210 request.fee_contribution, self.max_additional_fee
211 )));
212 }
213
214 Ok(())
215 }
216
217 pub fn verify_final(&self, proposal: &PayJoinProposal) -> Result<()> {
219 for utxo in &self.utxos {
221 let found = proposal
222 .inputs
223 .iter()
224 .any(|i| i.txid == utxo.txid && i.vout == utxo.vout);
225 if !found {
226 return Err(CoinJoinError::VerificationFailed(
227 "Sender input missing from proposal".into(),
228 ));
229 }
230 }
231
232 Ok(())
233 }
234}
235
236impl Default for PayJoinSender {
237 fn default() -> Self {
238 Self::new()
239 }
240}
241
242#[derive(Debug, Clone, Serialize, Deserialize)]
244pub struct PayJoinProposal {
245 pub inputs: Vec<InputRef>,
247 pub outputs: Vec<OutputDef>,
249 pub original_psbt: String,
251 pub fee_contribution: u64,
253}
254
255impl PayJoinProposal {
256 pub fn total_input(&self) -> u64 {
258 self.inputs.iter().map(|i| i.amount).sum()
259 }
260
261 pub fn total_output(&self) -> u64 {
263 self.outputs.iter().map(|o| o.amount).sum()
264 }
265
266 pub fn fee(&self) -> u64 {
268 self.total_input().saturating_sub(self.total_output())
269 }
270
271 pub fn input_count(&self) -> usize {
273 self.inputs.len()
274 }
275
276 pub fn output_count(&self) -> usize {
278 self.outputs.len()
279 }
280}
281
282#[cfg(test)]
283mod tests {
284 use super::*;
285
286 #[test]
287 fn test_payjoin_request() {
288 let mut request = PayJoinRequest::new("cHNidP8...".into());
289 request.add_input(InputRef::from_outpoint([1u8; 32], 0, 50_000));
290 request.set_fee_contribution(1000);
291
292 assert_eq!(request.receiver_input_total(), 50_000);
293 assert_eq!(request.fee_contribution, 1000);
294 }
295
296 #[test]
297 fn test_payjoin_receiver() {
298 let mut receiver = PayJoinReceiver::new(vec![0x00, 0x14], 100_000);
299 receiver.add_utxo(InputRef::from_outpoint([1u8; 32], 0, 50_000));
300
301 let request = receiver.create_request("cHNidP8...").unwrap();
302 assert_eq!(request.receiver_inputs.len(), 1);
303 }
304
305 #[test]
306 fn test_payjoin_sender() {
307 let mut sender = PayJoinSender::new();
308 sender.add_utxo(InputRef::from_outpoint([2u8; 32], 0, 100_000));
309
310 let mut request = PayJoinRequest::new("cHNidP8...".into());
311 request.add_input(InputRef::from_outpoint([1u8; 32], 0, 50_000));
312 request.outputs.push(OutputDef::new(140_000, vec![0x00, 0x14]));
313
314 let proposal = sender.process_request(&request).unwrap();
315 assert_eq!(proposal.input_count(), 2);
316 }
317
318 #[test]
319 fn test_payjoin_proposal() {
320 let proposal = PayJoinProposal {
321 inputs: vec![
322 InputRef::from_outpoint([1u8; 32], 0, 100_000),
323 InputRef::from_outpoint([2u8; 32], 0, 50_000),
324 ],
325 outputs: vec![OutputDef::new(140_000, vec![0x00, 0x14])],
326 original_psbt: "cHNidP8...".into(),
327 fee_contribution: 1000,
328 };
329
330 assert_eq!(proposal.total_input(), 150_000);
331 assert_eq!(proposal.total_output(), 140_000);
332 assert_eq!(proposal.fee(), 10_000);
333 }
334}