1use crate::builder::{CoinJoinBuilder, CoinJoinTransaction};
6use crate::error::{CoinJoinError, Result};
7use crate::types::Participant;
8use serde::{Deserialize, Serialize};
9use sha2::{Digest, Sha256};
10
11#[derive(Debug, Clone, Copy, PartialEq, Eq)]
13pub enum SessionState {
14 Gathering,
16 Ready,
18 Signing,
20 Complete,
22 Failed,
24}
25
26pub struct CoinJoinSession {
28 id: [u8; 32],
30 state: SessionState,
32 participants: Vec<Participant>,
34 output_amount: u64,
36 min_participants: usize,
38 max_participants: usize,
40 transaction: Option<CoinJoinTransaction>,
42 signatures: Vec<(String, Vec<u8>)>,
44}
45
46impl CoinJoinSession {
47 pub fn new(output_amount: u64) -> Self {
49 let id = Self::generate_session_id();
50 Self {
51 id,
52 state: SessionState::Gathering,
53 participants: Vec::new(),
54 output_amount,
55 min_participants: 2,
56 max_participants: 10,
57 transaction: None,
58 signatures: Vec::new(),
59 }
60 }
61
62 fn generate_session_id() -> [u8; 32] {
64 use sha2::Sha256;
65 let mut hasher = Sha256::new();
66 hasher.update(std::time::SystemTime::now()
67 .duration_since(std::time::UNIX_EPOCH)
68 .unwrap_or_default()
69 .as_nanos()
70 .to_le_bytes());
71 let result = hasher.finalize();
72 let mut id = [0u8; 32];
73 id.copy_from_slice(&result);
74 id
75 }
76
77 pub fn id(&self) -> &[u8; 32] {
79 &self.id
80 }
81
82 pub fn state(&self) -> SessionState {
84 self.state
85 }
86
87 pub fn output_amount(&self) -> u64 {
89 self.output_amount
90 }
91
92 pub fn participant_count(&self) -> usize {
94 self.participants.len()
95 }
96
97 pub fn set_min_participants(&mut self, min: usize) {
99 self.min_participants = min;
100 }
101
102 pub fn set_max_participants(&mut self, max: usize) {
104 self.max_participants = max;
105 }
106
107 pub fn join(&mut self, participant: Participant) -> Result<JoinResponse> {
109 if self.state != SessionState::Gathering {
110 return Err(CoinJoinError::PayJoinError(
111 "Session not accepting participants".into(),
112 ));
113 }
114
115 if self.participants.len() >= self.max_participants {
116 return Err(CoinJoinError::PayJoinError("Session is full".into()));
117 }
118
119 let total_input = participant.total_input();
121 if total_input < self.output_amount {
122 return Err(CoinJoinError::InsufficientFunds {
123 needed: self.output_amount,
124 available: total_input,
125 });
126 }
127
128 if self.participants.iter().any(|p| p.id == participant.id) {
130 return Err(CoinJoinError::InvalidParticipant(
131 "Already joined".into(),
132 ));
133 }
134
135 self.participants.push(participant.clone());
136
137 if self.participants.len() >= self.min_participants {
139 self.state = SessionState::Ready;
140 }
141
142 Ok(JoinResponse {
143 session_id: self.id,
144 participant_id: participant.id,
145 position: self.participants.len() - 1,
146 current_count: self.participants.len(),
147 ready: self.state == SessionState::Ready,
148 })
149 }
150
151 pub fn build_transaction(&mut self) -> Result<&CoinJoinTransaction> {
153 if self.state != SessionState::Ready {
154 return Err(CoinJoinError::PayJoinError(
155 "Session not ready to build".into(),
156 ));
157 }
158
159 let mut builder = CoinJoinBuilder::new();
160 builder.set_output_amount(self.output_amount);
161 builder.set_min_participants(self.min_participants);
162
163 for participant in &self.participants {
164 builder.add_participant(participant.clone());
165 }
166
167 let tx = builder.build()?;
168 self.transaction = Some(tx);
169 self.state = SessionState::Signing;
170
171 Ok(self.transaction.as_ref().unwrap())
172 }
173
174 pub fn submit_signature(&mut self, participant_id: &str, signature: Vec<u8>) -> Result<()> {
176 if self.state != SessionState::Signing {
177 return Err(CoinJoinError::PayJoinError(
178 "Session not accepting signatures".into(),
179 ));
180 }
181
182 if !self.participants.iter().any(|p| p.id == participant_id) {
184 return Err(CoinJoinError::InvalidParticipant(
185 "Unknown participant".into(),
186 ));
187 }
188
189 if self.signatures.iter().any(|(id, _)| id == participant_id) {
191 return Err(CoinJoinError::PayJoinError(
192 "Signature already submitted".into(),
193 ));
194 }
195
196 self.signatures.push((participant_id.to_string(), signature));
197
198 if self.signatures.len() == self.participants.len() {
200 self.state = SessionState::Complete;
201 }
202
203 Ok(())
204 }
205
206 pub fn transaction(&self) -> Option<&CoinJoinTransaction> {
208 self.transaction.as_ref()
209 }
210
211 pub fn signatures(&self) -> &[(String, Vec<u8>)] {
213 &self.signatures
214 }
215
216 pub fn is_complete(&self) -> bool {
218 self.state == SessionState::Complete
219 }
220
221 pub fn cancel(&mut self) {
223 self.state = SessionState::Failed;
224 }
225}
226
227#[derive(Debug, Clone, Serialize, Deserialize)]
229pub struct JoinResponse {
230 pub session_id: [u8; 32],
232 pub participant_id: String,
234 pub position: usize,
236 pub current_count: usize,
238 pub ready: bool,
240}
241
242#[derive(Debug, Clone, Serialize, Deserialize)]
244pub struct SessionAnnouncement {
245 pub session_id: [u8; 32],
247 pub output_amount: u64,
249 pub current_count: usize,
251 pub min_participants: usize,
253 pub max_participants: usize,
255 pub state: String,
257}
258
259impl From<&CoinJoinSession> for SessionAnnouncement {
260 fn from(session: &CoinJoinSession) -> Self {
261 Self {
262 session_id: session.id,
263 output_amount: session.output_amount,
264 current_count: session.participants.len(),
265 min_participants: session.min_participants,
266 max_participants: session.max_participants,
267 state: format!("{:?}", session.state),
268 }
269 }
270}
271
272pub fn verify_commitment(
274 participant: &Participant,
275 commitment: &[u8; 32],
276) -> bool {
277 let computed = compute_commitment(participant);
278 &computed == commitment
279}
280
281pub fn compute_commitment(participant: &Participant) -> [u8; 32] {
283 let mut hasher = Sha256::new();
284 hasher.update(participant.id.as_bytes());
285 for input in &participant.inputs {
286 hasher.update(input.txid);
287 hasher.update(input.vout.to_le_bytes());
288 hasher.update(input.amount.to_le_bytes());
289 }
290 hasher.update(&participant.output_script);
291 let result = hasher.finalize();
292 let mut commitment = [0u8; 32];
293 commitment.copy_from_slice(&result);
294 commitment
295}
296
297#[cfg(test)]
298mod tests {
299 use super::*;
300
301 fn create_test_participant(id: &str, amount: u64) -> Participant {
302 Participant::new(
303 id,
304 vec![InputRef::from_outpoint([1u8; 32], 0, amount)],
305 vec![0x00, 0x14],
306 )
307 }
308
309 #[test]
310 fn test_session_creation() {
311 let session = CoinJoinSession::new(50_000);
312 assert_eq!(session.state(), SessionState::Gathering);
313 assert_eq!(session.output_amount(), 50_000);
314 }
315
316 #[test]
317 fn test_join_session() {
318 let mut session = CoinJoinSession::new(50_000);
319
320 let alice = create_test_participant("alice", 100_000);
321 let response = session.join(alice).unwrap();
322
323 assert_eq!(response.position, 0);
324 assert!(!response.ready);
325 assert_eq!(session.participant_count(), 1);
326 }
327
328 #[test]
329 fn test_session_ready() {
330 let mut session = CoinJoinSession::new(50_000);
331
332 session.join(create_test_participant("alice", 100_000)).unwrap();
333 let response = session.join(create_test_participant("bob", 100_000)).unwrap();
334
335 assert!(response.ready);
336 assert_eq!(session.state(), SessionState::Ready);
337 }
338
339 #[test]
340 fn test_build_transaction() {
341 let mut session = CoinJoinSession::new(50_000);
342
343 session.join(create_test_participant("alice", 100_000)).unwrap();
344 session.join(create_test_participant("bob", 100_000)).unwrap();
345
346 let tx = session.build_transaction().unwrap();
347 assert_eq!(tx.participant_count, 2);
348 assert_eq!(session.state(), SessionState::Signing);
349 }
350
351 #[test]
352 fn test_submit_signatures() {
353 let mut session = CoinJoinSession::new(50_000);
354
355 session.join(create_test_participant("alice", 100_000)).unwrap();
356 session.join(create_test_participant("bob", 100_000)).unwrap();
357 session.build_transaction().unwrap();
358
359 session.submit_signature("alice", vec![1, 2, 3]).unwrap();
360 assert!(!session.is_complete());
361
362 session.submit_signature("bob", vec![4, 5, 6]).unwrap();
363 assert!(session.is_complete());
364 }
365
366 #[test]
367 fn test_insufficient_funds() {
368 let mut session = CoinJoinSession::new(50_000);
369
370 let poor = create_test_participant("poor", 10_000);
371 let result = session.join(poor);
372
373 assert!(matches!(result, Err(CoinJoinError::InsufficientFunds { .. })));
374 }
375
376 #[test]
377 fn test_commitment() {
378 let participant = create_test_participant("alice", 100_000);
379 let commitment = compute_commitment(&participant);
380
381 assert!(verify_commitment(&participant, &commitment));
382
383 let other = create_test_participant("bob", 100_000);
385 assert!(!verify_commitment(&other, &commitment));
386 }
387}