wash_lib/
spier.rs

1use std::task::Poll;
2
3use anyhow::Result;
4use chrono::{DateTime, Local};
5use futures::{Stream, StreamExt};
6use tracing::debug;
7
8/// A struct that represents an invocation that was observed by the spier.
9#[derive(Debug)]
10pub struct ObservedInvocation {
11    /// The timestamp when this was received
12    pub timestamp: DateTime<Local>,
13    /// The name or id of the entity that sent this invocation
14    pub from: String,
15    /// The name or id of the entity that received this invocation
16    pub to: String,
17    /// The operation that was invoked
18    pub operation: String,
19    /// The inner message that was received. We will attempt to parse the inner message from CBOR
20    /// and JSON into a JSON string and fall back to the raw bytes if we are unable to do so
21    pub message: ObservedMessage,
22}
23
24/// A inner message that we've seen in an invocation message. This will either be a raw bytes or a
25/// parsed value if it was a format we recognized.
26///
27/// Please note that this struct is meant for debugging, so its `Display` implementation does some
28/// heavier lifting like constructing strings from the raw bytes.
29#[derive(Debug)]
30pub enum ObservedMessage {
31    Raw(Vec<u8>),
32    Parsed(String),
33}
34
35impl std::fmt::Display for ObservedMessage {
36    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
37        match self {
38            ObservedMessage::Raw(bytes) => write!(f, "{}", String::from_utf8_lossy(bytes)),
39            ObservedMessage::Parsed(v) => {
40                write!(f, "{v}")
41            }
42        }
43    }
44}
45
46impl ObservedMessage {
47    #[must_use]
48    pub fn parse(data: Vec<u8>) -> Self {
49        Self::Parsed(String::from_utf8_lossy(&data).to_string())
50    }
51}
52
53/// A struct that can spy on the RPC messages sent to and from an component, consumable as a stream
54pub struct Spier {
55    stream: futures::stream::SelectAll<async_nats::Subscriber>,
56    component_id: String,
57    friendly_name: Option<String>,
58}
59
60impl Spier {
61    /// Creates a new Spier instance for the given component. Will return an error if the component cannot
62    /// be found or if there are connection issues
63    pub async fn new(
64        component_id: &str,
65        ctl_client: &wasmcloud_control_interface::Client,
66        nats_client: &async_nats::Client,
67    ) -> Result<Self> {
68        let linked_component = get_linked_components(component_id, ctl_client).await?;
69
70        let lattice = ctl_client.lattice();
71        let rpc_topic = format!("{lattice}.{component_id}.wrpc.>");
72        let component_stream = nats_client.subscribe(rpc_topic).await?;
73
74        let mut subs = futures::future::join_all(linked_component.iter().map(|prov| {
75            let topic = format!("{lattice}.{}.wrpc.>", &prov.id);
76            nats_client.subscribe(topic)
77        }))
78        .await
79        .into_iter()
80        .collect::<Result<Vec<_>, _>>()?;
81        subs.push(component_stream);
82
83        let stream = futures::stream::select_all(subs);
84
85        Ok(Self {
86            stream,
87            component_id: component_id.to_string(),
88            friendly_name: None,
89        })
90    }
91
92    /// Returns the component name, or id if no name is set, that this spier is spying on
93    pub fn component_id(&self) -> &str {
94        self.friendly_name
95            .as_deref()
96            .unwrap_or_else(|| self.component_id.as_ref())
97    }
98}
99
100impl Stream for Spier {
101    type Item = ObservedInvocation;
102    fn poll_next(
103        mut self: std::pin::Pin<&mut Self>,
104        cx: &mut std::task::Context<'_>,
105    ) -> std::task::Poll<Option<Self::Item>> {
106        match self.stream.poll_next_unpin(cx) {
107            Poll::Ready(None) => Poll::Ready(None),
108            Poll::Ready(Some(msg)) => {
109                // <lattice>.<component>.wrpc.0.0.1.<operation>@<versionX.Y.Z>.<function>
110                let mut subject_parts = msg.subject.split('.');
111                subject_parts.next(); // Skip the lattice
112                let component_id = subject_parts.next();
113                // Skip "wrpc.0.0.1", collect the rest
114                let operation = subject_parts.skip(4).collect::<Vec<_>>();
115
116                // The length assertion is to ensure that at least the `operation.function` is present since the
117                // version is technically optional.
118                if component_id.is_none() || operation.len() < 2 {
119                    debug!("Received invocation with invalid subject: {}", msg.subject);
120                    cx.waker().wake_by_ref();
121                    return Poll::Pending;
122                }
123                let component_id = component_id.unwrap();
124
125                let (from, to) = if component_id == self.component_id {
126                    // Attempt to get the source from the message header
127                    let from = msg
128                        .headers
129                        .and_then(|headers| headers.get("source-id").map(ToString::to_string))
130                        .unwrap_or_else(|| "linked component".to_string());
131                    (from, (*component_id).to_string())
132                } else {
133                    (self.component_id.to_string(), (*component_id).to_string())
134                };
135
136                // NOTE(thomastaylor312): Ideally we'd consume `msg.payload` above with a
137                // `Cursor` and `from_reader` and then manually reconstruct the acking using the
138                // message context, but I didn't want to waste time optimizing yet
139                Poll::Ready(Some(ObservedInvocation {
140                    timestamp: Local::now(),
141                    from,
142                    to,
143                    operation: operation.join("."),
144                    message: ObservedMessage::parse(msg.payload.to_vec()),
145                }))
146            }
147            Poll::Pending => Poll::Pending,
148        }
149    }
150}
151
152#[derive(Debug)]
153struct ProviderDetails {
154    id: String,
155}
156
157/// Fetches all components linked to the given component
158async fn get_linked_components(
159    component_id: &str,
160    ctl_client: &wasmcloud_control_interface::Client,
161) -> Result<Vec<ProviderDetails>> {
162    let details = ctl_client
163        .get_links()
164        .await
165        .map_err(|e| anyhow::anyhow!("Unable to get links: {e:?}"))
166        .map(|response| response.into_data())?
167        .map(|linkdefs| {
168            linkdefs
169                .into_iter()
170                .filter_map(|link| {
171                    if link.source_id() == component_id {
172                        Some(ProviderDetails {
173                            id: link.target().to_string(),
174                        })
175                    } else if link.target() == component_id {
176                        Some(ProviderDetails {
177                            id: link.source_id().to_string(),
178                        })
179                    } else {
180                        None
181                    }
182                })
183                .collect::<Vec<_>>()
184        })
185        .unwrap_or_default();
186
187    Ok(details)
188}