tikv_client/request/
plan.rs

1// Copyright 2021 TiKV Project Authors. Licensed under Apache-2.0.
2
3use std::marker::PhantomData;
4use std::sync::Arc;
5
6use async_recursion::async_recursion;
7use async_trait::async_trait;
8use futures::future::try_join_all;
9use futures::prelude::*;
10use log::debug;
11use log::info;
12use tokio::sync::Semaphore;
13use tokio::time::sleep;
14
15use crate::backoff::Backoff;
16use crate::pd::PdClient;
17use crate::proto::errorpb;
18use crate::proto::errorpb::EpochNotMatch;
19use crate::proto::kvrpcpb;
20use crate::request::shard::HasNextBatch;
21use crate::request::NextBatch;
22use crate::request::Shardable;
23use crate::request::{KvRequest, StoreRequest};
24use crate::stats::tikv_stats;
25use crate::store::HasRegionError;
26use crate::store::HasRegionErrors;
27use crate::store::KvClient;
28use crate::store::RegionStore;
29use crate::store::{HasKeyErrors, Store};
30use crate::transaction::resolve_locks;
31use crate::transaction::HasLocks;
32use crate::transaction::ResolveLocksContext;
33use crate::transaction::ResolveLocksOptions;
34use crate::util::iter::FlatMapOkIterExt;
35use crate::Error;
36use crate::Result;
37
38/// A plan for how to execute a request. A user builds up a plan with various
39/// options, then exectutes it.
40#[async_trait]
41pub trait Plan: Sized + Clone + Sync + Send + 'static {
42    /// The ultimate result of executing the plan (should be a high-level type, not a GRPC response).
43    type Result: Send;
44
45    /// Execute the plan.
46    async fn execute(&self) -> Result<Self::Result>;
47}
48
49/// The simplest plan which just dispatches a request to a specific kv server.
50#[derive(Clone)]
51pub struct Dispatch<Req: KvRequest> {
52    pub request: Req,
53    pub kv_client: Option<Arc<dyn KvClient + Send + Sync>>,
54}
55
56#[async_trait]
57impl<Req: KvRequest> Plan for Dispatch<Req> {
58    type Result = Req::Response;
59
60    async fn execute(&self) -> Result<Self::Result> {
61        let stats = tikv_stats(self.request.label());
62        let result = self
63            .kv_client
64            .as_ref()
65            .expect("Unreachable: kv_client has not been initialised in Dispatch")
66            .dispatch(&self.request)
67            .await;
68        let result = stats.done(result);
69        result.map(|r| {
70            *r.downcast()
71                .expect("Downcast failed: request and response type mismatch")
72        })
73    }
74}
75
76impl<Req: KvRequest + StoreRequest> StoreRequest for Dispatch<Req> {
77    fn apply_store(&mut self, store: &Store) {
78        self.kv_client = Some(store.client.clone());
79        self.request.apply_store(store);
80    }
81}
82
83const MULTI_REGION_CONCURRENCY: usize = 16;
84const MULTI_STORES_CONCURRENCY: usize = 16;
85
86fn is_grpc_error(e: &Error) -> bool {
87    matches!(e, Error::GrpcAPI(_) | Error::Grpc(_))
88}
89
90pub struct RetryableMultiRegion<P: Plan, PdC: PdClient> {
91    pub(super) inner: P,
92    pub pd_client: Arc<PdC>,
93    pub backoff: Backoff,
94
95    /// Preserve all regions' results for other downstream plans to handle.
96    /// If true, return Ok and preserve all regions' results, even if some of them are Err.
97    /// Otherwise, return the first Err if there is any.
98    pub preserve_region_results: bool,
99}
100
101impl<P: Plan + Shardable, PdC: PdClient> RetryableMultiRegion<P, PdC>
102where
103    P::Result: HasKeyErrors + HasRegionError,
104{
105    // A plan may involve multiple shards
106    #[async_recursion]
107    async fn single_plan_handler(
108        pd_client: Arc<PdC>,
109        current_plan: P,
110        backoff: Backoff,
111        permits: Arc<Semaphore>,
112        preserve_region_results: bool,
113    ) -> Result<<Self as Plan>::Result> {
114        let shards = current_plan.shards(&pd_client).collect::<Vec<_>>().await;
115        let mut handles = Vec::new();
116        for shard in shards {
117            let (shard, region_store) = shard?;
118            let mut clone = current_plan.clone();
119            clone.apply_shard(shard, &region_store)?;
120            let handle = tokio::spawn(Self::single_shard_handler(
121                pd_client.clone(),
122                clone,
123                region_store,
124                backoff.clone(),
125                permits.clone(),
126                preserve_region_results,
127            ));
128            handles.push(handle);
129        }
130
131        let results = try_join_all(handles).await?;
132        if preserve_region_results {
133            Ok(results
134                .into_iter()
135                .flat_map_ok(|x| x)
136                .map(|x| match x {
137                    Ok(r) => r,
138                    Err(e) => Err(e),
139                })
140                .collect())
141        } else {
142            Ok(results
143                .into_iter()
144                .collect::<Result<Vec<_>>>()?
145                .into_iter()
146                .flatten()
147                .collect())
148        }
149    }
150
151    #[async_recursion]
152    async fn single_shard_handler(
153        pd_client: Arc<PdC>,
154        plan: P,
155        region_store: RegionStore,
156        mut backoff: Backoff,
157        permits: Arc<Semaphore>,
158        preserve_region_results: bool,
159    ) -> Result<<Self as Plan>::Result> {
160        // limit concurrent requests
161        let permit = permits.acquire().await.unwrap();
162        let res = plan.execute().await;
163        drop(permit);
164
165        let mut resp = match res {
166            Ok(resp) => resp,
167            Err(e) if is_grpc_error(&e) => {
168                return Self::handle_grpc_error(
169                    pd_client,
170                    plan,
171                    region_store,
172                    backoff,
173                    permits,
174                    preserve_region_results,
175                    e,
176                )
177                .await;
178            }
179            Err(e) => return Err(e),
180        };
181
182        if let Some(e) = resp.key_errors() {
183            Ok(vec![Err(Error::MultipleKeyErrors(e))])
184        } else if let Some(e) = resp.region_error() {
185            match backoff.next_delay_duration() {
186                Some(duration) => {
187                    let region_error_resolved =
188                        Self::handle_region_error(pd_client.clone(), e, region_store).await?;
189                    // don't sleep if we have resolved the region error
190                    if !region_error_resolved {
191                        sleep(duration).await;
192                    }
193                    Self::single_plan_handler(
194                        pd_client,
195                        plan,
196                        backoff,
197                        permits,
198                        preserve_region_results,
199                    )
200                    .await
201                }
202                None => Err(Error::RegionError(Box::new(e))),
203            }
204        } else {
205            Ok(vec![Ok(resp)])
206        }
207    }
208
209    // Returns
210    // 1. Ok(true): error has been resolved, retry immediately
211    // 2. Ok(false): backoff, and then retry
212    // 3. Err(Error): can't be resolved, return the error to upper level
213    async fn handle_region_error(
214        pd_client: Arc<PdC>,
215        e: errorpb::Error,
216        region_store: RegionStore,
217    ) -> Result<bool> {
218        let ver_id = region_store.region_with_leader.ver_id();
219        if let Some(not_leader) = e.not_leader {
220            if let Some(leader) = not_leader.leader {
221                match pd_client
222                    .update_leader(region_store.region_with_leader.ver_id(), leader)
223                    .await
224                {
225                    Ok(_) => Ok(true),
226                    Err(e) => {
227                        pd_client.invalidate_region_cache(ver_id).await;
228                        Err(e)
229                    }
230                }
231            } else {
232                // The peer doesn't know who is the current leader. Generally it's because
233                // the Raft group is in an election, but it's possible that the peer is
234                // isolated and removed from the Raft group. So it's necessary to reload
235                // the region from PD.
236                pd_client.invalidate_region_cache(ver_id).await;
237                Ok(false)
238            }
239        } else if e.store_not_match.is_some() {
240            pd_client.invalidate_region_cache(ver_id).await;
241            Ok(false)
242        } else if e.epoch_not_match.is_some() {
243            Self::on_region_epoch_not_match(
244                pd_client.clone(),
245                region_store,
246                e.epoch_not_match.unwrap(),
247            )
248            .await
249        } else if e.stale_command.is_some() || e.region_not_found.is_some() {
250            pd_client.invalidate_region_cache(ver_id).await;
251            Ok(false)
252        } else if e.server_is_busy.is_some()
253            || e.raft_entry_too_large.is_some()
254            || e.max_timestamp_not_synced.is_some()
255        {
256            Err(Error::RegionError(Box::new(e)))
257        } else {
258            // TODO: pass the logger around
259            // info!("unknwon region error: {:?}", e);
260            pd_client.invalidate_region_cache(ver_id).await;
261            Ok(false)
262        }
263    }
264
265    // Returns
266    // 1. Ok(true): error has been resolved, retry immediately
267    // 2. Ok(false): backoff, and then retry
268    // 3. Err(Error): can't be resolved, return the error to upper level
269    async fn on_region_epoch_not_match(
270        pd_client: Arc<PdC>,
271        region_store: RegionStore,
272        error: EpochNotMatch,
273    ) -> Result<bool> {
274        let ver_id = region_store.region_with_leader.ver_id();
275        if error.current_regions.is_empty() {
276            pd_client.invalidate_region_cache(ver_id).await;
277            return Ok(true);
278        }
279
280        for r in error.current_regions {
281            if r.id == region_store.region_with_leader.id() {
282                let region_epoch = r.region_epoch.unwrap();
283                let returned_conf_ver = region_epoch.conf_ver;
284                let returned_version = region_epoch.version;
285                let current_region_epoch = region_store
286                    .region_with_leader
287                    .region
288                    .region_epoch
289                    .clone()
290                    .unwrap();
291                let current_conf_ver = current_region_epoch.conf_ver;
292                let current_version = current_region_epoch.version;
293
294                // Find whether the current region is ahead of TiKV's. If so, backoff.
295                if returned_conf_ver < current_conf_ver || returned_version < current_version {
296                    return Ok(false);
297                }
298            }
299        }
300        // TODO: finer grained processing
301        pd_client.invalidate_region_cache(ver_id).await;
302        Ok(false)
303    }
304
305    async fn handle_grpc_error(
306        pd_client: Arc<PdC>,
307        plan: P,
308        region_store: RegionStore,
309        mut backoff: Backoff,
310        permits: Arc<Semaphore>,
311        preserve_region_results: bool,
312        e: Error,
313    ) -> Result<<Self as Plan>::Result> {
314        debug!("handle grpc error: {:?}", e);
315        let ver_id = region_store.region_with_leader.ver_id();
316        pd_client.invalidate_region_cache(ver_id).await;
317        match backoff.next_delay_duration() {
318            Some(duration) => {
319                sleep(duration).await;
320                Self::single_plan_handler(
321                    pd_client,
322                    plan,
323                    backoff,
324                    permits,
325                    preserve_region_results,
326                )
327                .await
328            }
329            None => Err(e),
330        }
331    }
332}
333
334impl<P: Plan, PdC: PdClient> Clone for RetryableMultiRegion<P, PdC> {
335    fn clone(&self) -> Self {
336        RetryableMultiRegion {
337            inner: self.inner.clone(),
338            pd_client: self.pd_client.clone(),
339            backoff: self.backoff.clone(),
340            preserve_region_results: self.preserve_region_results,
341        }
342    }
343}
344
345#[async_trait]
346impl<P: Plan + Shardable, PdC: PdClient> Plan for RetryableMultiRegion<P, PdC>
347where
348    P::Result: HasKeyErrors + HasRegionError,
349{
350    type Result = Vec<Result<P::Result>>;
351
352    async fn execute(&self) -> Result<Self::Result> {
353        // Limit the maximum concurrency of multi-region request. If there are
354        // too many concurrent requests, TiKV is more likely to return a "TiKV
355        // is busy" error
356        let concurrency_permits = Arc::new(Semaphore::new(MULTI_REGION_CONCURRENCY));
357        Self::single_plan_handler(
358            self.pd_client.clone(),
359            self.inner.clone(),
360            self.backoff.clone(),
361            concurrency_permits.clone(),
362            self.preserve_region_results,
363        )
364        .await
365    }
366}
367
368pub struct RetryableAllStores<P: Plan, PdC: PdClient> {
369    pub(super) inner: P,
370    pub pd_client: Arc<PdC>,
371    pub backoff: Backoff,
372}
373
374impl<P: Plan, PdC: PdClient> Clone for RetryableAllStores<P, PdC> {
375    fn clone(&self) -> Self {
376        RetryableAllStores {
377            inner: self.inner.clone(),
378            pd_client: self.pd_client.clone(),
379            backoff: self.backoff.clone(),
380        }
381    }
382}
383
384// About `HasRegionError`:
385// Store requests should be return region errors.
386// But as the response of only store request by now (UnsafeDestroyRangeResponse) has the `region_error` field,
387// we require `HasRegionError` to check whether there is region error returned from TiKV.
388#[async_trait]
389impl<P: Plan + StoreRequest, PdC: PdClient> Plan for RetryableAllStores<P, PdC>
390where
391    P::Result: HasKeyErrors + HasRegionError,
392{
393    type Result = Vec<Result<P::Result>>;
394
395    async fn execute(&self) -> Result<Self::Result> {
396        let concurrency_permits = Arc::new(Semaphore::new(MULTI_STORES_CONCURRENCY));
397        let stores = self.pd_client.clone().all_stores().await?;
398        let mut handles = Vec::with_capacity(stores.len());
399        for store in stores {
400            let mut clone = self.inner.clone();
401            clone.apply_store(&store);
402            let handle = tokio::spawn(Self::single_store_handler(
403                clone,
404                self.backoff.clone(),
405                concurrency_permits.clone(),
406            ));
407            handles.push(handle);
408        }
409        let results = try_join_all(handles).await?;
410        Ok(results.into_iter().collect::<Vec<_>>())
411    }
412}
413
414impl<P: Plan, PdC: PdClient> RetryableAllStores<P, PdC>
415where
416    P::Result: HasKeyErrors + HasRegionError,
417{
418    async fn single_store_handler(
419        plan: P,
420        mut backoff: Backoff,
421        permits: Arc<Semaphore>,
422    ) -> Result<P::Result> {
423        loop {
424            let permit = permits.acquire().await.unwrap();
425            let res = plan.execute().await;
426            drop(permit);
427
428            match res {
429                Ok(mut resp) => {
430                    if let Some(e) = resp.key_errors() {
431                        return Err(Error::MultipleKeyErrors(e));
432                    } else if let Some(e) = resp.region_error() {
433                        // Store request should not return region error.
434                        return Err(Error::RegionError(Box::new(e)));
435                    } else {
436                        return Ok(resp);
437                    }
438                }
439                Err(e) if is_grpc_error(&e) => match backoff.next_delay_duration() {
440                    Some(duration) => {
441                        sleep(duration).await;
442                        continue;
443                    }
444                    None => return Err(e),
445                },
446                Err(e) => return Err(e),
447            }
448        }
449    }
450}
451
452/// A technique for merging responses into a single result (with type `Out`).
453pub trait Merge<In>: Sized + Clone + Send + Sync + 'static {
454    type Out: Send;
455
456    fn merge(&self, input: Vec<Result<In>>) -> Result<Self::Out>;
457}
458
459#[derive(Clone)]
460pub struct MergeResponse<P: Plan, In, M: Merge<In>> {
461    pub inner: P,
462    pub merge: M,
463    pub phantom: PhantomData<In>,
464}
465
466#[async_trait]
467impl<In: Clone + Send + Sync + 'static, P: Plan<Result = Vec<Result<In>>>, M: Merge<In>> Plan
468    for MergeResponse<P, In, M>
469{
470    type Result = M::Out;
471
472    async fn execute(&self) -> Result<Self::Result> {
473        self.merge.merge(self.inner.execute().await?)
474    }
475}
476
477/// A merge strategy which collects data from a response into a single type.
478#[derive(Clone, Copy)]
479pub struct Collect;
480
481/// A merge strategy that only takes the first element. It's used for requests
482/// that should have exactly one response, e.g. a get request.
483#[derive(Clone, Copy)]
484pub struct CollectSingle;
485
486#[doc(hidden)]
487#[macro_export]
488macro_rules! collect_first {
489    ($type_: ty) => {
490        impl Merge<$type_> for CollectSingle {
491            type Out = $type_;
492
493            fn merge(&self, mut input: Vec<Result<$type_>>) -> Result<Self::Out> {
494                assert!(input.len() == 1);
495                input.pop().unwrap()
496            }
497        }
498    };
499}
500
501/// A merge strategy to be used with
502/// [`preserve_shard`](super::plan_builder::PlanBuilder::preserve_shard).
503/// It matches the shards preserved before and the values returned in the response.
504#[derive(Clone, Debug)]
505pub struct CollectWithShard;
506
507/// A merge strategy which returns an error if any response is an error and
508/// otherwise returns a Vec of the results.
509#[derive(Clone, Copy)]
510pub struct CollectError;
511
512impl<T: Send> Merge<T> for CollectError {
513    type Out = Vec<T>;
514
515    fn merge(&self, input: Vec<Result<T>>) -> Result<Self::Out> {
516        input.into_iter().collect()
517    }
518}
519
520/// Process data into another kind of data.
521pub trait Process<In>: Sized + Clone + Send + Sync + 'static {
522    type Out: Send;
523
524    fn process(&self, input: Result<In>) -> Result<Self::Out>;
525}
526
527#[derive(Clone)]
528pub struct ProcessResponse<P: Plan, Pr: Process<P::Result>> {
529    pub inner: P,
530    pub processor: Pr,
531}
532
533#[async_trait]
534impl<P: Plan, Pr: Process<P::Result>> Plan for ProcessResponse<P, Pr> {
535    type Result = Pr::Out;
536
537    async fn execute(&self) -> Result<Self::Result> {
538        self.processor.process(self.inner.execute().await)
539    }
540}
541
542#[derive(Clone, Copy, Debug)]
543pub struct DefaultProcessor;
544
545pub struct ResolveLock<P: Plan, PdC: PdClient> {
546    pub inner: P,
547    pub pd_client: Arc<PdC>,
548    pub backoff: Backoff,
549}
550
551impl<P: Plan, PdC: PdClient> Clone for ResolveLock<P, PdC> {
552    fn clone(&self) -> Self {
553        ResolveLock {
554            inner: self.inner.clone(),
555            pd_client: self.pd_client.clone(),
556            backoff: self.backoff.clone(),
557        }
558    }
559}
560
561#[async_trait]
562impl<P: Plan, PdC: PdClient> Plan for ResolveLock<P, PdC>
563where
564    P::Result: HasLocks,
565{
566    type Result = P::Result;
567
568    async fn execute(&self) -> Result<Self::Result> {
569        let mut result = self.inner.execute().await?;
570        let mut clone = self.clone();
571        loop {
572            let locks = result.take_locks();
573            if locks.is_empty() {
574                return Ok(result);
575            }
576
577            if self.backoff.is_none() {
578                return Err(Error::ResolveLockError(locks));
579            }
580
581            let pd_client = self.pd_client.clone();
582            let live_locks = resolve_locks(locks, pd_client.clone()).await?;
583            if live_locks.is_empty() {
584                result = self.inner.execute().await?;
585            } else {
586                match clone.backoff.next_delay_duration() {
587                    None => return Err(Error::ResolveLockError(live_locks)),
588                    Some(delay_duration) => {
589                        sleep(delay_duration).await;
590                        result = clone.inner.execute().await?;
591                    }
592                }
593            }
594        }
595    }
596}
597
598#[derive(Default)]
599pub struct CleanupLocksResult {
600    pub region_error: Option<errorpb::Error>,
601    pub key_error: Option<Vec<Error>>,
602    pub resolved_locks: usize,
603}
604
605impl Clone for CleanupLocksResult {
606    fn clone(&self) -> Self {
607        Self {
608            resolved_locks: self.resolved_locks,
609            ..Default::default() // Ignore errors, which should be extracted by `extract_error()`.
610        }
611    }
612}
613
614impl HasRegionError for CleanupLocksResult {
615    fn region_error(&mut self) -> Option<errorpb::Error> {
616        self.region_error.take()
617    }
618}
619
620impl HasKeyErrors for CleanupLocksResult {
621    fn key_errors(&mut self) -> Option<Vec<Error>> {
622        self.key_error.take()
623    }
624}
625
626impl Merge<CleanupLocksResult> for Collect {
627    type Out = CleanupLocksResult;
628
629    fn merge(&self, input: Vec<Result<CleanupLocksResult>>) -> Result<Self::Out> {
630        input
631            .into_iter()
632            .try_fold(CleanupLocksResult::default(), |acc, x| {
633                Ok(CleanupLocksResult {
634                    resolved_locks: acc.resolved_locks + x?.resolved_locks,
635                    ..Default::default()
636                })
637            })
638    }
639}
640
641pub struct CleanupLocks<P: Plan, PdC: PdClient> {
642    pub inner: P,
643    pub ctx: ResolveLocksContext,
644    pub options: ResolveLocksOptions,
645    pub store: Option<RegionStore>,
646    pub pd_client: Arc<PdC>,
647    pub backoff: Backoff,
648}
649
650impl<P: Plan, PdC: PdClient> Clone for CleanupLocks<P, PdC> {
651    fn clone(&self) -> Self {
652        CleanupLocks {
653            inner: self.inner.clone(),
654            ctx: self.ctx.clone(),
655            options: self.options,
656            store: None,
657            pd_client: self.pd_client.clone(),
658            backoff: self.backoff.clone(),
659        }
660    }
661}
662
663#[async_trait]
664impl<P: Plan + Shardable + NextBatch, PdC: PdClient> Plan for CleanupLocks<P, PdC>
665where
666    P::Result: HasLocks + HasNextBatch + HasKeyErrors + HasRegionError,
667{
668    type Result = CleanupLocksResult;
669
670    async fn execute(&self) -> Result<Self::Result> {
671        let mut result = CleanupLocksResult::default();
672        let mut inner = self.inner.clone();
673        let mut lock_resolver = crate::transaction::LockResolver::new(self.ctx.clone());
674        let region = &self.store.as_ref().unwrap().region_with_leader;
675        let mut has_more_batch = true;
676
677        while has_more_batch {
678            let mut scan_lock_resp = inner.execute().await?;
679
680            // Propagate errors to `retry_multi_region` for retry.
681            if let Some(e) = scan_lock_resp.key_errors() {
682                info!("CleanupLocks::execute, inner key errors:{:?}", e);
683                result.key_error = Some(e);
684                return Ok(result);
685            } else if let Some(e) = scan_lock_resp.region_error() {
686                info!("CleanupLocks::execute, inner region error:{}", e.message);
687                result.region_error = Some(e);
688                return Ok(result);
689            }
690
691            // Iterate to next batch of inner.
692            match scan_lock_resp.has_next_batch() {
693                Some(range) if region.contains(range.0.as_ref()) => {
694                    debug!("CleanupLocks::execute, next range:{:?}", range);
695                    inner.next_batch(range);
696                }
697                _ => has_more_batch = false,
698            }
699
700            let mut locks = scan_lock_resp.take_locks();
701            if locks.is_empty() {
702                break;
703            }
704            if locks.len() < self.options.batch_size as usize {
705                has_more_batch = false;
706            }
707
708            if self.options.async_commit_only {
709                locks = locks
710                    .into_iter()
711                    .filter(|l| l.use_async_commit)
712                    .collect::<Vec<_>>();
713            }
714            debug!("CleanupLocks::execute, meet locks:{}", locks.len());
715
716            let lock_size = locks.len();
717            match lock_resolver
718                .cleanup_locks(self.store.clone().unwrap(), locks, self.pd_client.clone())
719                .await
720            {
721                Ok(()) => {
722                    result.resolved_locks += lock_size;
723                }
724                Err(Error::ExtractedErrors(mut errors)) => {
725                    // Propagate errors to `retry_multi_region` for retry.
726                    if let Error::RegionError(e) = errors.pop().unwrap() {
727                        result.region_error = Some(*e);
728                    } else {
729                        result.key_error = Some(errors);
730                    }
731                    return Ok(result);
732                }
733                Err(e) => {
734                    return Err(e);
735                }
736            }
737
738            // TODO: improve backoff
739            // if self.backoff.is_none() {
740            //     return Err(Error::ResolveLockError);
741            // }
742        }
743
744        Ok(result)
745    }
746}
747
748/// When executed, the plan extracts errors from its inner plan, and returns an
749/// `Err` wrapping the error.
750///
751/// We usually need to apply this plan if (and only if) the output of the inner
752/// plan is of a response type.
753///
754/// The errors come from two places: `Err` from inner plans, and `Ok(response)`
755/// where `response` contains unresolved errors (`error` and `region_error`).
756pub struct ExtractError<P: Plan> {
757    pub inner: P,
758}
759
760impl<P: Plan> Clone for ExtractError<P> {
761    fn clone(&self) -> Self {
762        ExtractError {
763            inner: self.inner.clone(),
764        }
765    }
766}
767
768#[async_trait]
769impl<P: Plan> Plan for ExtractError<P>
770where
771    P::Result: HasKeyErrors + HasRegionErrors,
772{
773    type Result = P::Result;
774
775    async fn execute(&self) -> Result<Self::Result> {
776        let mut result = self.inner.execute().await?;
777        if let Some(errors) = result.key_errors() {
778            Err(Error::ExtractedErrors(errors))
779        } else if let Some(errors) = result.region_errors() {
780            Err(Error::ExtractedErrors(
781                errors
782                    .into_iter()
783                    .map(|e| Error::RegionError(Box::new(e)))
784                    .collect(),
785            ))
786        } else {
787            Ok(result)
788        }
789    }
790}
791
792/// When executed, the plan clones the shard and execute its inner plan, then
793/// returns `(shard, response)`.
794///
795/// It's useful when the information of shard are lost in the response but needed
796/// for processing.
797pub struct PreserveShard<P: Plan + Shardable> {
798    pub inner: P,
799    pub shard: Option<P::Shard>,
800}
801
802impl<P: Plan + Shardable> Clone for PreserveShard<P> {
803    fn clone(&self) -> Self {
804        PreserveShard {
805            inner: self.inner.clone(),
806            shard: None,
807        }
808    }
809}
810
811#[async_trait]
812impl<P> Plan for PreserveShard<P>
813where
814    P: Plan + Shardable,
815{
816    type Result = ResponseWithShard<P::Result, P::Shard>;
817
818    async fn execute(&self) -> Result<Self::Result> {
819        let res = self.inner.execute().await?;
820        let shard = self
821            .shard
822            .as_ref()
823            .expect("Unreachable: Shardable::apply_shard() is not called before executing PreserveShard")
824            .clone();
825        Ok(ResponseWithShard(res, shard))
826    }
827}
828
829// contains a response and the corresponding shards
830#[derive(Debug, Clone)]
831pub struct ResponseWithShard<Resp, Shard>(pub Resp, pub Shard);
832
833impl<Resp: HasKeyErrors, Shard> HasKeyErrors for ResponseWithShard<Resp, Shard> {
834    fn key_errors(&mut self) -> Option<Vec<Error>> {
835        self.0.key_errors()
836    }
837}
838
839impl<Resp: HasLocks, Shard> HasLocks for ResponseWithShard<Resp, Shard> {
840    fn take_locks(&mut self) -> Vec<kvrpcpb::LockInfo> {
841        self.0.take_locks()
842    }
843}
844
845impl<Resp: HasRegionError, Shard> HasRegionError for ResponseWithShard<Resp, Shard> {
846    fn region_error(&mut self) -> Option<errorpb::Error> {
847        self.0.region_error()
848    }
849}
850
851#[cfg(test)]
852mod test {
853    use futures::stream::BoxStream;
854    use futures::stream::{self};
855
856    use super::*;
857    use crate::mock::MockPdClient;
858    use crate::proto::kvrpcpb::BatchGetResponse;
859
860    #[derive(Clone)]
861    struct ErrPlan;
862
863    #[async_trait]
864    impl Plan for ErrPlan {
865        type Result = BatchGetResponse;
866
867        async fn execute(&self) -> Result<Self::Result> {
868            Err(Error::Unimplemented)
869        }
870    }
871
872    impl Shardable for ErrPlan {
873        type Shard = ();
874
875        fn shards(
876            &self,
877            _: &Arc<impl crate::pd::PdClient>,
878        ) -> BoxStream<'static, crate::Result<(Self::Shard, crate::store::RegionStore)>> {
879            Box::pin(stream::iter(1..=3).map(|_| Err(Error::Unimplemented))).boxed()
880        }
881
882        fn apply_shard(&mut self, _: Self::Shard, _: &crate::store::RegionStore) -> Result<()> {
883            Ok(())
884        }
885    }
886
887    #[tokio::test]
888    async fn test_err() {
889        let plan = RetryableMultiRegion {
890            inner: ResolveLock {
891                inner: ErrPlan,
892                backoff: Backoff::no_backoff(),
893                pd_client: Arc::new(MockPdClient::default()),
894            },
895            pd_client: Arc::new(MockPdClient::default()),
896            backoff: Backoff::no_backoff(),
897            preserve_region_results: false,
898        };
899        assert!(plan.execute().await.is_err())
900    }
901}