1use crate::ant_protocol::{
30 ChunkGetRequest, ChunkGetResponse, ChunkMessage, ChunkMessageBody, ChunkPutRequest,
31 ChunkPutResponse, ChunkQuoteRequest, ChunkQuoteResponse, ProtocolError, CHUNK_PROTOCOL_ID,
32 DATA_TYPE_CHUNK, MAX_CHUNK_SIZE,
33};
34use crate::error::Result;
35use crate::payment::{PaymentVerifier, QuoteGenerator};
36use crate::storage::disk::DiskStorage;
37use bytes::Bytes;
38use std::sync::Arc;
39use tracing::{debug, info, warn};
40
41pub struct AntProtocol {
46 storage: Arc<DiskStorage>,
48 payment_verifier: Arc<PaymentVerifier>,
50 quote_generator: Arc<QuoteGenerator>,
52}
53
54impl AntProtocol {
55 #[must_use]
63 pub fn new(
64 storage: Arc<DiskStorage>,
65 payment_verifier: Arc<PaymentVerifier>,
66 quote_generator: Arc<QuoteGenerator>,
67 ) -> Self {
68 Self {
69 storage,
70 payment_verifier,
71 quote_generator,
72 }
73 }
74
75 #[must_use]
77 pub fn protocol_id(&self) -> &'static str {
78 CHUNK_PROTOCOL_ID
79 }
80
81 pub async fn handle_message(&self, data: &[u8]) -> Result<Bytes> {
95 let message = ChunkMessage::decode(data)
96 .map_err(|e| crate::error::Error::Protocol(format!("Failed to decode message: {e}")))?;
97
98 let request_id = message.request_id;
99
100 let response_body = match message.body {
101 ChunkMessageBody::PutRequest(req) => {
102 ChunkMessageBody::PutResponse(self.handle_put(req).await)
103 }
104 ChunkMessageBody::GetRequest(req) => {
105 ChunkMessageBody::GetResponse(self.handle_get(req).await)
106 }
107 ChunkMessageBody::QuoteRequest(ref req) => {
108 ChunkMessageBody::QuoteResponse(self.handle_quote(req))
109 }
110 ChunkMessageBody::PutResponse(_)
112 | ChunkMessageBody::GetResponse(_)
113 | ChunkMessageBody::QuoteResponse(_) => {
114 let error = ProtocolError::Internal("Unexpected response message".to_string());
115 ChunkMessageBody::PutResponse(ChunkPutResponse::Error(error))
116 }
117 };
118
119 let response = ChunkMessage {
120 request_id,
121 body: response_body,
122 };
123
124 response
125 .encode()
126 .map(Bytes::from)
127 .map_err(|e| crate::error::Error::Protocol(format!("Failed to encode response: {e}")))
128 }
129
130 async fn handle_put(&self, request: ChunkPutRequest) -> ChunkPutResponse {
132 let address = request.address;
133 debug!("Handling PUT request for {}", hex::encode(address));
134
135 if request.content.len() > MAX_CHUNK_SIZE {
137 return ChunkPutResponse::Error(ProtocolError::ChunkTooLarge {
138 size: request.content.len(),
139 max_size: MAX_CHUNK_SIZE,
140 });
141 }
142
143 let computed = crate::client::compute_address(&request.content);
145 if computed != address {
146 return ChunkPutResponse::Error(ProtocolError::AddressMismatch {
147 expected: address,
148 actual: computed,
149 });
150 }
151
152 if self.storage.exists(&address) {
154 debug!("Chunk {} already exists", hex::encode(address));
155 return ChunkPutResponse::AlreadyExists { address };
156 }
157
158 let payment_result = self
160 .payment_verifier
161 .verify_payment(&address, request.payment_proof.as_deref())
162 .await;
163
164 match payment_result {
165 Ok(status) if status.can_store() => {
166 }
168 Ok(_) => {
169 return ChunkPutResponse::PaymentRequired {
170 message: "Payment required for new chunk".to_string(),
171 };
172 }
173 Err(e) => {
174 return ChunkPutResponse::Error(ProtocolError::PaymentFailed(e.to_string()));
175 }
176 }
177
178 match self.storage.put(&address, &request.content).await {
180 Ok(_) => {
181 info!(
182 "Stored chunk {} ({} bytes)",
183 hex::encode(address),
184 request.content.len()
185 );
186 self.quote_generator.record_store(DATA_TYPE_CHUNK);
188 ChunkPutResponse::Success { address }
189 }
190 Err(e) => {
191 warn!("Failed to store chunk {}: {}", hex::encode(address), e);
192 ChunkPutResponse::Error(ProtocolError::StorageFailed(e.to_string()))
193 }
194 }
195 }
196
197 async fn handle_get(&self, request: ChunkGetRequest) -> ChunkGetResponse {
199 let address = request.address;
200 debug!("Handling GET request for {}", hex::encode(address));
201
202 match self.storage.get(&address).await {
203 Ok(Some(content)) => {
204 debug!(
205 "Retrieved chunk {} ({} bytes)",
206 hex::encode(address),
207 content.len()
208 );
209 ChunkGetResponse::Success { address, content }
210 }
211 Ok(None) => {
212 debug!("Chunk {} not found", hex::encode(address));
213 ChunkGetResponse::NotFound { address }
214 }
215 Err(e) => {
216 warn!("Failed to retrieve chunk {}: {}", hex::encode(address), e);
217 ChunkGetResponse::Error(ProtocolError::StorageFailed(e.to_string()))
218 }
219 }
220 }
221
222 fn handle_quote(&self, request: &ChunkQuoteRequest) -> ChunkQuoteResponse {
224 debug!(
225 "Handling quote request for {} (size: {})",
226 hex::encode(request.address),
227 request.data_size
228 );
229
230 let data_size_usize = usize::try_from(request.data_size).unwrap_or(usize::MAX);
232 if data_size_usize > MAX_CHUNK_SIZE {
233 return ChunkQuoteResponse::Error(ProtocolError::ChunkTooLarge {
234 size: data_size_usize,
235 max_size: MAX_CHUNK_SIZE,
236 });
237 }
238
239 match self
240 .quote_generator
241 .create_quote(request.address, data_size_usize, request.data_type)
242 {
243 Ok(quote) => {
244 match rmp_serde::to_vec("e) {
246 Ok(quote_bytes) => ChunkQuoteResponse::Success { quote: quote_bytes },
247 Err(e) => ChunkQuoteResponse::Error(ProtocolError::QuoteFailed(format!(
248 "Failed to serialize quote: {e}"
249 ))),
250 }
251 }
252 Err(e) => ChunkQuoteResponse::Error(ProtocolError::QuoteFailed(e.to_string())),
253 }
254 }
255
256 #[must_use]
258 pub fn storage_stats(&self) -> crate::storage::StorageStats {
259 self.storage.stats()
260 }
261
262 #[must_use]
264 pub fn payment_cache_stats(&self) -> crate::payment::CacheStats {
265 self.payment_verifier.cache_stats()
266 }
267
268 #[must_use]
270 pub fn exists(&self, address: &[u8; 32]) -> bool {
271 self.storage.exists(address)
272 }
273
274 pub async fn get_local(&self, address: &[u8; 32]) -> Result<Option<Vec<u8>>> {
280 self.storage.get(address).await
281 }
282
283 pub async fn put_local(&self, address: &[u8; 32], content: &[u8]) -> Result<bool> {
291 self.storage.put(address, content).await
292 }
293}
294
295#[cfg(test)]
296#[allow(clippy::unwrap_used, clippy::expect_used, clippy::panic)]
297mod tests {
298 use super::*;
299 use crate::payment::metrics::QuotingMetricsTracker;
300 use crate::payment::{EvmVerifierConfig, PaymentVerifierConfig};
301 use crate::storage::DiskStorageConfig;
302 use ant_evm::RewardsAddress;
303 use tempfile::TempDir;
304
305 async fn create_test_protocol() -> (AntProtocol, TempDir) {
306 let temp_dir = TempDir::new().expect("create temp dir");
307
308 let storage_config = DiskStorageConfig {
309 root_dir: temp_dir.path().to_path_buf(),
310 verify_on_read: true,
311 max_chunks: 0,
312 };
313 let storage = Arc::new(
314 DiskStorage::new(storage_config)
315 .await
316 .expect("create storage"),
317 );
318
319 let payment_config = PaymentVerifierConfig {
320 evm: EvmVerifierConfig {
321 enabled: false, ..Default::default()
323 },
324 cache_capacity: 100,
325 };
326 let payment_verifier = Arc::new(PaymentVerifier::new(payment_config));
327
328 let rewards_address = RewardsAddress::new([1u8; 20]);
329 let metrics_tracker = QuotingMetricsTracker::new(1000, 100);
330 let quote_generator = Arc::new(QuoteGenerator::new(rewards_address, metrics_tracker));
331
332 let protocol = AntProtocol::new(storage, payment_verifier, quote_generator);
333 (protocol, temp_dir)
334 }
335
336 #[tokio::test]
337 async fn test_put_and_get_chunk() {
338 let (protocol, _temp) = create_test_protocol().await;
339
340 let content = b"hello world";
341 let address = DiskStorage::compute_address(content);
342
343 let put_request = ChunkPutRequest::with_payment(
345 address,
346 content.to_vec(),
347 rmp_serde::to_vec(&ant_evm::ProofOfPayment {
348 peer_quotes: vec![],
349 })
350 .unwrap(),
351 );
352 let put_msg = ChunkMessage {
353 request_id: 1,
354 body: ChunkMessageBody::PutRequest(put_request),
355 };
356 let put_bytes = put_msg.encode().expect("encode put");
357
358 let response_bytes = protocol
360 .handle_message(&put_bytes)
361 .await
362 .expect("handle put");
363 let response = ChunkMessage::decode(&response_bytes).expect("decode response");
364
365 assert_eq!(response.request_id, 1);
366 if let ChunkMessageBody::PutResponse(ChunkPutResponse::Success { address: addr }) =
367 response.body
368 {
369 assert_eq!(addr, address);
370 } else {
371 panic!("expected PutResponse::Success, got: {response:?}");
372 }
373
374 let get_request = ChunkGetRequest::new(address);
376 let get_msg = ChunkMessage {
377 request_id: 2,
378 body: ChunkMessageBody::GetRequest(get_request),
379 };
380 let get_bytes = get_msg.encode().expect("encode get");
381
382 let response_bytes = protocol
384 .handle_message(&get_bytes)
385 .await
386 .expect("handle get");
387 let response = ChunkMessage::decode(&response_bytes).expect("decode response");
388
389 assert_eq!(response.request_id, 2);
390 if let ChunkMessageBody::GetResponse(ChunkGetResponse::Success {
391 address: addr,
392 content: data,
393 }) = response.body
394 {
395 assert_eq!(addr, address);
396 assert_eq!(data, content.to_vec());
397 } else {
398 panic!("expected GetResponse::Success");
399 }
400 }
401
402 #[tokio::test]
403 async fn test_get_not_found() {
404 let (protocol, _temp) = create_test_protocol().await;
405
406 let address = [0xAB; 32];
407 let get_request = ChunkGetRequest::new(address);
408 let get_msg = ChunkMessage {
409 request_id: 10,
410 body: ChunkMessageBody::GetRequest(get_request),
411 };
412 let get_bytes = get_msg.encode().expect("encode get");
413
414 let response_bytes = protocol
415 .handle_message(&get_bytes)
416 .await
417 .expect("handle get");
418 let response = ChunkMessage::decode(&response_bytes).expect("decode response");
419
420 assert_eq!(response.request_id, 10);
421 if let ChunkMessageBody::GetResponse(ChunkGetResponse::NotFound { address: addr }) =
422 response.body
423 {
424 assert_eq!(addr, address);
425 } else {
426 panic!("expected GetResponse::NotFound");
427 }
428 }
429
430 #[tokio::test]
431 async fn test_put_address_mismatch() {
432 let (protocol, _temp) = create_test_protocol().await;
433
434 let content = b"test content";
435 let wrong_address = [0xFF; 32]; let put_request = ChunkPutRequest::with_payment(
438 wrong_address,
439 content.to_vec(),
440 rmp_serde::to_vec(&ant_evm::ProofOfPayment {
441 peer_quotes: vec![],
442 })
443 .unwrap(),
444 );
445 let put_msg = ChunkMessage {
446 request_id: 20,
447 body: ChunkMessageBody::PutRequest(put_request),
448 };
449 let put_bytes = put_msg.encode().expect("encode put");
450
451 let response_bytes = protocol
452 .handle_message(&put_bytes)
453 .await
454 .expect("handle put");
455 let response = ChunkMessage::decode(&response_bytes).expect("decode response");
456
457 assert_eq!(response.request_id, 20);
458 if let ChunkMessageBody::PutResponse(ChunkPutResponse::Error(
459 ProtocolError::AddressMismatch { .. },
460 )) = response.body
461 {
462 } else {
464 panic!("expected AddressMismatch error, got: {response:?}");
465 }
466 }
467
468 #[tokio::test]
469 async fn test_put_chunk_too_large() {
470 let (protocol, _temp) = create_test_protocol().await;
471
472 let content = vec![0u8; MAX_CHUNK_SIZE + 1];
474 let address = DiskStorage::compute_address(&content);
475
476 let put_request = ChunkPutRequest::new(address, content);
477 let put_msg = ChunkMessage {
478 request_id: 30,
479 body: ChunkMessageBody::PutRequest(put_request),
480 };
481 let put_bytes = put_msg.encode().expect("encode put");
482
483 let response_bytes = protocol
484 .handle_message(&put_bytes)
485 .await
486 .expect("handle put");
487 let response = ChunkMessage::decode(&response_bytes).expect("decode response");
488
489 assert_eq!(response.request_id, 30);
490 if let ChunkMessageBody::PutResponse(ChunkPutResponse::Error(
491 ProtocolError::ChunkTooLarge { .. },
492 )) = response.body
493 {
494 } else {
496 panic!("expected ChunkTooLarge error");
497 }
498 }
499
500 #[tokio::test]
501 async fn test_put_already_exists() {
502 let (protocol, _temp) = create_test_protocol().await;
503
504 let content = b"duplicate content";
505 let address = DiskStorage::compute_address(content);
506
507 let put_request = ChunkPutRequest::with_payment(
509 address,
510 content.to_vec(),
511 rmp_serde::to_vec(&ant_evm::ProofOfPayment {
512 peer_quotes: vec![],
513 })
514 .unwrap(),
515 );
516 let put_msg = ChunkMessage {
517 request_id: 40,
518 body: ChunkMessageBody::PutRequest(put_request),
519 };
520 let put_bytes = put_msg.encode().expect("encode put");
521
522 let _ = protocol
523 .handle_message(&put_bytes)
524 .await
525 .expect("handle put");
526
527 let response_bytes = protocol
529 .handle_message(&put_bytes)
530 .await
531 .expect("handle put 2");
532 let response = ChunkMessage::decode(&response_bytes).expect("decode response");
533
534 assert_eq!(response.request_id, 40);
535 if let ChunkMessageBody::PutResponse(ChunkPutResponse::AlreadyExists { address: addr }) =
536 response.body
537 {
538 assert_eq!(addr, address);
539 } else {
540 panic!("expected AlreadyExists");
541 }
542 }
543
544 #[tokio::test]
545 async fn test_protocol_id() {
546 let (protocol, _temp) = create_test_protocol().await;
547 assert_eq!(protocol.protocol_id(), CHUNK_PROTOCOL_ID);
548 }
549
550 #[tokio::test]
551 async fn test_exists_and_local_access() {
552 let (protocol, _temp) = create_test_protocol().await;
553
554 let content = b"local access test";
555 let address = DiskStorage::compute_address(content);
556
557 assert!(!protocol.exists(&address));
558
559 protocol
560 .put_local(&address, content)
561 .await
562 .expect("put local");
563
564 assert!(protocol.exists(&address));
565
566 let retrieved = protocol.get_local(&address).await.expect("get local");
567 assert_eq!(retrieved, Some(content.to_vec()));
568 }
569}