1
use std::error::Error;
2

            
3
use log::debug;
4
use log::trace;
5

            
6
use crate::LabelledTransitionSystem;
7

            
8
/// Returns a topological ordering of the states of the given LTS.
9
///
10
/// An error is returned if the LTS contains a cycle.
11
///     - reverse: If true, the topological ordering is reversed, i.e. successors before the incoming state.
12
116
pub fn sort_topological<F>(
13
116
    lts: &LabelledTransitionSystem,
14
116
    filter: F,
15
116
    reverse: bool,
16
116
) -> Result<Vec<usize>, Box<dyn Error>>
17
116
where
18
116
    F: Fn(usize, usize) -> bool,
19
116
{
20
116
    let start = std::time::Instant::now();
21
116
    trace!("{:?}", lts);
22

            
23
    // The resulting order of states.
24
116
    let mut stack = Vec::new();
25
116

            
26
116
    let mut visited = vec![false; lts.num_of_states()];
27
116
    let mut depth_stack = Vec::new();
28
116
    let mut marks = vec![None; lts.num_of_states()];
29

            
30
565007
    for state_index in lts.iter_states() {
31
565007
        if marks[state_index].is_none()
32
565007
            && !sort_topological_visit(
33
565007
                lts,
34
565007
                &filter,
35
565007
                state_index,
36
565007
                &mut depth_stack,
37
565007
                &mut marks,
38
565007
                &mut visited,
39
565007
                &mut stack,
40
565007
            )
41
        {
42
            trace!("There is a cycle from state {state_index} on path {stack:?}");
43
            return Err("Labelled transition system contains a cycle".into());
44
565007
        }
45
    }
46

            
47
116
    if !reverse {
48
60
        stack.reverse();
49
60
    }
50
116
    trace!("Topological order: {stack:?}");
51

            
52
    // Turn the stack into a permutation.
53
116
    let mut reorder = vec![0; lts.num_of_states()];
54
565007
    for (i, &state_index) in stack.iter().enumerate() {
55
565007
        reorder[state_index] = i;
56
565007
    }
57

            
58
116
    debug_assert!(
59
2603714
        is_topologically_sorted(lts, filter, |i| reorder[i], reverse),
60
        "The permutation {reorder:?} is not a valid topological ordering of the states of the given LTS: {lts:?}"
61
    );
62
116
    debug!("Time sort_topological: {:.3}s", start.elapsed().as_secs_f64());
63

            
64
116
    Ok(reorder)
65
116
}
66

            
67
/// Reorders the states of the given LTS according to the given permutation.
68
57
pub fn reorder_states<P>(lts: &LabelledTransitionSystem, permutation: P) -> LabelledTransitionSystem
69
57
where
70
57
    P: Fn(usize) -> usize,
71
57
{
72
57
    let start = std::time::Instant::now();
73
57

            
74
57
    // We know that it is a permutation, so there won't be any duplicated transitions.
75
57
    let mut transitions: Vec<(usize, usize, usize)> = Vec::default();
76

            
77
282497
    for state_index in lts.iter_states() {
78
282497
        let new_state_index = permutation(state_index);
79

            
80
490487
        for (label, to_index) in lts.outgoing_transitions(state_index) {
81
490481
            let new_to_index = permutation(*to_index);
82
490481
            transitions.push((new_state_index, *label, new_to_index));
83
490481
        }
84
    }
85

            
86
57
    debug!("Time reorder_states: {:.3}s", start.elapsed().as_secs_f64());
87
57
    LabelledTransitionSystem::new(
88
57
        permutation(lts.initial_state_index()),
89
57
        Some(lts.num_of_states()),
90
114
        || transitions.iter().cloned(),
91
57
        lts.labels().into(),
92
57
        lts.hidden_labels().into(),
93
57
    )
94
57
}
95

            
96
// The mark of a state in the depth first search.
97
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
98
enum Mark {
99
    Temporary,
100
    Permanent,
101
}
102

            
103
/// Visits the given state in a depth first search.
104
///
105
/// Returns false if a cycle is detected.
106
565007
fn sort_topological_visit<F>(
107
565007
    lts: &LabelledTransitionSystem,
108
565007
    filter: &F,
109
565007
    state_index: usize,
110
565007
    depth_stack: &mut Vec<usize>,
111
565007
    marks: &mut [Option<Mark>],
112
565007
    visited: &mut [bool],
113
565007
    stack: &mut Vec<usize>,
114
565007
) -> bool
115
565007
where
116
565007
    F: Fn(usize, usize) -> bool,
