dfir_lang/graph/graph_algorithms.rs
1//! General graph algorithm utility functions
2
3use std::collections::{BTreeSet, HashMap, HashSet};
4use std::hash::Hash;
5
6use slotmap::{Key, SecondaryMap, SparseSecondaryMap};
7
8/// Topologically sorts a set of nodes. Returns a list where the order of `Id`s will agree with
9/// the order of any path through the graph.
10///
11/// This succeeds if the input is a directed acyclic graph (DAG).
12///
13/// If the input has a cycle, an `Err` will be returned containing the cycle. Each node in the
14/// cycle will be listed exactly once.
15///
16/// <https://en.wikipedia.org/wiki/Topological_sorting#Depth-first_search>
17pub fn topo_sort<Id, PredsIter>(
18 node_ids: impl IntoIterator<Item = Id>,
19 mut preds_fn: impl FnMut(Id) -> PredsIter,
20) -> Result<Vec<Id>, Vec<Id>>
21where
22 Id: Copy + Eq + Hash,
23 PredsIter: IntoIterator<Item = Id>,
24{
25 let (mut marked, mut order) = Default::default();
26
27 fn pred_dfs_postorder<Id, PredsIter>(
28 node_id: Id,
29 preds_fn: &mut impl FnMut(Id) -> PredsIter,
30 marked: &mut HashMap<Id, bool>, // `false` => temporary, `true` => permanent.
31 order: &mut Vec<Id>,
32 ) -> Result<(), ()>
33 where
34 Id: Copy + Eq + Hash,
35 PredsIter: IntoIterator<Item = Id>,
36 {
37 match marked.get(&node_id) {
38 Some(_permanent @ true) => Ok(()),
39 Some(_temporary @ false) => {
40 // Cycle found!
41 order.clear();
42 order.push(node_id);
43 Err(())
44 }
45 None => {
46 marked.insert(node_id, false);
47 for next_pred in (preds_fn)(node_id) {
48 pred_dfs_postorder(next_pred, preds_fn, marked, order).map_err(|()| {
49 if order.len() == 1 || order.first().unwrap() != order.last().unwrap() {
50 order.push(node_id);
51 }
52 })?;
53 }
54 order.push(node_id);
55 marked.insert(node_id, true);
56 Ok(())
57 }
58 }
59 }
60
61 for node_id in node_ids {
62 if pred_dfs_postorder(node_id, &mut preds_fn, &mut marked, &mut order).is_err() {
63 // Cycle found.
64 let end = order.last().unwrap();
65 let beg = order.iter().position(|n| n == end).unwrap();
66 order.drain(0..=beg);
67 return Err(order);
68 }
69 }
70
71 Ok(order)
72}
73
74/// Datastructure for merging subgraphs while maintaining topological sort order.
75///
76/// Maintains a global topo-sorted Vec of all operators. Each subgraph (merged group)
77/// occupies a contiguous range in this Vec. Merging two groups combines their ranges
78/// and re-sorts the affected window so groups remain contiguous and correctly ordered.
79pub struct SubgraphMerge<K>
80where
81 K: Key,
82{
83 /// Predecessor edges in the quotient DAG (per representative).
84 subgraph_preds: SecondaryMap<K, Vec<K>>,
85 /// All operators in global topo-sort order (fixed length, reshuffled in windows).
86 /// Invariant: subgraphs are contiguous & non-overlapping ranges in this vec.
87 toposort_node: Vec<K>,
88 /// Reverse index: SG representative node -> index (in toposort_node).
89 /// Invariant: `K` is both the representative node and the first node in the SG.
90 sg_idx: SparseSecondaryMap<K, usize>,
91 /// SG representative node -> SG len.
92 /// The subgraph's nodes are `toposort_node[index..index+len]`.
93 /// Invariant: the subgraph ranges are complete and non-overlapping.
94 sg_len: SparseSecondaryMap<K, usize>,
95
96 /// Union-find for subgraph membership.
97 subgraph_unionfind: crate::union_find::UnionFind<K>,
98
99 /// Per-representative: set of representatives that this subgraph must not merge with.
100 /// Maintained symmetrically: if `enemies[a]` contains `b`, then `enemies[b]` contains `a`.
101 enemies: SecondaryMap<K, HashSet<K>>,
102}
103
104impl<K> SubgraphMerge<K>
105where
106 K: Key,
107{
108 /// Creates a new `SubgraphMerge` from nodes and their predecessor edges.
109 ///
110 /// `enemies` specifies pairs of nodes that must never be placed in the same subgraph.
111 /// These are checked in O(1) during [`Self::try_merge`] and maintained as representatives
112 /// change.
113 ///
114 /// Returns `Err` with a cycle if the input graph is not a DAG.
115 pub fn new<PredsIter>(
116 keys: impl IntoIterator<Item = K>,
117 mut preds_fn: impl FnMut(K) -> PredsIter,
118 enemies_iter: impl IntoIterator<Item = (K, K)>,
119 ) -> Result<Self, Vec<K>>
120 where
121 PredsIter: IntoIterator<Item = K>,
122 {
123 let subgraph_preds = keys
124 .into_iter()
125 .map(|k| (k, (preds_fn)(k).into_iter().collect()))
126 .collect::<SecondaryMap<K, Vec<K>>>();
127 let toposort_node =
128 topo_sort(subgraph_preds.keys(), |k| subgraph_preds[k].iter().copied())?;
129 let sg_idx = toposort_node
130 .iter()
131 .enumerate()
132 .map(|(i, &k)| (k, i))
133 .collect();
134 let sg_len = toposort_node.iter().map(|&k| (k, 1)).collect();
135 let subgraph_unionfind = crate::union_find::UnionFind::with_capacity(toposort_node.len());
136
137 let mut enemies = SecondaryMap::<K, HashSet<K>>::new();
138 for (a, b) in enemies_iter {
139 assert_ne!(a, b, "no-merge pair must not contain the same node twice");
140 enemies.entry(a).unwrap().or_default().insert(b);
141 enemies.entry(b).unwrap().or_default().insert(a);
142 }
143
144 Ok(Self {
145 subgraph_preds,
146 toposort_node,
147 sg_idx,
148 sg_len,
149 subgraph_unionfind,
150 enemies,
151 })
152 }
153
154 /// Find the representative of the subgraph containing `k`.
155 pub fn find(&mut self, k: K) -> K {
156 self.subgraph_unionfind.find(k)
157 }
158
159 /// Returns true if `u` and `v` are in the same subgraph.
160 pub fn same_set(&mut self, u: K, v: K) -> bool {
161 self.subgraph_unionfind.same_set(u, v)
162 }
163
164 /// Iterates all subgraph representatives with their topo-sorted operator slices,
165 /// in topological order (by position in `toposort_node`).
166 pub fn subgraphs(&self) -> impl Iterator<Item = &[K]> {
167 let mut i = 0;
168 std::iter::from_fn(move || {
169 let Some(&sg_node) = self.toposort_node.get(i) else {
170 debug_assert_eq!(i, self.toposort_node.len());
171 return None;
172 };
173 debug_assert_eq!(i, self.sg_idx[sg_node]);
174 let sg_len = self.sg_len[sg_node];
175 let sg_slice = &self.toposort_node[i..i + sg_len];
176 i += sg_len;
177 Some(sg_slice)
178 })
179 }
180
181 /// Attempts to merge the subgraphs containing `u` and `v`.
182 /// Returns `false` if merging would create a cycle in the subgraph DAG,
183 /// or if the merge is forbidden by a no-merge constraint.
184 pub fn try_merge(&mut self, u: K, v: K) -> bool {
185 // 0. Set up `u` and `v` to be in order, and subgraph representatives.
186
187 // Ensure `u` and `v` are subgraph representatives.
188 let u = self.subgraph_unionfind.find(u);
189 let v = self.subgraph_unionfind.find(v);
190 if u == v {
191 // Short circuit no-op case. Guards against weird `u == v` aliasing.
192 return true;
193 }
194
195 // O(1) no-merge constraint check.
196 if self
197 .enemies
198 .get(u)
199 .is_some_and(|enemy_set| enemy_set.contains(&v))
200 {
201 return false;
202 }
203
204 // Ensure `u` is before `v` in topo order.
205 let (u, v) = if self.sg_idx[u] < self.sg_idx[v] {
206 (u, v)
207 } else {
208 (v, u)
209 };
210 // Get the member nodes of `u` and `v`, and the `window`. Pulling references here does ensure that
211 // `toposort_node` remains unchanged until we properly merge `u_nodes` and `v_nodes`.
212 let (u_nodes, v_nodes, window) = {
213 let (u_idx, u_len) = (self.sg_idx[u], self.sg_len[u]);
214 let (v_idx, v_len) = (self.sg_idx[v], self.sg_len[v]);
215 (
216 &self.toposort_node[u_idx..u_idx + u_len],
217 &self.toposort_node[v_idx..v_idx + v_len],
218 u_idx..v_idx + v_len,
219 )
220 };
221
222 // 1. Cycle check: can `v` reach `u` via predecessor edges?
223 // Only groups within `window` can be on such a path. Direct predecessor edges from `v` to `u` become
224 // self-loops after merge and are not real cycles, so we skip direct `u -> v` edges.
225
226 let mut stack = vec![v];
227 let mut visited = HashSet::<_>::from_iter([v]);
228
229 while let Some(x) = stack.pop() {
230 for &p in self.subgraph_preds[x].iter() {
231 let root_p = self.subgraph_unionfind.find(p);
232
233 if root_p == u {
234 if x == v {
235 // Ignore `u -> v` direct edge, not a real cycle.
236 continue;
237 }
238 // Cycle found, return false.
239 return false;
240 }
241
242 // Prune: group must be within the `window`.
243 if window.contains(&self.sg_idx[root_p]) && visited.insert(root_p) {
244 stack.push(root_p);
245 }
246 }
247 }
248
249 // 2. Perform merge in union-find and append predecessors.
250 // `u` will be the new representative.
251 {
252 // `UnionFind::union` ensures the first arg's representative will represent the new merged group. `u` is before
253 // `v` in the topo order, and `u` is already its own representative. This ensures that `u` stays at the *start*
254 // of its subgraph group, so the `idx..idx+len` slice is the whole subgraph.
255 let _new_root = self.subgraph_unionfind.union(u, v);
256 debug_assert_eq!(u, _new_root);
257 let v_preds = &mut self.subgraph_preds.remove(v).unwrap();
258 let u_preds = &mut self.subgraph_preds[u];
259 u_preds.append(v_preds);
260 // Update all preds to be representatives (from past unioning). Delete any self-edges.
261 u_preds.retain_mut(|x| {
262 *x = self.subgraph_unionfind.find(*x);
263 *x != u // Retain only non-self edges.
264 });
265 // Remove any duplicates (may have be created from past unioning).
266 u_preds.sort_unstable();
267 u_preds.dedup();
268 }
269 // Remove subsumed `v` and grow `u`'s length.
270 {
271 self.sg_idx.remove(v).unwrap();
272 let v_len = self.sg_len.remove(v).unwrap();
273 // Set `u`'s len to the combined size. (Note: `sg_idx[u]` still needs updating, below after re-sort).
274 self.sg_len[u] += v_len;
275 }
276 // Merge enemies: remap v's enemies to point to u.
277 for w in self.enemies.remove(v).into_iter().flatten() {
278 debug_assert_ne!(
279 w, u,
280 "`w` in an enemy of `v`, so it can't be `w == u`, since we are merging `u` and `v`"
281 );
282 // Add `w`` to `u`'s enemies.
283 self.enemies.entry(u).unwrap().or_default().insert(w);
284 // Add `u` to `w`'s enemies. Remove `v`.
285 // `w` enemies guaranteed to exist by the symmetric invariant: if `v`'s enemies contain `w``, then `w`'s
286 // enemies contain `v`.
287 let w_enemies = self.enemies.get_mut(w).unwrap();
288 let _removed = w_enemies.remove(&v);
289 debug_assert!(_removed);
290 w_enemies.insert(u);
291 }
292
293 // 3. Re-sort groups in `window`.
294 // Topo-sort groups in the window by their quotient edges.
295 {
296 let sorted_groups = {
297 let reps_in_window = self.toposort_node[window.clone()]
298 .iter()
299 .map(|&k| self.subgraph_unionfind.find(k))
300 .collect::<BTreeSet<_>>();
301
302 // We borrow fields separately to allow the closure to call `find()` (which needs `&mut`) while also reading
303 // `subgraph_preds` and `sg_idx` (via `&`).
304 // Only predecessor groups whose range overlaps the window are included - groups entirely outside the window
305 // have their ordering already satisfied.
306 let subgraph_preds = &self.subgraph_preds;
307 let subgraph_unionfind = &mut self.subgraph_unionfind;
308 let sg_idx = &self.sg_idx;
309 topo_sort(reps_in_window, |k| {
310 subgraph_preds[k]
311 .iter()
312 .map(|&p| subgraph_unionfind.find(p))
313 .filter(|&p| window.contains(&sg_idx[p])) // Prune to window.
314 .collect::<Vec<_>>()
315 .into_iter()
316 })
317 .expect("bug: cycle check passed but re-toposort found cycle")
318 };
319
320 // Rebuild the window: lay out each group's operators in sorted group order.
321 // All groups except `u` (new root) have contiguous operators at their current range. `u`'s operators will be
322 // `u_nodes` *and* `v_nodes`.
323 let mut buf = Vec::with_capacity(window.len());
324 for &group in &sorted_groups {
325 if group == u {
326 buf.extend_from_slice(u_nodes);
327 buf.extend_from_slice(v_nodes);
328 } else {
329 let g_idx = self.sg_idx[group];
330 let g_len = self.sg_len[group];
331 buf.extend_from_slice(&self.toposort_node[g_idx..g_idx + g_len]);
332 }
333 }
334 self.toposort_node[window.clone()].copy_from_slice(&buf);
335
336 // Update reverse index `sg_idx` start positions (`sg_len` already correct).
337 let mut pos = window.start;
338 for &group in &sorted_groups {
339 self.sg_idx[group] = pos;
340 pos += self.sg_len[group];
341 }
342 debug_assert_eq!(window.end, pos);
343 }
344
345 true
346 }
347}
348
349#[cfg(test)]
350mod test {
351 use std::collections::{BTreeMap, BTreeSet};
352
353 use itertools::Itertools;
354 use slotmap::SlotMap;
355
356 use super::*;
357
358 #[test]
359 pub fn test_toposort() {
360 let edges = [
361 (5, 11),
362 (11, 2),
363 (11, 9),
364 (11, 10),
365 (7, 11),
366 (7, 8),
367 (8, 9),
368 (3, 8),
369 (3, 10),
370 ];
371
372 // https://commons.wikimedia.org/wiki/File:Directed_acyclic_graph_2.svg
373 let sort = topo_sort([2, 3, 5, 7, 8, 9, 10, 11], |v| {
374 edges
375 .iter()
376 .filter(move |&&(_, dst)| v == dst)
377 .map(|&(src, _)| src)
378 });
379 assert!(
380 sort.is_ok(),
381 "Did not expect cycle: {:?}",
382 sort.unwrap_err()
383 );
384
385 let sort = sort.unwrap();
386 println!("{:?}", sort);
387
388 let position: BTreeMap<_, _> = sort.iter().enumerate().map(|(i, &x)| (x, i)).collect();
389 for (src, dst) in edges.iter() {
390 assert!(position[src] < position[dst]);
391 }
392 }
393
394 #[test]
395 pub fn test_toposort_cycle() {
396 // https://commons.wikimedia.org/wiki/File:Directed_graph,_cyclic.svg
397 // ┌────►C──────┐
398 // │ │
399 // │ ▼
400 // A───────►B E ─────►F
401 // ▲ │
402 // │ │
403 // └─────D◄─────┘
404 let edges = [
405 ('A', 'B'),
406 ('B', 'C'),
407 ('C', 'E'),
408 ('D', 'B'),
409 ('E', 'F'),
410 ('E', 'D'),
411 ];
412 let ids = edges
413 .iter()
414 .flat_map(|&(a, b)| [a, b])
415 .collect::<BTreeSet<_>>();
416 let cycle_rotations = BTreeSet::from_iter([
417 ['B', 'C', 'E', 'D'],
418 ['C', 'E', 'D', 'B'],
419 ['E', 'D', 'B', 'C'],
420 ['D', 'B', 'C', 'E'],
421 ]);
422
423 let permutations = ids.iter().copied().permutations(ids.len());
424 for permutation in permutations {
425 let result = topo_sort(permutation.iter().copied(), |v| {
426 edges
427 .iter()
428 .filter(move |&&(_, dst)| v == dst)
429 .map(|&(src, _)| src)
430 });
431 assert!(result.is_err());
432 let cycle = result.unwrap_err();
433 assert!(
434 cycle_rotations.contains(&*cycle),
435 "cycle: {:?}, vertex order: {:?}",
436 cycle,
437 permutation
438 );
439 }
440 }
441
442 #[test]
443 pub fn test_subgraph_merge_basic() {
444 let mut preds = SlotMap::new();
445
446 let a = preds.insert(vec![]);
447 let b = preds.insert(vec![]);
448 let c = preds.insert(vec![]);
449 let d = preds.insert(vec![]);
450 let e = preds.insert(vec![]);
451 let f = preds.insert(vec![]);
452
453 preds[b].push(a);
454 preds[c].push(b);
455 preds[d].push(b);
456 preds[e].push(c);
457 preds[e].push(d);
458 preds[f].push(e);
459
460 let mut merge = SubgraphMerge::new(
461 preds.keys(),
462 |v| preds[v].iter().copied(),
463 std::iter::empty(),
464 )
465 .unwrap();
466
467 assert!(merge.try_merge(a, a)); // No-op.
468 // ┌──► C ──┐
469 // │ ▼
470 // A ───► B E ───► F
471 // │ ▲
472 // └──► D ──┘
473 assert!(merge.try_merge(b, c));
474 assert!(merge.try_merge(b, c)); // No-op.
475 // A ───► BC ────► E ───► F
476 // │ ▲
477 // └──► D ──┘
478 assert!(!merge.try_merge(c, e)); // Rejected due to `D` outside-cycle.
479
480 assert!(merge.try_merge(d, e));
481 assert!(merge.try_merge(c, e)); // Now valid since `D` is no longer outside.
482 }
483
484 #[test]
485 pub fn test_subgraph_merge_enemies() {
486 let mut preds = SlotMap::new();
487
488 // A ───► B ───► C ───► D
489 let a = preds.insert(vec![]);
490 let b = preds.insert(vec![]);
491 let c = preds.insert(vec![]);
492 let d = preds.insert(vec![]);
493
494 preds[b].push(a);
495 preds[c].push(b);
496 preds[d].push(c);
497
498 // B and C are enemies (must not merge).
499 let mut merge =
500 SubgraphMerge::new(preds.keys(), |v| preds[v].iter().copied(), [(b, c)]).unwrap();
501
502 // Direct enemy pair: rejected.
503 assert!(!merge.try_merge(b, c));
504
505 // Non-enemy pairs: allowed.
506 assert!(merge.try_merge(a, b));
507
508 // Now A and B are merged. C is still an enemy of the AB group.
509 assert!(!merge.try_merge(a, c));
510 assert!(!merge.try_merge(b, c));
511
512 // D is not an enemy of anyone.
513 assert!(merge.try_merge(c, d));
514
515 // After C and D merge, the CD group is still an enemy of AB.
516 assert!(!merge.try_merge(a, d));
517 assert!(!merge.try_merge(b, d));
518 }
519}