tc_scalar/reference/
mod.rs

1//! Utilities to reference to a `State` within a transaction, and resolve that [`TCRef`].
2
3use std::collections::HashSet;
4use std::convert::TryFrom;
5use std::fmt;
6use std::ops::Deref;
7
8use async_hash::{Digest, Hash, Output};
9use async_trait::async_trait;
10use destream::de::{self, Decoder, FromStream};
11use destream::en::{EncodeMap, Encoder, IntoStream, ToStream};
12use futures::TryFutureExt;
13use get_size::GetSize;
14use get_size_derive::*;
15use log::debug;
16use safecast::TryCastFrom;
17
18use tc_error::*;
19use tc_transact::public::{Public, StateInstance, ToState};
20use tcgeneric::*;
21
22use super::{OpDef, Scalar, Scope, Value};
23
24pub use after::After;
25pub use case::Case;
26pub use id::*;
27pub use op::*;
28pub use r#if::IfRef;
29pub use r#while::While;
30pub use with::With;
31
32mod after;
33mod case;
34mod r#if;
35mod r#while;
36mod with;
37
38pub mod id;
39pub mod op;
40
41const PREFIX: PathLabel = path_label(&["state", "scalar", "ref"]);
42
43/// Trait defining dependencies and a resolution method for a [`TCRef`].
44#[async_trait]
45pub trait Refer<State: StateInstance> {
46    /// Replace references to "$self" with the given relative path.
47    ///
48    /// This is used to control whether or not an OpDef will be replicated.
49    fn dereference_self(self, path: &TCPathBuf) -> Self;
50
51    /// Return `true` if this is a conditional reference (e.g. `If` or `Case`).
52    fn is_conditional(&self) -> bool;
53
54    /// Return `true` if this references a write operation to a cluster other than the path given.
55    fn is_inter_service_write(&self, cluster_path: &[PathSegment]) -> bool;
56
57    /// Return `true` if this state is a resolvable reference.
58    fn is_ref(&self) -> bool;
59
60    /// Replace the given relative path with "$self".
61    ///
62    /// This is used to control whether or not an OpDef will be replicated.
63    fn reference_self(self, path: &TCPathBuf) -> Self;
64
65    /// Add the dependency [`Id`]s of this reference to the given set.
66    fn requires(&self, deps: &mut HashSet<Id>);
67
68    /// Resolve this reference with respect to the given context.
69    async fn resolve<'a, T: ToState<State> + Public<State> + Instance>(
70        self,
71        context: &'a Scope<'a, State, T>,
72        txn: &'a State::Txn,
73    ) -> TCResult<State>;
74}
75
76/// The [`Class`] of a [`TCRef`].
77#[derive(Clone, Copy, Eq, PartialEq)]
78pub enum RefType {
79    After,
80    Case,
81    Id,
82    If,
83    Op(OpRefType),
84    While,
85    With,
86}
87
88impl Class for RefType {}
89
90impl NativeClass for RefType {
91    fn from_path(path: &[PathSegment]) -> Option<Self> {
92        if path.len() == 4 && &path[0..3] == &PREFIX[..] {
93            match path[3].as_str() {
94                "after" => Some(Self::After),
95                "case" => Some(Self::Case),
96                "id" => Some(Self::Id),
97                "if" => Some(Self::If),
98                "while" => Some(Self::While),
99                "with" => Some(Self::With),
100                _ => None,
101            }
102        } else if let Some(ort) = OpRefType::from_path(path) {
103            Some(RefType::Op(ort))
104        } else {
105            None
106        }
107    }
108
109    fn path(&self) -> TCPathBuf {
110        let suffix = match self {
111            Self::After => "after",
112            Self::Case => "case",
113            Self::Id => "id",
114            Self::If => "if",
115            Self::Op(ort) => return ort.path(),
116            Self::While => "while",
117            Self::With => "with",
118        };
119
120        TCPathBuf::from(PREFIX).append(label(suffix))
121    }
122}
123
124impl fmt::Debug for RefType {
125    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
126        match self {
127            Self::After => f.write_str("After"),
128            Self::Case => f.write_str("Case"),
129            Self::Id => f.write_str("Id"),
130            Self::If => f.write_str("If"),
131            Self::Op(ort) => fmt::Debug::fmt(ort, f),
132            Self::While => f.write_str("While"),
133            Self::With => f.write_str("With"),
134        }
135    }
136}
137
138/// A reference to a `State`.
139#[derive(Clone, Eq, PartialEq, GetSize)]
140pub enum TCRef {
141    After(Box<After>),
142    Case(Box<Case>),
143    Id(IdRef),
144    If(Box<IfRef>),
145    Op(OpRef),
146    While(Box<While>),
147    With(Box<With>),
148}
149
150impl Instance for TCRef {
151    type Class = RefType;
152
153    fn class(&self) -> Self::Class {
154        match self {
155            Self::After(_) => RefType::After,
156            Self::Case(_) => RefType::Case,
157            Self::Id(_) => RefType::Id,
158            Self::If(_) => RefType::If,
159            Self::Op(op_ref) => RefType::Op(op_ref.class()),
160            Self::While(_) => RefType::While,
161            Self::With(_) => RefType::With,
162        }
163    }
164}
165
166#[async_trait]
167impl<State> Refer<State> for TCRef
168where
169    State: StateInstance + Refer<State> + From<Scalar>,
170    State::Closure: From<(Map<State>, OpDef)> + TryCastFrom<State>,
171    Map<State>: TryFrom<State, Error = TCError>,
172    Value: TryFrom<State, Error = TCError> + TryCastFrom<State>,
173    bool: TryCastFrom<State>,
174{
175    fn dereference_self(self, path: &TCPathBuf) -> Self {
176        match self {
177            Self::After(after) => {
178                let after = after.dereference_self(path);
179                Self::After(Box::new(after))
180            }
181            Self::Case(case) => {
182                let case = case.dereference_self(path);
183                Self::Case(Box::new(case))
184            }
185            Self::Id(id_ref) => Self::Id(Refer::<State>::dereference_self(id_ref, path)),
186            Self::If(if_ref) => {
187                let if_ref = if_ref.dereference_self(path);
188                Self::If(Box::new(if_ref))
189            }
190            Self::Op(op_ref) => Self::Op(op_ref.dereference_self(path)),
191            Self::While(while_ref) => {
192                let while_ref = while_ref.dereference_self(path);
193                Self::While(Box::new(while_ref))
194            }
195            Self::With(with) => {
196                let with = with.dereference_self(path);
197                Self::With(Box::new(with))
198            }
199        }
200    }
201
202    fn is_conditional(&self) -> bool {
203        match self {
204            Self::After(after) => after.is_conditional(),
205            Self::Case(case) => case.is_conditional(),
206            Self::Id(id_ref) => Refer::<State>::is_conditional(id_ref),
207            Self::If(if_ref) => if_ref.is_conditional(),
208            Self::Op(op_ref) => op_ref.is_conditional(),
209            Self::While(while_ref) => while_ref.is_conditional(),
210            Self::With(with) => with.is_conditional(),
211        }
212    }
213
214    fn is_inter_service_write(&self, cluster_path: &[PathSegment]) -> bool {
215        match self {
216            Self::After(after) => after.is_inter_service_write(cluster_path),
217            Self::Case(case) => case.is_inter_service_write(cluster_path),
218            Self::Id(id_ref) => Refer::<State>::is_inter_service_write(id_ref, cluster_path),
219            Self::If(if_ref) => if_ref.is_inter_service_write(cluster_path),
220            Self::Op(op_ref) => op_ref.is_inter_service_write(cluster_path),
221            Self::While(while_ref) => while_ref.is_inter_service_write(cluster_path),
222            Self::With(with) => with.is_inter_service_write(cluster_path),
223        }
224    }
225
226    fn is_ref(&self) -> bool {
227        true
228    }
229
230    fn reference_self(self, path: &TCPathBuf) -> Self {
231        match self {
232            Self::After(after) => {
233                let after = after.reference_self(path);
234                Self::After(Box::new(after))
235            }
236            Self::Case(case) => {
237                let case = case.reference_self(path);
238                Self::Case(Box::new(case))
239            }
240            Self::Id(id_ref) => Self::Id(Refer::<State>::reference_self(id_ref, path)),
241            Self::If(if_ref) => {
242                let if_ref = if_ref.reference_self(path);
243                Self::If(Box::new(if_ref))
244            }
245            Self::Op(op_ref) => Self::Op(op_ref.reference_self(path)),
246            Self::While(while_ref) => {
247                let while_ref = while_ref.reference_self(path);
248                Self::While(Box::new(while_ref))
249            }
250            Self::With(with) => {
251                let with = with.reference_self(path);
252                Self::With(Box::new(with))
253            }
254        }
255    }
256
257    fn requires(&self, deps: &mut HashSet<Id>) {
258        match self {
259            Self::After(after) => after.requires(deps),
260            Self::Case(case) => case.requires(deps),
261            Self::Id(id_ref) => Refer::<State>::requires(id_ref, deps),
262            Self::If(if_ref) => if_ref.requires(deps),
263            Self::Op(op_ref) => op_ref.requires(deps),
264            Self::While(while_ref) => while_ref.requires(deps),
265            Self::With(with) => with.requires(deps),
266        }
267    }
268
269    async fn resolve<'a, T: ToState<State> + Public<State> + Instance>(
270        self,
271        context: &'a Scope<'a, State, T>,
272        txn: &'a State::Txn,
273    ) -> TCResult<State> {
274        debug!("TCRef::resolve {:?}", self);
275
276        match self {
277            Self::After(after) => after.resolve(context, txn).await,
278            Self::Case(case) => case.resolve(context, txn).await,
279            Self::Id(id_ref) => Refer::<State>::resolve(id_ref, context, txn).await,
280            Self::If(if_ref) => if_ref.resolve(context, txn).await,
281            Self::Op(op_ref) => op_ref.resolve(context, txn).await,
282            Self::While(while_ref) => while_ref.resolve(context, txn).await,
283            Self::With(with) => with.resolve(context, txn).await,
284        }
285    }
286}
287
288impl<'a, D: Digest> Hash<D> for &'a TCRef {
289    fn hash(self) -> Output<D> {
290        match self {
291            TCRef::After(after) => Hash::<D>::hash(after.deref()),
292            TCRef::Case(case) => Hash::<D>::hash(case.deref()),
293            TCRef::Id(id) => Hash::<D>::hash(id),
294            TCRef::If(if_ref) => Hash::<D>::hash(if_ref.deref()),
295            TCRef::Op(op) => Hash::<D>::hash(op),
296            TCRef::While(while_ref) => Hash::<D>::hash(while_ref.deref()),
297            TCRef::With(with) => Hash::<D>::hash(with.deref()),
298        }
299    }
300}
301
302impl TryFrom<TCRef> for OpRef {
303    type Error = TCError;
304
305    fn try_from(tc_ref: TCRef) -> TCResult<Self> {
306        match tc_ref {
307            TCRef::Op(op_ref) => Ok(op_ref),
308            other => Err(TCError::unexpected(other, "an OpRef")),
309        }
310    }
311}
312
313impl TryCastFrom<TCRef> for Id {
314    fn can_cast_from(tc_ref: &TCRef) -> bool {
315        match tc_ref {
316            TCRef::Id(_) => true,
317            _ => false,
318        }
319    }
320
321    fn opt_cast_from(tc_ref: TCRef) -> Option<Self> {
322        match tc_ref {
323            TCRef::Id(id_ref) => Some(id_ref.into_id()),
324            _ => None,
325        }
326    }
327}
328
329impl TryCastFrom<TCRef> for OpRef {
330    fn can_cast_from(tc_ref: &TCRef) -> bool {
331        match tc_ref {
332            TCRef::Op(_) => true,
333            _ => false,
334        }
335    }
336
337    fn opt_cast_from(tc_ref: TCRef) -> Option<Self> {
338        match tc_ref {
339            TCRef::Op(op) => Some(op),
340            _ => None,
341        }
342    }
343}
344
345/// A helper struct used to deserialize a [`TCRef`]
346pub struct RefVisitor;
347
348impl RefVisitor {
349    /// Deserialize a map value, assuming it's an instance of the given [`RefType`].
350    pub async fn visit_map_value<A: de::MapAccess>(
351        class: RefType,
352        access: &mut A,
353    ) -> Result<TCRef, A::Error> {
354        match class {
355            RefType::After => {
356                access
357                    .next_value(())
358                    .map_ok(Box::new)
359                    .map_ok(TCRef::After)
360                    .await
361            }
362            RefType::Case => {
363                access
364                    .next_value(())
365                    .map_ok(Box::new)
366                    .map_ok(TCRef::Case)
367                    .await
368            }
369            RefType::Id => access.next_value(()).map_ok(TCRef::Id).await,
370            RefType::If => {
371                access
372                    .next_value(())
373                    .map_ok(Box::new)
374                    .map_ok(TCRef::If)
375                    .await
376            }
377            RefType::Op(ort) => {
378                OpRefVisitor::visit_map_value(ort, access)
379                    .map_ok(TCRef::Op)
380                    .await
381            }
382            RefType::While => {
383                access
384                    .next_value(())
385                    .map_ok(Box::new)
386                    .map_ok(TCRef::While)
387                    .await
388            }
389            RefType::With => {
390                access
391                    .next_value(())
392                    .map_ok(Box::new)
393                    .map_ok(TCRef::With)
394                    .await
395            }
396        }
397    }
398
399    /// Deserialize a [`TCRef`] with the given `subject`.
400    pub fn visit_ref_value<E: de::Error>(subject: Subject, params: Scalar) -> Result<TCRef, E> {
401        if params.is_none() {
402            match subject {
403                Subject::Link(link) => Err(de::Error::invalid_type(link, &"a Ref")),
404                Subject::Ref(id_ref, path) if path.is_empty() => Ok(TCRef::Id(id_ref)),
405                Subject::Ref(id_ref, path) => Ok(TCRef::Op(OpRef::Get((
406                    Subject::Ref(id_ref, path),
407                    Value::default().into(),
408                )))),
409            }
410        } else {
411            OpRefVisitor::visit_ref_value(subject, params).map(TCRef::Op)
412        }
413    }
414}
415
416#[async_trait]
417impl de::Visitor for RefVisitor {
418    type Value = TCRef;
419
420    fn expecting() -> &'static str {
421        "a Ref, like {\"$subject\": []} or {\"/path/to/op\": [\"key\"]"
422    }
423
424    async fn visit_map<A: de::MapAccess>(self, mut access: A) -> Result<Self::Value, A::Error> {
425        let subject = access.next_key::<Subject>(()).await?;
426
427        let subject =
428            subject.ok_or_else(|| de::Error::custom("expected a Ref or Link, found empty map"))?;
429
430        if let Subject::Link(link) = &subject {
431            if link.host().is_none() {
432                if let Some(class) = RefType::from_path(link.path()) {
433                    debug!("RefVisitor visiting instance of {:?}...", class);
434                    return Self::visit_map_value(class, &mut access).await;
435                }
436            }
437        }
438
439        let params = access.next_value(()).await?;
440        Self::visit_ref_value(subject, params)
441    }
442}
443
444#[async_trait]
445impl FromStream for TCRef {
446    type Context = ();
447
448    async fn from_stream<D: Decoder>(_: (), d: &mut D) -> Result<Self, <D as Decoder>::Error> {
449        d.decode_map(RefVisitor).await
450    }
451}
452
453impl<'en> ToStream<'en> for TCRef {
454    fn to_stream<E: Encoder<'en>>(&'en self, e: E) -> Result<E::Ok, E::Error> {
455        if let Self::Id(id_ref) = self {
456            return id_ref.to_stream(e);
457        } else if let Self::Op(op_ref) = self {
458            return op_ref.to_stream(e);
459        };
460
461        let mut map = e.encode_map(Some(1))?;
462
463        map.encode_key(self.class().path().to_string())?;
464        match self {
465            Self::Id(_) => unreachable!("TCRef::Id to_stream"),
466            Self::Op(_) => unreachable!("TCRef::Op to_stream"),
467
468            Self::After(after) => map.encode_value(after),
469            Self::Case(case) => map.encode_value(case),
470            Self::If(if_ref) => map.encode_value(if_ref),
471            Self::While(while_ref) => map.encode_value(while_ref),
472            Self::With(with) => map.encode_value(with),
473        }?;
474
475        map.end()
476    }
477}
478
479impl<'en> IntoStream<'en> for TCRef {
480    fn into_stream<E: Encoder<'en>>(self, e: E) -> Result<E::Ok, E::Error> {
481        if let Self::Id(id_ref) = self {
482            return id_ref.into_stream(e);
483        } else if let Self::Op(op_ref) = self {
484            return op_ref.into_stream(e);
485        };
486
487        let mut map = e.encode_map(Some(1))?;
488
489        map.encode_key(self.class().path().to_string())?;
490        match self {
491            Self::Id(_) => unreachable!("TCRef::Id into_stream"),
492            Self::Op(_) => unreachable!("TCRef::Op into_stream"),
493
494            Self::After(after) => map.encode_value(after),
495            Self::Case(case) => map.encode_value(case),
496            Self::If(if_ref) => map.encode_value(if_ref),
497            Self::While(while_ref) => map.encode_value(while_ref),
498            Self::With(with) => map.encode_value(with),
499        }?;
500
501        map.end()
502    }
503}
504
505impl fmt::Debug for TCRef {
506    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
507        match self {
508            Self::After(after) => fmt::Debug::fmt(after, f),
509            Self::Case(case) => fmt::Debug::fmt(case, f),
510            Self::Id(id_ref) => fmt::Debug::fmt(id_ref, f),
511            Self::If(if_ref) => fmt::Debug::fmt(if_ref, f),
512            Self::Op(op_ref) => fmt::Debug::fmt(op_ref, f),
513            Self::While(while_ref) => fmt::Debug::fmt(while_ref, f),
514            Self::With(with) => fmt::Debug::fmt(with, f),
515        }
516    }
517}