rustywallet_coinjoin/
payjoin.rs

1//! PayJoin (BIP78) implementation.
2//!
3//! PayJoin is a privacy technique where the receiver contributes inputs
4//! to a transaction, breaking the common-input-ownership heuristic.
5
6use crate::error::{CoinJoinError, Result};
7use crate::types::{InputRef, OutputDef};
8use serde::{Deserialize, Serialize};
9
10/// PayJoin request from receiver to sender.
11#[derive(Debug, Clone, Serialize, Deserialize)]
12pub struct PayJoinRequest {
13    /// Original PSBT (base64 encoded)
14    pub original_psbt: String,
15    /// Receiver's additional inputs
16    pub receiver_inputs: Vec<InputRef>,
17    /// Modified outputs
18    pub outputs: Vec<OutputDef>,
19    /// Fee contribution from receiver
20    pub fee_contribution: u64,
21    /// Minimum fee rate (sat/vB)
22    pub min_fee_rate: Option<f64>,
23}
24
25impl PayJoinRequest {
26    /// Create a new PayJoin request.
27    pub fn new(original_psbt: String) -> Self {
28        Self {
29            original_psbt,
30            receiver_inputs: Vec::new(),
31            outputs: Vec::new(),
32            fee_contribution: 0,
33            min_fee_rate: None,
34        }
35    }
36
37    /// Add receiver input.
38    pub fn add_input(&mut self, input: InputRef) {
39        self.receiver_inputs.push(input);
40    }
41
42    /// Set fee contribution.
43    pub fn set_fee_contribution(&mut self, amount: u64) {
44        self.fee_contribution = amount;
45    }
46
47    /// Set minimum fee rate.
48    pub fn set_min_fee_rate(&mut self, rate: f64) {
49        self.min_fee_rate = Some(rate);
50    }
51
52    /// Total receiver input amount.
53    pub fn receiver_input_total(&self) -> u64 {
54        self.receiver_inputs.iter().map(|i| i.amount).sum()
55    }
56}
57
58/// PayJoin receiver - creates PayJoin proposals.
59pub struct PayJoinReceiver {
60    /// Receiver's output script
61    output_script: Vec<u8>,
62    /// Expected payment amount
63    expected_amount: u64,
64    /// Available UTXOs
65    utxos: Vec<InputRef>,
66}
67
68impl PayJoinReceiver {
69    /// Create a new PayJoin receiver.
70    pub fn new(output_script: Vec<u8>, expected_amount: u64) -> Self {
71        Self {
72            output_script,
73            expected_amount,
74            utxos: Vec::new(),
75        }
76    }
77
78    /// Add available UTXO.
79    pub fn add_utxo(&mut self, utxo: InputRef) {
80        self.utxos.push(utxo);
81    }
82
83    /// Add multiple UTXOs.
84    pub fn add_utxos(&mut self, utxos: impl IntoIterator<Item = InputRef>) {
85        self.utxos.extend(utxos);
86    }
87
88    /// Create PayJoin request from original PSBT.
89    ///
90    /// The receiver selects inputs to contribute and modifies the transaction.
91    pub fn create_request(&self, original_psbt: &str) -> Result<PayJoinRequest> {
92        if self.utxos.is_empty() {
93            return Err(CoinJoinError::PayJoinError(
94                "No UTXOs available for PayJoin".into(),
95            ));
96        }
97
98        let mut request = PayJoinRequest::new(original_psbt.to_string());
99
100        // Select UTXOs to contribute (simple strategy: use first available)
101        // In production, use more sophisticated selection
102        let selected = self.select_inputs()?;
103        for input in selected {
104            request.add_input(input);
105        }
106
107        // Receiver's output (payment + contributed inputs)
108        let receiver_total = request.receiver_input_total();
109        let receiver_output = OutputDef::new(
110            self.expected_amount + receiver_total,
111            self.output_script.clone(),
112        );
113        request.outputs.push(receiver_output);
114
115        Ok(request)
116    }
117
118    /// Select inputs to contribute.
119    fn select_inputs(&self) -> Result<Vec<InputRef>> {
120        // Simple selection: contribute one input if available
121        if let Some(utxo) = self.utxos.first() {
122            Ok(vec![utxo.clone()])
123        } else {
124            Err(CoinJoinError::PayJoinError("No UTXOs to select".into()))
125        }
126    }
127
128    /// Verify a PayJoin proposal is valid.
129    pub fn verify_proposal(&self, request: &PayJoinRequest) -> Result<()> {
130        // Check receiver inputs are from our UTXOs
131        for input in &request.receiver_inputs {
132            let found = self.utxos.iter().any(|u| u.txid == input.txid && u.vout == input.vout);
133            if !found {
134                return Err(CoinJoinError::VerificationFailed(
135                    "Unknown input in proposal".into(),
136                ));
137            }
138        }
139
140        Ok(())
141    }
142}
143
144/// PayJoin sender - processes PayJoin requests.
145pub struct PayJoinSender {
146    /// Sender's UTXOs
147    utxos: Vec<InputRef>,
148    /// Maximum additional fee willing to pay
149    max_additional_fee: u64,
150}
151
152impl PayJoinSender {
153    /// Create a new PayJoin sender.
154    pub fn new() -> Self {
155        Self {
156            utxos: Vec::new(),
157            max_additional_fee: 10_000, // Default 10k sats
158        }
159    }
160
161    /// Add sender UTXO.
162    pub fn add_utxo(&mut self, utxo: InputRef) {
163        self.utxos.push(utxo);
164    }
165
166    /// Set maximum additional fee.
167    pub fn set_max_additional_fee(&mut self, amount: u64) {
168        self.max_additional_fee = amount;
169    }
170
171    /// Process a PayJoin request.
172    ///
173    /// Validates the request and creates the final PayJoin PSBT.
174    pub fn process_request(&self, request: &PayJoinRequest) -> Result<PayJoinProposal> {
175        // Validate request
176        self.validate_request(request)?;
177
178        // Create proposal with combined inputs
179        let mut all_inputs = Vec::new();
180
181        // Add sender inputs (from original PSBT - simplified)
182        all_inputs.extend(self.utxos.clone());
183
184        // Add receiver inputs
185        all_inputs.extend(request.receiver_inputs.clone());
186
187        let proposal = PayJoinProposal {
188            inputs: all_inputs,
189            outputs: request.outputs.clone(),
190            original_psbt: request.original_psbt.clone(),
191            fee_contribution: request.fee_contribution,
192        };
193
194        Ok(proposal)
195    }
196
197    /// Validate a PayJoin request.
198    fn validate_request(&self, request: &PayJoinRequest) -> Result<()> {
199        // Check receiver added inputs
200        if request.receiver_inputs.is_empty() {
201            return Err(CoinJoinError::PayJoinError(
202                "Receiver must contribute at least one input".into(),
203            ));
204        }
205
206        // Check fee contribution is reasonable
207        if request.fee_contribution > self.max_additional_fee {
208            return Err(CoinJoinError::PayJoinError(format!(
209                "Fee contribution {} exceeds maximum {}",
210                request.fee_contribution, self.max_additional_fee
211            )));
212        }
213
214        Ok(())
215    }
216
217    /// Verify the final PayJoin transaction.
218    pub fn verify_final(&self, proposal: &PayJoinProposal) -> Result<()> {
219        // Verify our inputs are included
220        for utxo in &self.utxos {
221            let found = proposal
222                .inputs
223                .iter()
224                .any(|i| i.txid == utxo.txid && i.vout == utxo.vout);
225            if !found {
226                return Err(CoinJoinError::VerificationFailed(
227                    "Sender input missing from proposal".into(),
228                ));
229            }
230        }
231
232        Ok(())
233    }
234}
235
236impl Default for PayJoinSender {
237    fn default() -> Self {
238        Self::new()
239    }
240}
241
242/// PayJoin proposal ready for signing.
243#[derive(Debug, Clone, Serialize, Deserialize)]
244pub struct PayJoinProposal {
245    /// All inputs (sender + receiver)
246    pub inputs: Vec<InputRef>,
247    /// All outputs
248    pub outputs: Vec<OutputDef>,
249    /// Original PSBT for reference
250    pub original_psbt: String,
251    /// Fee contribution from receiver
252    pub fee_contribution: u64,
253}
254
255impl PayJoinProposal {
256    /// Total input amount.
257    pub fn total_input(&self) -> u64 {
258        self.inputs.iter().map(|i| i.amount).sum()
259    }
260
261    /// Total output amount.
262    pub fn total_output(&self) -> u64 {
263        self.outputs.iter().map(|o| o.amount).sum()
264    }
265
266    /// Implied fee.
267    pub fn fee(&self) -> u64 {
268        self.total_input().saturating_sub(self.total_output())
269    }
270
271    /// Number of inputs.
272    pub fn input_count(&self) -> usize {
273        self.inputs.len()
274    }
275
276    /// Number of outputs.
277    pub fn output_count(&self) -> usize {
278        self.outputs.len()
279    }
280}
281
282#[cfg(test)]
283mod tests {
284    use super::*;
285
286    #[test]
287    fn test_payjoin_request() {
288        let mut request = PayJoinRequest::new("cHNidP8...".into());
289        request.add_input(InputRef::from_outpoint([1u8; 32], 0, 50_000));
290        request.set_fee_contribution(1000);
291
292        assert_eq!(request.receiver_input_total(), 50_000);
293        assert_eq!(request.fee_contribution, 1000);
294    }
295
296    #[test]
297    fn test_payjoin_receiver() {
298        let mut receiver = PayJoinReceiver::new(vec![0x00, 0x14], 100_000);
299        receiver.add_utxo(InputRef::from_outpoint([1u8; 32], 0, 50_000));
300
301        let request = receiver.create_request("cHNidP8...").unwrap();
302        assert_eq!(request.receiver_inputs.len(), 1);
303    }
304
305    #[test]
306    fn test_payjoin_sender() {
307        let mut sender = PayJoinSender::new();
308        sender.add_utxo(InputRef::from_outpoint([2u8; 32], 0, 100_000));
309
310        let mut request = PayJoinRequest::new("cHNidP8...".into());
311        request.add_input(InputRef::from_outpoint([1u8; 32], 0, 50_000));
312        request.outputs.push(OutputDef::new(140_000, vec![0x00, 0x14]));
313
314        let proposal = sender.process_request(&request).unwrap();
315        assert_eq!(proposal.input_count(), 2);
316    }
317
318    #[test]
319    fn test_payjoin_proposal() {
320        let proposal = PayJoinProposal {
321            inputs: vec![
322                InputRef::from_outpoint([1u8; 32], 0, 100_000),
323                InputRef::from_outpoint([2u8; 32], 0, 50_000),
324            ],
325            outputs: vec![OutputDef::new(140_000, vec![0x00, 0x14])],
326            original_psbt: "cHNidP8...".into(),
327            fee_contribution: 1000,
328        };
329
330        assert_eq!(proposal.total_input(), 150_000);
331        assert_eq!(proposal.total_output(), 140_000);
332        assert_eq!(proposal.fee(), 10_000);
333    }
334}