Skip to main content

tap_aggregator/
server.rs

1// Copyright 2023-, Semiotic AI, Inc.
2// SPDX-License-Identifier: Apache-2.0
3
4use std::{collections::HashSet, fmt::Debug, str::FromStr};
5
6use anyhow::Result;
7use axum::{error_handling::HandleError, routing::post_service, BoxError, Router};
8use hyper::StatusCode;
9use jsonrpsee::{
10    proc_macros::rpc,
11    server::{ServerBuilder, ServerConfig, ServerHandle, TowerService},
12};
13use lazy_static::lazy_static;
14use log::{error, info};
15use prometheus::{register_counter, register_int_counter, Counter, IntCounter};
16use tap_core::signed_message::Eip712SignedMessage;
17use tap_graph::{Receipt, ReceiptAggregateVoucher, SignedReceipt};
18use thegraph_core::alloy::{
19    dyn_abi::Eip712Domain, primitives::Address, signers::local::PrivateKeySigner,
20};
21use tokio::{net::TcpListener, signal, task::JoinHandle};
22use tonic::{codec::CompressionEncoding, service::Routes, Request, Response, Status};
23use tower::{layer::util::Identity, make::Shared};
24
25use crate::{
26    aggregator,
27    api_versioning::{
28        tap_rpc_api_versions_info, TapRpcApiVersion, TapRpcApiVersionsInfo,
29        TAP_RPC_API_VERSIONS_DEPRECATED,
30    },
31    error_codes::{JsonRpcErrorCode, JsonRpcWarningCode},
32    grpc::{v1, v2},
33    jsonrpsee_helpers::{JsonRpcError, JsonRpcResponse, JsonRpcResult, JsonRpcWarning},
34};
35
36// Register the metrics into the global metrics registry.
37lazy_static! {
38    static ref AGGREGATION_SUCCESS_COUNTER: IntCounter = register_int_counter!(
39        "aggregation_success_count",
40        "Number of successful receipt aggregation requests."
41    )
42    .unwrap();
43    static ref AGGREGATION_FAILURE_COUNTER: IntCounter = register_int_counter!(
44        "aggregation_failure_count",
45        "Number of failed receipt aggregation requests (for any reason)."
46    )
47    .unwrap();
48    static ref DEPRECATION_WARNING_COUNT: IntCounter = register_int_counter!(
49        "deprecation_warning_count",
50        "Number of deprecation warnings sent to clients."
51    )
52    .unwrap();
53    static ref VERSION_ERROR_COUNT: IntCounter = register_int_counter!(
54        "version_error_count",
55        "Number of API version errors sent to clients."
56    )
57    .unwrap();
58    static ref TOTAL_AGGREGATED_RECEIPTS: IntCounter = register_int_counter!(
59        "total_aggregated_receipts",
60        "Total number of receipts successfully aggregated."
61    )
62    .unwrap();
63// Using float for the GRT value because it can somewhat easily exceed the maximum value of int64.
64    static ref TOTAL_GRT_AGGREGATED: Counter = register_counter!(
65        "total_aggregated_grt",
66        "Total successfully aggregated GRT value (wei)."
67    )
68    .unwrap();
69}
70
71/// Generates the `RpcServer` trait that is used to define the JSON-RPC API.
72///
73/// Note that because of the way the `rpc` macro works, we cannot document the RpcServer trait here.
74/// (So even this very docstring will not appear in the generated documentation...)
75/// As a result, we document the JSON-RPC API in the `tap_aggregator/README.md` file.
76/// Do not forget to update the documentation there if you make any changes to the JSON-RPC API.
77#[rpc(server)]
78pub trait Rpc {
79    /// Returns the versions of the TAP JSON-RPC API implemented by this server.
80    #[method(name = "api_versions")]
81    fn api_versions(&self) -> JsonRpcResult<TapRpcApiVersionsInfo>;
82
83    /// Returns the EIP-712 domain separator information used by this server.
84    /// The client is able to verify the signatures of the receipts and receipt aggregate vouchers.
85    #[method(name = "eip712domain_info")]
86    fn eip712_domain_info(&self) -> JsonRpcResult<Eip712Domain>;
87
88    /// Returns the v2 EIP-712 domain separator information used by this server.
89    /// The client is able to verify the signatures of the receipts and receipt aggregate vouchers.
90    #[cfg(feature = "v2")]
91    #[method(name = "eip712domain_info_v2")]
92    fn eip712_domain_info_v2(&self) -> JsonRpcResult<Eip712Domain>;
93
94    /// Aggregates the given receipts into a receipt aggregate voucher.
95    /// Returns an error if the user expected API version is not supported.
96    #[method(name = "aggregate_receipts")]
97    fn aggregate_receipts(
98        &self,
99        api_version: String,
100        receipts: Vec<Eip712SignedMessage<Receipt>>,
101        previous_rav: Option<Eip712SignedMessage<ReceiptAggregateVoucher>>,
102    ) -> JsonRpcResult<Eip712SignedMessage<ReceiptAggregateVoucher>>;
103
104    /// Aggregates the given v2 receipts into a v2 receipt aggregate voucher.
105    /// Uses the Horizon protocol for collection-based aggregation.
106    #[cfg(feature = "v2")]
107    #[method(name = "aggregate_receipts_v2")]
108    fn aggregate_receipts_v2(
109        &self,
110        api_version: String,
111        receipts: Vec<Eip712SignedMessage<tap_graph::v2::Receipt>>,
112        previous_rav: Option<Eip712SignedMessage<tap_graph::v2::ReceiptAggregateVoucher>>,
113    ) -> JsonRpcResult<Eip712SignedMessage<tap_graph::v2::ReceiptAggregateVoucher>>;
114}
115
116#[derive(Clone)]
117struct RpcImpl {
118    wallet: PrivateKeySigner,
119    accepted_addresses: HashSet<Address>,
120    domain_separator: Eip712Domain,
121    #[cfg(feature = "v2")]
122    domain_separator_v2: Eip712Domain,
123    kafka: Option<rdkafka::producer::ThreadedProducer<rdkafka::producer::DefaultProducerContext>>,
124}
125
126/// Helper method that checks if the given API version is supported.
127/// Returns an error if the API version is not supported.
128fn parse_api_version(api_version: &str) -> Result<TapRpcApiVersion, JsonRpcError> {
129    TapRpcApiVersion::from_str(api_version).map_err(|_| {
130        jsonrpsee::types::ErrorObject::owned(
131            JsonRpcErrorCode::InvalidVersion as i32,
132            format!("Unsupported API version: \"{api_version}\"."),
133            Some(tap_rpc_api_versions_info()),
134        )
135    })
136}
137
138/// Helper method that checks if the given API version has a deprecation warning.
139/// Returns a warning if the API version is deprecated.
140fn check_api_version_deprecation(api_version: &TapRpcApiVersion) -> Option<JsonRpcWarning> {
141    if TAP_RPC_API_VERSIONS_DEPRECATED.contains(api_version) {
142        Some(JsonRpcWarning::new(
143            JsonRpcWarningCode::DeprecatedVersion as i32,
144            format!(
145                "The API version {api_version} will be deprecated. \
146                Please check https://github.com/semiotic-ai/timeline_aggregation_protocol for more information."
147            ),
148            Some(tap_rpc_api_versions_info()),
149        ))
150    } else {
151        None
152    }
153}
154
155fn aggregate_receipts_(
156    api_version: String,
157    wallet: &PrivateKeySigner,
158    accepted_addresses: &HashSet<Address>,
159    domain_separator: &Eip712Domain,
160    receipts: Vec<Eip712SignedMessage<Receipt>>,
161    previous_rav: Option<Eip712SignedMessage<ReceiptAggregateVoucher>>,
162) -> JsonRpcResult<Eip712SignedMessage<ReceiptAggregateVoucher>> {
163    use crate::receipt_classifier::validate_v1_receipt_batch;
164
165    // Return an error if the API version is not supported.
166    let api_version = match parse_api_version(api_version.as_str()) {
167        Ok(v) => v,
168        Err(e) => {
169            VERSION_ERROR_COUNT.inc();
170            return Err(e);
171        }
172    };
173
174    // Add a warning if the API version is to be deprecated.
175    let mut warnings: Vec<JsonRpcWarning> = Vec::new();
176    if let Some(w) = check_api_version_deprecation(&api_version) {
177        warnings.push(w);
178        DEPRECATION_WARNING_COUNT.inc();
179    }
180
181    // This endpoint handles v1 receipts for legacy aggregation
182    // V2 receipts are handled through the aggregate_receipts_v2 endpoint
183    if let Err(e) = validate_v1_receipt_batch(&receipts) {
184        return Err(jsonrpsee::types::ErrorObject::owned(
185            JsonRpcErrorCode::Aggregation as i32,
186            e.to_string(),
187            None::<()>,
188        ));
189    }
190
191    log::debug!("Processing V1 receipts");
192
193    // Execute v1 aggregation
194    let res = aggregator::v1::check_and_aggregate_receipts(
195        domain_separator,
196        &receipts,
197        previous_rav,
198        wallet,
199        accepted_addresses,
200    );
201
202    // Handle aggregation error
203    match res {
204        Ok(res) => Ok(JsonRpcResponse::warn(res, warnings)),
205        Err(e) => Err(jsonrpsee::types::ErrorObject::owned(
206            JsonRpcErrorCode::Aggregation as i32,
207            e.to_string(),
208            None::<()>,
209        )),
210    }
211}
212
213#[cfg(feature = "v2")]
214fn aggregate_receipts_v2_(
215    api_version: String,
216    wallet: &PrivateKeySigner,
217    accepted_addresses: &HashSet<Address>,
218    domain_separator: &Eip712Domain,
219    receipts: Vec<Eip712SignedMessage<tap_graph::v2::Receipt>>,
220    previous_rav: Option<Eip712SignedMessage<tap_graph::v2::ReceiptAggregateVoucher>>,
221) -> JsonRpcResult<Eip712SignedMessage<tap_graph::v2::ReceiptAggregateVoucher>> {
222    use crate::receipt_classifier::validate_v2_receipt_batch;
223
224    // Return an error if the API version is not supported.
225    let api_version = match parse_api_version(api_version.as_str()) {
226        Ok(v) => v,
227        Err(e) => {
228            VERSION_ERROR_COUNT.inc();
229            return Err(e);
230        }
231    };
232
233    // Add a warning if the API version is to be deprecated.
234    let mut warnings: Vec<JsonRpcWarning> = Vec::new();
235    if let Some(w) = check_api_version_deprecation(&api_version) {
236        warnings.push(w);
237        DEPRECATION_WARNING_COUNT.inc();
238    }
239
240    // Validate v2 receipt batch for horizon processing
241    if let Err(e) = validate_v2_receipt_batch(&receipts) {
242        return Err(jsonrpsee::types::ErrorObject::owned(
243            JsonRpcErrorCode::Aggregation as i32,
244            e.to_string(),
245            None::<()>,
246        ));
247    }
248
249    log::debug!("Processing V2 receipts with Horizon protocol");
250
251    // Execute v2 aggregation
252    let res = aggregator::v2::check_and_aggregate_receipts(
253        domain_separator,
254        &receipts,
255        previous_rav,
256        wallet,
257        accepted_addresses,
258    );
259
260    // Handle aggregation error
261    match res {
262        Ok(res) => Ok(JsonRpcResponse::warn(res, warnings)),
263        Err(e) => Err(jsonrpsee::types::ErrorObject::owned(
264            JsonRpcErrorCode::Aggregation as i32,
265            e.to_string(),
266            None::<()>,
267        )),
268    }
269}
270
271#[tonic::async_trait]
272impl v1::tap_aggregator_server::TapAggregator for RpcImpl {
273    async fn aggregate_receipts(
274        &self,
275        request: Request<v1::RavRequest>,
276    ) -> Result<Response<v1::RavResponse>, Status> {
277        let rav_request = request.into_inner();
278        let receipts: Vec<SignedReceipt> = rav_request
279            .receipts
280            .into_iter()
281            .map(TryFrom::try_from)
282            .collect::<Result<_, _>>()
283            .map_err(|_| Status::invalid_argument("Error while getting list of signed_receipts"))?;
284
285        let previous_rav = rav_request
286            .previous_rav
287            .map(TryFrom::try_from)
288            .transpose()
289            .map_err(|_| Status::invalid_argument("Error while getting previous rav"))?;
290
291        let receipts_grt: u128 = receipts.iter().map(|r| r.message.value).sum();
292        let receipts_count: u64 = receipts.len() as u64;
293
294        match aggregator::v1::check_and_aggregate_receipts(
295            &self.domain_separator,
296            receipts.as_slice(),
297            previous_rav,
298            &self.wallet,
299            &self.accepted_addresses,
300        ) {
301            Ok(res) => {
302                TOTAL_GRT_AGGREGATED.inc_by(receipts_grt as f64);
303                TOTAL_AGGREGATED_RECEIPTS.inc_by(receipts_count);
304                AGGREGATION_SUCCESS_COUNTER.inc();
305                if let Some(kafka) = &self.kafka {
306                    produce_kafka_records(
307                        kafka,
308                        &self.wallet.address(),
309                        &res.message.allocationId,
310                        res.message.valueAggregate,
311                    );
312                }
313
314                let response = v1::RavResponse {
315                    rav: Some(res.into()),
316                };
317                Ok(Response::new(response))
318            }
319            Err(e) => {
320                AGGREGATION_FAILURE_COUNTER.inc();
321                Err(Status::failed_precondition(e.to_string()))
322            }
323        }
324    }
325}
326
327#[tonic::async_trait]
328impl v2::tap_aggregator_server::TapAggregator for RpcImpl {
329    async fn aggregate_receipts(
330        &self,
331        request: Request<v2::RavRequest>,
332    ) -> Result<Response<v2::RavResponse>, Status> {
333        let rav_request = request.into_inner();
334        let receipts: Vec<tap_graph::v2::SignedReceipt> = rav_request
335            .receipts
336            .into_iter()
337            .map(TryFrom::try_from)
338            .collect::<Result<_, _>>()
339            .map_err(|_| Status::invalid_argument("Error while getting list of signed_receipts"))?;
340
341        let previous_rav = rav_request
342            .previous_rav
343            .map(TryFrom::try_from)
344            .transpose()
345            .map_err(|_| Status::invalid_argument("Error while getting previous rav"))?;
346
347        let receipts_grt: u128 = receipts.iter().map(|r| r.message.value).sum();
348        let receipts_count: u64 = receipts.len() as u64;
349
350        match aggregator::v2::check_and_aggregate_receipts(
351            &self.domain_separator_v2,
352            receipts.as_slice(),
353            previous_rav,
354            &self.wallet,
355            &self.accepted_addresses,
356        ) {
357            Ok(res) => {
358                TOTAL_GRT_AGGREGATED.inc_by(receipts_grt as f64);
359                TOTAL_AGGREGATED_RECEIPTS.inc_by(receipts_count);
360                AGGREGATION_SUCCESS_COUNTER.inc();
361                if let Some(kafka) = &self.kafka {
362                    produce_kafka_records(
363                        kafka,
364                        &res.message.payer,
365                        &res.message.collectionId,
366                        res.message.valueAggregate,
367                    );
368                }
369
370                let response = v2::RavResponse {
371                    rav: Some(res.into()),
372                };
373                Ok(Response::new(response))
374            }
375            Err(e) => {
376                AGGREGATION_FAILURE_COUNTER.inc();
377                Err(Status::failed_precondition(e.to_string()))
378            }
379        }
380    }
381}
382
383impl RpcServer for RpcImpl {
384    fn api_versions(&self) -> JsonRpcResult<TapRpcApiVersionsInfo> {
385        Ok(JsonRpcResponse::ok(tap_rpc_api_versions_info()))
386    }
387
388    fn eip712_domain_info(&self) -> JsonRpcResult<Eip712Domain> {
389        Ok(JsonRpcResponse::ok(self.domain_separator.clone()))
390    }
391
392    #[cfg(feature = "v2")]
393    fn eip712_domain_info_v2(&self) -> JsonRpcResult<Eip712Domain> {
394        Ok(JsonRpcResponse::ok(self.domain_separator_v2.clone()))
395    }
396
397    fn aggregate_receipts(
398        &self,
399        api_version: String,
400        receipts: Vec<Eip712SignedMessage<Receipt>>,
401        previous_rav: Option<Eip712SignedMessage<ReceiptAggregateVoucher>>,
402    ) -> JsonRpcResult<Eip712SignedMessage<ReceiptAggregateVoucher>> {
403        // Values for Prometheus metrics
404        let receipts_grt: u128 = receipts.iter().map(|r| r.message.value).sum();
405        let receipts_count: u64 = receipts.len() as u64;
406
407        match aggregate_receipts_(
408            api_version,
409            &self.wallet,
410            &self.accepted_addresses,
411            &self.domain_separator,
412            receipts,
413            previous_rav,
414        ) {
415            Ok(res) => {
416                TOTAL_GRT_AGGREGATED.inc_by(receipts_grt as f64);
417                TOTAL_AGGREGATED_RECEIPTS.inc_by(receipts_count);
418                AGGREGATION_SUCCESS_COUNTER.inc();
419                if let Some(kafka) = &self.kafka {
420                    produce_kafka_records(
421                        kafka,
422                        &self.wallet.address(),
423                        &res.data.message.allocationId,
424                        res.data.message.valueAggregate,
425                    );
426                }
427                Ok(res)
428            }
429            Err(e) => {
430                AGGREGATION_FAILURE_COUNTER.inc();
431                Err(e)
432            }
433        }
434    }
435
436    #[cfg(feature = "v2")]
437    fn aggregate_receipts_v2(
438        &self,
439        api_version: String,
440        receipts: Vec<Eip712SignedMessage<tap_graph::v2::Receipt>>,
441        previous_rav: Option<Eip712SignedMessage<tap_graph::v2::ReceiptAggregateVoucher>>,
442    ) -> JsonRpcResult<Eip712SignedMessage<tap_graph::v2::ReceiptAggregateVoucher>> {
443        // Values for Prometheus metrics
444        let receipts_grt: u128 = receipts.iter().map(|r| r.message.value).sum();
445        let receipts_count: u64 = receipts.len() as u64;
446
447        match aggregate_receipts_v2_(
448            api_version,
449            &self.wallet,
450            &self.accepted_addresses,
451            &self.domain_separator_v2,
452            receipts,
453            previous_rav,
454        ) {
455            Ok(res) => {
456                TOTAL_GRT_AGGREGATED.inc_by(receipts_grt as f64);
457                TOTAL_AGGREGATED_RECEIPTS.inc_by(receipts_count);
458                AGGREGATION_SUCCESS_COUNTER.inc();
459                if let Some(kafka) = &self.kafka {
460                    // V2 RAVs use collectionId instead of allocationId
461                    produce_kafka_records(
462                        kafka,
463                        &self.wallet.address(),
464                        &res.data.message.collectionId,
465                        res.data.message.valueAggregate,
466                    );
467                }
468                Ok(res)
469            }
470            Err(e) => {
471                AGGREGATION_FAILURE_COUNTER.inc();
472                Err(e)
473            }
474        }
475    }
476}
477
478#[allow(clippy::too_many_arguments)]
479pub async fn run_server(
480    port: u16,
481    wallet: PrivateKeySigner,
482    accepted_addresses: HashSet<Address>,
483    domain_separator: Eip712Domain,
484    domain_separator_v2: Eip712Domain,
485    max_request_body_size: u32,
486    max_response_body_size: u32,
487    max_concurrent_connections: u32,
488    kafka: Option<rdkafka::producer::ThreadedProducer<rdkafka::producer::DefaultProducerContext>>,
489) -> Result<(JoinHandle<()>, std::net::SocketAddr)> {
490    // Setting up the JSON RPC server
491    let rpc_impl = RpcImpl {
492        wallet,
493        accepted_addresses,
494        domain_separator,
495        domain_separator_v2,
496        kafka,
497    };
498    let (json_rpc_service, _) = create_json_rpc_service(
499        rpc_impl.clone(),
500        max_request_body_size,
501        max_response_body_size,
502        max_concurrent_connections,
503    )?;
504
505    async fn handle_anyhow_error(err: BoxError) -> (StatusCode, String) {
506        (
507            StatusCode::INTERNAL_SERVER_ERROR,
508            format!("Something went wrong: {err}"),
509        )
510    }
511    let json_rpc_router = Router::new().route_service(
512        "/",
513        HandleError::new(post_service(json_rpc_service), handle_anyhow_error),
514    );
515
516    let grpc_service = create_grpc_service(rpc_impl)?;
517
518    let grpc_router = Router::new()
519        .layer(tower::limit::ConcurrencyLimitLayer::new(
520            max_concurrent_connections as usize,
521        ))
522        .merge(grpc_service.into_axum_router());
523
524    let service = tower::steer::Steer::new(
525        [json_rpc_router, grpc_router],
526        |req: &hyper::Request<_>, _services: &[_]| {
527            if req
528                .headers()
529                .get(hyper::header::CONTENT_TYPE)
530                .map(|content_type| content_type.as_bytes())
531                .filter(|content_type| content_type.starts_with(b"application/grpc"))
532                .is_some()
533            {
534                // route to the gRPC service (second service element) when the
535                // header is set
536                1
537            } else {
538                // otherwise route to the REST service
539                0
540            }
541        },
542    );
543
544    // Create a `TcpListener` using tokio.
545    let listener = TcpListener::bind(&format!("0.0.0.0:{port}"))
546        .await
547        .expect("Failed to bind to tap-aggregator port");
548
549    let addr = listener.local_addr()?;
550    let handle = tokio::spawn(async move {
551        if let Err(e) = axum::serve(listener, Shared::new(service))
552            .with_graceful_shutdown(shutdown_handler())
553            .await
554        {
555            log::error!("Tap Aggregator error: {e}");
556        }
557    });
558
559    Ok((handle, addr))
560}
561
562/// Graceful shutdown handler
563async fn shutdown_handler() {
564    let ctrl_c = async {
565        signal::ctrl_c()
566            .await
567            .expect("Failed to install Ctrl+C handler");
568    };
569
570    let terminate = async {
571        signal::unix::signal(signal::unix::SignalKind::terminate())
572            .expect("Failed to install signal handler")
573            .recv()
574            .await;
575    };
576
577    tokio::select! {
578        _ = ctrl_c => {},
579        _ = terminate => {},
580    }
581
582    info!("Signal received, starting graceful shutdown");
583}
584
585fn create_grpc_service(rpc_impl: RpcImpl) -> Result<Routes> {
586    let grpc_service = Routes::new(
587        v1::tap_aggregator_server::TapAggregatorServer::new(rpc_impl.clone())
588            .accept_compressed(CompressionEncoding::Zstd),
589    )
590    .add_service(
591        v2::tap_aggregator_server::TapAggregatorServer::new(rpc_impl)
592            .accept_compressed(CompressionEncoding::Zstd),
593    )
594    .prepare();
595
596    Ok(grpc_service)
597}
598
599fn create_json_rpc_service(
600    rpc_impl: RpcImpl,
601    max_request_body_size: u32,
602    max_response_body_size: u32,
603    max_concurrent_connections: u32,
604) -> Result<(TowerService<Identity, Identity>, ServerHandle)> {
605    let config = ServerConfig::builder()
606        .max_request_body_size(max_request_body_size)
607        .max_response_body_size(max_response_body_size)
608        .max_connections(max_concurrent_connections)
609        .http_only()
610        .build();
611
612    let service_builder = ServerBuilder::new().set_config(config).to_service_builder();
613    use jsonrpsee::server::stop_channel;
614    let (stop_handle, server_handle) = stop_channel();
615    let handle = service_builder.build(rpc_impl.into_rpc(), stop_handle);
616    Ok((handle, server_handle))
617}
618
619fn produce_kafka_records<K: Debug>(
620    kafka: &rdkafka::producer::ThreadedProducer<rdkafka::producer::DefaultProducerContext>,
621    sender: &Address,
622    key_fragment: &K,
623    aggregated_value: u128,
624) {
625    let topic = "gateway_ravs";
626    let key = format!("{sender:?}:{key_fragment:?}");
627    let payload = aggregated_value.to_string();
628    let result = kafka.send(
629        rdkafka::producer::BaseRecord::to(topic)
630            .key(&key)
631            .payload(&payload),
632    );
633    if let Err((err, _)) = result {
634        error!("error producing to {topic}: {err}");
635    }
636}
637
638#[cfg(test)]
639#[allow(clippy::too_many_arguments)]
640mod tests {
641    use std::{collections::HashSet, str::FromStr};
642
643    use jsonrpsee::{core::client::ClientT, http_client::HttpClientBuilder, rpc_params};
644    use rstest::*;
645    use tap_core::{signed_message::Eip712SignedMessage, tap_eip712_domain, TapVersion};
646    use tap_graph::{Receipt, ReceiptAggregateVoucher};
647    use thegraph_core::alloy::{
648        dyn_abi::Eip712Domain, primitives::Address, signers::local::PrivateKeySigner,
649    };
650
651    use crate::server;
652
653    #[derive(Clone)]
654    struct Keys {
655        wallet: PrivateKeySigner,
656        address: Address,
657    }
658
659    fn keys() -> Keys {
660        let wallet = PrivateKeySigner::random();
661        let address = wallet.address();
662        Keys { wallet, address }
663    }
664
665    #[fixture]
666    fn allocation_ids() -> Vec<Address> {
667        vec![
668            Address::from_str("0xabababababababababababababababababababab").unwrap(),
669            Address::from_str("0xdeaddeaddeaddeaddeaddeaddeaddeaddeaddead").unwrap(),
670            Address::from_str("0xbeefbeefbeefbeefbeefbeefbeefbeefbeefbeef").unwrap(),
671            Address::from_str("0x1234567890abcdef1234567890abcdef12345678").unwrap(),
672        ]
673    }
674
675    #[fixture]
676    fn domain_separator() -> Eip712Domain {
677        tap_eip712_domain(1, Address::from([0x11u8; 20]), TapVersion::V1)
678    }
679    #[fixture]
680    fn domain_separator_v2() -> Eip712Domain {
681        tap_eip712_domain(1, Address::from([0x22u8; 20]), TapVersion::V2)
682    }
683
684    #[fixture]
685    fn http_request_size_limit() -> u32 {
686        100 * 1024
687    }
688
689    #[fixture]
690    fn http_response_size_limit() -> u32 {
691        100 * 1024
692    }
693
694    #[fixture]
695    fn http_max_concurrent_connections() -> u32 {
696        1
697    }
698
699    #[rstest]
700    #[tokio::test]
701    async fn protocol_version(
702        domain_separator: Eip712Domain,
703        domain_separator_v2: Eip712Domain,
704        http_request_size_limit: u32,
705        http_response_size_limit: u32,
706        http_max_concurrent_connections: u32,
707    ) {
708        // The keys that will be used to sign the new RAVs
709        let keys_main = keys();
710
711        // Start the JSON-RPC server.
712        let (handle, local_addr) = server::run_server(
713            0,
714            keys_main.wallet,
715            HashSet::from([keys_main.address]),
716            domain_separator,
717            domain_separator_v2,
718            http_request_size_limit,
719            http_response_size_limit,
720            http_max_concurrent_connections,
721            None,
722        )
723        .await
724        .unwrap();
725
726        // Start the JSON-RPC client.
727        let client = HttpClientBuilder::default()
728            .build(format!("http://127.0.0.1:{}", local_addr.port()))
729            .unwrap();
730        let _: server::JsonRpcResponse<server::TapRpcApiVersionsInfo> = client
731            .request("api_versions", rpc_params!(None::<()>))
732            .await
733            .unwrap();
734
735        handle.abort();
736    }
737
738    #[rstest]
739    #[case::basic_rav_test (vec![45,56,34,23])]
740    #[case::rav_from_zero_valued_receipts (vec![0,0,0,0])]
741    #[tokio::test]
742    async fn signed_rav_is_valid_with_no_previous_rav(
743        domain_separator: Eip712Domain,
744        domain_separator_v2: Eip712Domain,
745        http_request_size_limit: u32,
746        http_response_size_limit: u32,
747        http_max_concurrent_connections: u32,
748        allocation_ids: Vec<Address>,
749        #[case] values: Vec<u128>,
750        #[values("0.0")] api_version: &str,
751        #[values(0, 1, 2)] random_seed: u64,
752    ) {
753        // The keys that will be used to sign the new RAVs
754
755        use rand::{rngs::StdRng, seq::IndexedRandom, SeedableRng};
756        let keys_main = keys();
757        // Extra keys to test the server's ability to accept multiple signers as input
758        let keys_0 = keys();
759        let keys_1 = keys();
760        // Vector of all wallets to make it easier to select one randomly
761        let all_wallets = vec![keys_main.clone(), keys_0.clone(), keys_1.clone()];
762        // PRNG for selecting a random wallet
763        let mut rng = StdRng::seed_from_u64(random_seed);
764
765        // Start the JSON-RPC server.
766        let (handle, local_addr) = server::run_server(
767            0,
768            keys_main.wallet.clone(),
769            HashSet::from([keys_main.address, keys_0.address, keys_1.address]),
770            domain_separator.clone(),
771            domain_separator_v2.clone(),
772            http_request_size_limit,
773            http_response_size_limit,
774            http_max_concurrent_connections,
775            None,
776        )
777        .await
778        .unwrap();
779
780        // Start the JSON-RPC client.
781        let client = HttpClientBuilder::default()
782            .build(format!("http://127.0.0.1:{}", local_addr.port()))
783            .unwrap();
784
785        // Create receipts
786        let mut receipts = Vec::new();
787        for value in values {
788            receipts.push(
789                Eip712SignedMessage::new(
790                    &domain_separator,
791                    Receipt::new(allocation_ids[0], value).unwrap(),
792                    &all_wallets.choose(&mut rng).unwrap().wallet,
793                )
794                .unwrap(),
795            );
796        }
797
798        // Skipping receipts validation in this test, aggregate_receipts assumes receipts are valid.
799        // Create RAV through the JSON-RPC server.
800        let res: server::JsonRpcResponse<Eip712SignedMessage<ReceiptAggregateVoucher>> = client
801            .request(
802                "aggregate_receipts",
803                rpc_params!(api_version, &receipts, None::<()>),
804            )
805            .await
806            .unwrap();
807
808        let remote_rav = res.data;
809
810        let local_rav =
811            ReceiptAggregateVoucher::aggregate_receipts(allocation_ids[0], &receipts, None)
812                .unwrap();
813
814        assert!(remote_rav.message.allocationId == local_rav.allocationId);
815        assert!(remote_rav.message.timestampNs == local_rav.timestampNs);
816        assert!(remote_rav.message.valueAggregate == local_rav.valueAggregate);
817
818        assert!(remote_rav.recover_signer(&domain_separator).unwrap() == keys_main.address);
819
820        handle.abort();
821    }
822
823    #[rstest]
824    #[case::basic_rav_test (vec![45,56,34,23])]
825    #[case::rav_from_zero_valued_receipts (vec![0,0,0,0])]
826    #[tokio::test]
827    async fn signed_rav_is_valid_with_previous_rav(
828        domain_separator: Eip712Domain,
829        domain_separator_v2: Eip712Domain,
830        http_request_size_limit: u32,
831        http_response_size_limit: u32,
832        http_max_concurrent_connections: u32,
833        allocation_ids: Vec<Address>,
834        #[case] values: Vec<u128>,
835        #[values("0.0")] api_version: &str,
836        #[values(0, 1, 2, 3, 4)] random_seed: u64,
837    ) {
838        // The keys that will be used to sign the new RAVs
839
840        use rand::{rngs::StdRng, seq::IndexedRandom, SeedableRng};
841        let keys_main = keys();
842        // Extra keys to test the server's ability to accept multiple signers as input
843        let keys_0 = keys();
844        let keys_1 = keys();
845        // Vector of all wallets to make it easier to select one randomly
846        let all_wallets = vec![keys_main.clone(), keys_0.clone(), keys_1.clone()];
847        // PRNG for selecting a random wallet
848        let mut rng = StdRng::seed_from_u64(random_seed);
849
850        // Start the JSON-RPC server.
851        let (handle, local_addr) = server::run_server(
852            0,
853            keys_main.wallet.clone(),
854            HashSet::from([keys_main.address, keys_0.address, keys_1.address]),
855            domain_separator.clone(),
856            domain_separator_v2.clone(),
857            http_request_size_limit,
858            http_response_size_limit,
859            http_max_concurrent_connections,
860            None,
861        )
862        .await
863        .unwrap();
864
865        // Start the JSON-RPC client.
866        let client = HttpClientBuilder::default()
867            .build(format!("http://127.0.0.1:{}", local_addr.port()))
868            .unwrap();
869
870        // Create receipts
871        let mut receipts = Vec::new();
872        for value in values {
873            receipts.push(
874                Eip712SignedMessage::new(
875                    &domain_separator,
876                    Receipt::new(allocation_ids[0], value).unwrap(),
877                    &all_wallets.choose(&mut rng).unwrap().wallet,
878                )
879                .unwrap(),
880            );
881        }
882
883        // Create previous RAV from first half of receipts locally
884        let prev_rav = ReceiptAggregateVoucher::aggregate_receipts(
885            allocation_ids[0],
886            &receipts[0..receipts.len() / 2],
887            None,
888        )
889        .unwrap();
890        let signed_prev_rav = Eip712SignedMessage::new(
891            &domain_separator,
892            prev_rav,
893            &all_wallets.choose(&mut rng).unwrap().wallet,
894        )
895        .unwrap();
896
897        // Create new RAV from last half of receipts and prev_rav through the JSON-RPC server
898        let res: server::JsonRpcResponse<Eip712SignedMessage<ReceiptAggregateVoucher>> = client
899            .request(
900                "aggregate_receipts",
901                rpc_params!(
902                    api_version,
903                    &receipts[receipts.len() / 2..receipts.len()],
904                    Some(signed_prev_rav)
905                ),
906            )
907            .await
908            .unwrap();
909
910        let rav = res.data;
911
912        assert!(rav.recover_signer(&domain_separator).unwrap() == keys_main.address);
913
914        handle.abort();
915    }
916
917    #[rstest]
918    #[tokio::test]
919    async fn invalid_api_version(
920        domain_separator: Eip712Domain,
921        domain_separator_v2: Eip712Domain,
922        http_request_size_limit: u32,
923        http_response_size_limit: u32,
924        http_max_concurrent_connections: u32,
925        allocation_ids: Vec<Address>,
926    ) {
927        // The keys that will be used to sign the new RAVs
928        let keys_main = keys();
929
930        // Start the JSON-RPC server.
931        let (handle, local_addr) = server::run_server(
932            0,
933            keys_main.wallet.clone(),
934            HashSet::from([keys_main.address]),
935            domain_separator.clone(),
936            domain_separator_v2.clone(),
937            http_request_size_limit,
938            http_response_size_limit,
939            http_max_concurrent_connections,
940            None,
941        )
942        .await
943        .unwrap();
944
945        // Start the JSON-RPC client.
946        let client = HttpClientBuilder::default()
947            .build(format!("http://127.0.0.1:{}", local_addr.port()))
948            .unwrap();
949
950        // Create receipts
951        let receipts = vec![Eip712SignedMessage::new(
952            &domain_separator,
953            Receipt::new(allocation_ids[0], 42).unwrap(),
954            &keys_main.wallet,
955        )
956        .unwrap()];
957
958        // Skipping receipts validation in this test, aggregate_receipts assumes receipts are valid.
959        // Create RAV through the JSON-RPC server.
960        let res: Result<
961            server::JsonRpcResponse<Eip712SignedMessage<ReceiptAggregateVoucher>>,
962            jsonrpsee::core::ClientError,
963        > = client
964            .request(
965                "aggregate_receipts",
966                rpc_params!("invalid version string", &receipts, None::<()>),
967            )
968            .await;
969
970        assert!(res.is_err());
971
972        // Make sure the JSON-RPC error is "invalid version"
973        assert!(res
974            .as_ref()
975            .unwrap_err()
976            .to_string()
977            .contains("Unsupported API version"));
978
979        // Check the API versions returned by the server
980        match res.expect_err("Expected an error") {
981            jsonrpsee::core::ClientError::Call(err) => {
982                let versions: server::TapRpcApiVersionsInfo =
983                    serde_json::from_str(err.data().unwrap().get()).unwrap();
984                assert!(versions
985                    .versions_supported
986                    .contains(&server::TapRpcApiVersion::V0_0));
987            }
988            _ => panic!("Expected data in error"),
989        }
990
991        handle.abort();
992    }
993
994    /// Test that the server returns an error when the request size exceeds the limit.
995    /// The server should return HTTP 413 (Request Entity Too Large).
996    /// In this test, the request size limit is set to 100 kB, and we are expecting
997    /// that to fit about 250 receipts. We also test with 300 receipts, which should
998    /// exceed the limit.
999    /// We conclude that a limit of 10MB should fit about 25k receipts, and thus
1000    /// the TAP spec will require that the aggregator supports up to 15k receipts
1001    /// per aggregation request as a safe limit.
1002    #[rstest]
1003    #[tokio::test]
1004    async fn request_size_limit(
1005        domain_separator: Eip712Domain,
1006        domain_separator_v2: Eip712Domain,
1007        http_response_size_limit: u32,
1008        http_max_concurrent_connections: u32,
1009        allocation_ids: Vec<Address>,
1010        #[values("0.0")] api_version: &str,
1011    ) {
1012        // The keys that will be used to sign the new RAVs
1013        let keys_main = keys();
1014
1015        // Set the request byte size limit to a value that easily triggers the HTTP 413
1016        // error.
1017        let http_request_size_limit = 100 * 1024;
1018
1019        // Number of receipts that is just above the number that would fit within the
1020        // request size limit. This value is hard-coded here because it supports the
1021        // maximum number of receipts per aggregate value we wrote in the spec / docs.
1022        let number_of_receipts_to_exceed_limit = 300;
1023
1024        // Start the JSON-RPC server.
1025        let (handle, local_addr) = server::run_server(
1026            0,
1027            keys_main.wallet.clone(),
1028            HashSet::from([keys_main.address]),
1029            domain_separator.clone(),
1030            domain_separator_v2.clone(),
1031            http_request_size_limit,
1032            http_response_size_limit,
1033            http_max_concurrent_connections,
1034            None,
1035        )
1036        .await
1037        .unwrap();
1038
1039        // Start the JSON-RPC client.
1040        let client = HttpClientBuilder::default()
1041            .build(format!("http://127.0.0.1:{}", local_addr.port()))
1042            .unwrap();
1043
1044        // Create receipts
1045        let mut receipts = Vec::new();
1046        for _ in 1..number_of_receipts_to_exceed_limit {
1047            receipts.push(
1048                Eip712SignedMessage::new(
1049                    &domain_separator,
1050                    Receipt::new(allocation_ids[0], u128::MAX / 1000).unwrap(),
1051                    &keys_main.wallet,
1052                )
1053                .unwrap(),
1054            );
1055        }
1056
1057        // Skipping receipts validation in this test, aggregate_receipts assumes receipts are valid.
1058        // Create RAV through the JSON-RPC server.
1059        // Test with a number of receipts that stays within request size limit
1060        let res: Result<
1061            server::JsonRpcResponse<Eip712SignedMessage<ReceiptAggregateVoucher>>,
1062            jsonrpsee::core::ClientError,
1063        > = client
1064            .request(
1065                "aggregate_receipts",
1066                rpc_params!(
1067                    api_version,
1068                    &receipts[..number_of_receipts_to_exceed_limit - 50],
1069                    None::<()>
1070                ),
1071            )
1072            .await;
1073        assert!(res.is_ok());
1074
1075        // Create RAV through the JSON-RPC server.
1076        // Test with all receipts to exceed request size limit
1077        let res: Result<
1078            server::JsonRpcResponse<Eip712SignedMessage<ReceiptAggregateVoucher>>,
1079            jsonrpsee::core::ClientError,
1080        > = client
1081            .request(
1082                "aggregate_receipts",
1083                rpc_params!(api_version, &receipts, None::<()>),
1084            )
1085            .await;
1086
1087        assert!(res.is_err());
1088        // Make sure the error is a HTTP 413 Content Too Large
1089        assert!(res.unwrap_err().to_string().contains("413"));
1090
1091        handle.abort();
1092    }
1093}