Skip to main content

trellis_core/
transaction.rs

1use crate::input::{StoredInput, boxed_input};
2use crate::transaction_trace_build::{scope_events, stable_node_union};
3use crate::{
4    Graph, GraphError, GraphResult, InputNode, NodeId, OutputKey, RebaselineReason, TransactionId,
5    transaction_types::{
6        AuditEntry, AuditEvent, StagedInputChange, StagedInputOutcome, TransactionOptions,
7        TransactionPhase, TransactionResult,
8    },
9};
10use std::collections::{BTreeMap, BTreeSet};
11
12/// Staged canonical input transaction.
13pub struct Transaction<'graph, C = ()> {
14    pub(crate) graph: &'graph mut Graph<C>,
15    pub(crate) working: Graph<C>,
16    id: TransactionId,
17    options: TransactionOptions,
18    staged_inputs: BTreeMap<NodeId, Box<dyn StoredInput>>,
19    pub(crate) staged_events: Vec<AuditEvent>,
20    pub(crate) staged_resource_planner_collections: Vec<NodeId>,
21    pub(crate) staged_output_rebaselines: BTreeMap<OutputKey, RebaselineReason>,
22    pub(crate) graph_mutated: bool,
23    pub(crate) failed: Option<GraphError>,
24    closed: bool,
25}
26
27impl<'graph, C> Transaction<'graph, C> {
28    pub(crate) fn new(
29        graph: &'graph mut Graph<C>,
30        id: TransactionId,
31        options: TransactionOptions,
32    ) -> Self
33    where
34        C: Clone,
35    {
36        let mut working = graph.clone();
37        working.transaction_open = false;
38        Self {
39            graph,
40            working,
41            id,
42            options,
43            staged_inputs: BTreeMap::new(),
44            staged_events: Vec::new(),
45            staged_resource_planner_collections: Vec::new(),
46            staged_output_rebaselines: BTreeMap::new(),
47            graph_mutated: false,
48            failed: None,
49            closed: false,
50        }
51    }
52
53    /// Returns this transaction's id.
54    pub fn id(&self) -> TransactionId {
55        self.id
56    }
57
58    /// Stages a typed canonical input change.
59    pub fn set_input<T>(&mut self, input: InputNode<T>, value: T) -> GraphResult<()>
60    where
61        T: Clone + PartialEq + Send + Sync + 'static,
62    {
63        self.set_input_by_id(input.id(), value)
64    }
65
66    /// Stages a canonical input change by node id.
67    pub fn set_input_by_id<T>(&mut self, node: NodeId, value: T) -> GraphResult<()>
68    where
69        T: Clone + PartialEq + Send + Sync + 'static,
70    {
71        self.ensure_open()?;
72        if let Err(error) = self.working.validate_input_write::<T>(node) {
73            self.failed.get_or_insert_with(|| error.clone());
74            return Err(error);
75        }
76        self.staged_inputs.insert(node, boxed_input(value));
77        Ok(())
78    }
79
80    /// Commits staged input changes atomically.
81    pub fn commit(&mut self) -> GraphResult<TransactionResult<C>>
82    where
83        C: Clone + PartialEq,
84    {
85        self.ensure_open()?;
86        let mut phase_trace = vec![TransactionPhase::StageOperations];
87        phase_trace.push(TransactionPhase::ValidateTransaction);
88        if let Some(error) = self.failed.clone() {
89            self.close();
90            return Err(error);
91        }
92
93        let mut changed_inputs = Vec::new();
94        for (node, staged) in &self.staged_inputs {
95            let changed = self
96                .working
97                .input_values
98                .get(node)
99                .is_none_or(|current| !current.equals(staged.as_ref()));
100            if changed || !self.options.skip_equal_inputs {
101                changed_inputs.push(*node);
102            }
103        }
104        let changed_input_set = changed_inputs.iter().copied().collect::<BTreeSet<_>>();
105        let staged_input_changes = self
106            .staged_inputs
107            .keys()
108            .map(|node| StagedInputChange {
109                node: *node,
110                outcome: if changed_input_set.contains(node) {
111                    StagedInputOutcome::Changed
112                } else {
113                    StagedInputOutcome::Unchanged
114                },
115            })
116            .collect::<Vec<_>>();
117
118        let next_revision = if changed_inputs.is_empty() && !self.graph_mutated {
119            self.graph.revision
120        } else {
121            self.graph.revision.next()
122        };
123
124        let mut audit_events = self.staged_events.clone();
125        for node in self.staged_inputs.keys() {
126            let event = if changed_input_set.contains(node) {
127                AuditEvent::InputChanged(*node)
128            } else {
129                AuditEvent::InputUnchanged(*node)
130            };
131            audit_events.push(event);
132        }
133
134        phase_trace.push(TransactionPhase::CommitCanonicalInputs);
135        for node in &changed_inputs {
136            if let Some(staged) = self.staged_inputs.get(node) {
137                self.working.input_values.insert(*node, staged.clone());
138                if let Some(meta) = self.working.nodes.get_mut(node) {
139                    meta.mark_changed(next_revision);
140                }
141            }
142        }
143        for event in &self.staged_events {
144            if let AuditEvent::NodeCreated(node) = event
145                && let Some(meta) = self.working.nodes.get_mut(node)
146            {
147                meta.mark_created(next_revision);
148            }
149        }
150        let created_nodes: Vec<NodeId> = self
151            .staged_events
152            .iter()
153            .filter_map(|event| match event {
154                AuditEvent::NodeCreated(node) => Some(*node),
155                _ => None,
156            })
157            .collect();
158        let dirty_roots = stable_node_union(changed_inputs.iter().copied().chain(created_nodes));
159        let mut initial_changed = dirty_roots.clone();
160        phase_trace.push(TransactionPhase::MarkDirtyNodes);
161        phase_trace.push(TransactionPhase::RecomputeDerivedNodes);
162        let derived_trace = match self.working.recompute_dirty_derived(&initial_changed) {
163            Ok(trace) => trace,
164            Err(error) => {
165                self.close();
166                return Err(error);
167            }
168        };
169        let recomputed_derived_nodes = derived_trace.recomputed;
170        let changed_derived_nodes = derived_trace.changed;
171        for node in &changed_derived_nodes {
172            if let Some(meta) = self.working.nodes.get_mut(node) {
173                meta.mark_changed(next_revision);
174            }
175            audit_events.push(AuditEvent::DerivedChanged(*node));
176        }
177        initial_changed.extend(changed_derived_nodes.iter().copied());
178        phase_trace.push(TransactionPhase::RecomputeCollectionNodes);
179        let collection_recompute = match self.working.recompute_dirty_collections(&initial_changed)
180        {
181            Ok(trace) => trace,
182            Err(error) => {
183                self.close();
184                return Err(error);
185            }
186        };
187        let recomputed_collection_nodes = collection_recompute.recomputed;
188        let changed_collection_nodes = collection_recompute.changed;
189        for node in &changed_collection_nodes {
190            if let Some(meta) = self.working.nodes.get_mut(node) {
191                meta.mark_changed(next_revision);
192            }
193            audit_events.push(AuditEvent::CollectionChanged(*node));
194        }
195        phase_trace.push(TransactionPhase::ComputeStructuralDiffs);
196        self.working
197            .baseline_collection_diffs(&self.staged_resource_planner_collections);
198        let collection_diffs = self
199            .working
200            .collection_diffs
201            .iter()
202            .map(|(node, diff)| diff.trace(*node))
203            .collect::<Vec<_>>();
204        phase_trace.push(TransactionPhase::ResolveScopeLifecycle);
205        let closed_scopes: Vec<_> = self
206            .staged_events
207            .iter()
208            .filter_map(|event| match event {
209                AuditEvent::ScopeClosed(scope) => Some(*scope),
210                _ => None,
211            })
212            .collect();
213        let scope_events = scope_events(&audit_events);
214        phase_trace.push(TransactionPhase::ProduceResourcePlans);
215        let resource_plan = match self.working.produce_resource_plan(&closed_scopes) {
216            Ok(plan) => plan,
217            Err(error) => {
218                self.close();
219                return Err(error);
220            }
221        };
222        let resource_coalescences = self.working.take_pending_resource_coalescences();
223        audit_events.extend(resource_coalescences.iter().map(|coalesced| {
224            AuditEvent::ResourceOpenCoalesced {
225                key: coalesced.key.clone(),
226                scope: coalesced.scope,
227                existing_owner_count: coalesced.existing_owner_count,
228            }
229        }));
230        let mut output_changed = initial_changed.clone();
231        output_changed.extend(changed_collection_nodes.iter().copied());
232        phase_trace.push(TransactionPhase::ProduceOutputFrames);
233        let output_frames = match self.working.produce_output_frames(
234            &output_changed,
235            &closed_scopes,
236            &self.staged_output_rebaselines,
237            self.id,
238            next_revision,
239        ) {
240            Ok(frames) => frames,
241            Err(error) => {
242                self.close();
243                return Err(error);
244            }
245        };
246        let audit_log = audit_events
247            .into_iter()
248            .map(|event| AuditEntry {
249                transaction_id: self.id,
250                revision: next_revision,
251                event,
252            })
253            .collect();
254        phase_trace.push(TransactionPhase::CommitGraphRevision);
255        self.working.revision = next_revision;
256        self.working.next_node_id = self.graph.next_node_id;
257        self.working.next_scope_id = self.graph.next_scope_id;
258        self.working.next_output_key = self.graph.next_output_key;
259
260        phase_trace.push(TransactionPhase::ReturnTransactionResult);
261        let result = TransactionResult {
262            transaction_id: self.id,
263            revision: next_revision,
264            staged_input_changes,
265            changed_inputs,
266            dirty_roots,
267            recomputed_derived_nodes,
268            changed_derived_nodes,
269            recomputed_collection_nodes,
270            changed_collection_nodes,
271            collection_diffs,
272            resource_plan,
273            resource_coalescences,
274            output_frames,
275            scope_events,
276            audit_log,
277            phase_trace,
278            invariant_results: Vec::new(),
279        };
280        self.working
281            .record_transaction_audit(&result, self.options.audit_explanations);
282        self.working.reclaim_closed_scopes(&closed_scopes);
283        std::mem::swap(self.graph, &mut self.working);
284        self.close();
285        Ok(result)
286    }
287
288    pub(crate) fn ensure_open(&self) -> GraphResult<()> {
289        if self.closed {
290            Err(GraphError::TransactionClosed(self.id))
291        } else {
292            Ok(())
293        }
294    }
295
296    fn close(&mut self) {
297        self.closed = true;
298        self.graph.transaction_open = false;
299    }
300}
301
302impl<C> Drop for Transaction<'_, C> {
303    fn drop(&mut self) {
304        if !self.closed {
305            self.graph.transaction_open = false;
306        }
307    }
308}