117
565007
{
118
565007
    // Perform a depth first search.
119
565007
    depth_stack.push(state_index);
120

            
121
1695021
    while let Some(state) = depth_stack.pop() {
122
1130014
        match marks[state] {
123
            None => {
124
565007
                marks[state] = Some(Mark::Temporary);
125
565007
                depth_stack.push(state); // Re-add to stack to mark as permanent later
126
565007
                for (_, next_state) in lts
127
565007
                    .outgoing_transitions(state)
128
980987
                    .filter(|(label, to)| filter(*label, *to))
129
                {
130
                    // If it was marked temporary, then a cycle is detected.
131
343686
                    if marks[*next_state] == Some(Mark::Temporary) {
132
                        return false;
133
343686
                    }
134
343686
                    if marks[*next_state].is_none() {
135
                        depth_stack.push(*next_state);
136
343686
                    }
137
                }
138
            }
139
565007
            Some(Mark::Temporary) => {
140
565007
                marks[state] = Some(Mark::Permanent);
141
565007
                visited[state] = true;
142
565007
                stack.push(state);
143
565007
            }
144
            Some(Mark::Permanent) => {}
145
        }
146
    }
147

            
148
565007
    true
149
565007
}
150

            
151
/// Returns true if the given permutation is a topological ordering of the states of the given LTS.
152
117
fn is_topologically_sorted<F, P>(lts: &LabelledTransitionSystem, filter: F, permutation: P, reverse: bool) -> bool
153
117
where
154
117
    F: Fn(usize, usize) -> bool,
155
117
    P: Fn(usize) -> usize,
156
117
{
157
117
    debug_assert!(is_valid_permutation(&permutation, lts.num_of_states()));
158

            
159
    // Check that each vertex appears before its successors.
160
565017
    for state_index in lts.iter_states() {
161
565017
        let state_order = permutation(state_index);
162
565017
        for (_, successor) in lts
163
565017
            .outgoing_transitions(state_index)
164
980997
            .filter(|(label, to)| filter(*label, *to))
165
        {
166
343689
            if reverse {
167
171839
                if state_order <= permutation(*successor) {
168
                    return false;
169
171839
                }
170
171850
            } else if state_order >= permutation(*successor) {
171
                return false;
172
171850
            }
173
        }
174
    }
175

            
176
117
    true
177
117
}
178

            
179
/// Returns true if the given permutation is a valid permutation.
180
120
fn is_valid_permutation<P>(permutation: &P, max: usize) -> bool
181
120
where
182
120
    P: Fn(usize) -> usize,
183
120
{
184
120
    let mut visited = vec![false; max];
185

            
186
565045
    for i in 0..max {
187
        // Out of bounds
188
565045
        if permutation(i) >= max {
189
1
            return false;
190
565044
        }
191
565044

            
192
565044
        if visited[permutation(i)] {
193
1
            return false;
194
565043
        }
195
565043

            
196
565043
        visited[permutation(i)] = true;
197
    }
198

            
199
    // Check that all entries are visited.
200
565027
    visited.iter().all(|&v| v)
201
120
}
202

            
203
#[cfg(test)]
204
mod tests {
205

            
206
    use rand::seq::SliceRandom;
207
    use test_log::test;
208

            
209
    use crate::random_lts;
210

            
211
    use super::*;
212

            
213
1
    #[test]
214
    fn test_sort_topological_with_cycles() {
215
        let lts = random_lts(10, 3, 2);
216
6
        match sort_topological(&lts, |_, _| true, false) {
217
43
            Ok(order) => assert!(is_topologically_sorted(&lts, |_, _| true, |i| order[i], false)),
218
            Err(_) => {}
219
        }
220
    }
221

            
222
1
    #[test]
223
    fn test_reorder_states() {
224
        let lts = random_lts(10, 3, 2);
225

            
226
        // Generate a random permutation.
227
        let mut rng = rand::rng();
228
        let order: Vec<usize> = {
229
            let mut order: Vec<usize> = (0..lts.num_of_states()).collect();
230
            order.shuffle(&mut rng);
231
            order
232
        };
233

            
234
17
        let new_lts = reorder_states(&lts, |i| order[i]);
235

            
236
        trace!("{:?}", lts);	
237
        trace!("{:?}", new_lts);	
238

            
239
        //assert_eq!(new_lts.num_of_states(), lts.num_of_states());
240
        assert_eq!(new_lts.num_of_labels(), lts.num_of_labels());
241

            
242
        for from in lts.iter_states() {
243
            // Check that the states are in the correct order.
244
            for &(label, to) in lts.outgoing_transitions(from) {
245
                let new_from = order[from];
246
                let new_to = order[to];
247
                assert!(new_lts
248
                    .outgoing_transitions(new_from)
249
6
                    .any(|trans| *trans == (label, new_to)));
250
            }
251
        }
252
    }
253

            
254
1
    #[test]
255
    fn test_is_valid_permutation() {
256
        let lts = random_lts(10, 15, 2);
257

            
258
        // Generate a valid permutation.
259
        let mut rng = rand::rng();
260
        let valid_permutation: Vec<usize> = {
261
            let mut order: Vec<usize> = (0..lts.num_of_states()).collect();
262
            order.shuffle(&mut rng);
263
            order
264
        };
265

            
266
30
        assert!(is_valid_permutation(&|i| valid_permutation[i], valid_permutation.len()));
267

            
268
        // Generate an invalid permutation (duplicate entries).
269
        let invalid_permutation = vec![0, 1, 2, 3, 4, 5, 6, 7, 8, 8];
270
        assert!(!is_valid_permutation(
271
29
            &|i| invalid_permutation[i],
272
            invalid_permutation.len()
273
        ));
274

            
275
        // Generate an invalid permutation (missing entries).
276
        let invalid_permutation = vec![0, 1, 3, 4, 5, 6, 7, 8];
277
        assert!(!is_valid_permutation(
278
22
            &|i| invalid_permutation[i],
279
            invalid_permutation.len()
280
        ));
281
    }
282
}