1
use std::sync::Arc;
2

            
3
use glam::Vec3;
4
use log::debug;
5
use lts::LabelledTransitionSystem;
6
use rand::Rng;
7
use unsafety::index_edge;
8
use unsafety::Edge;
9

            
10
pub struct GraphLayout {
11
    // Store the underlying LTS to get the edges.
12
    pub lts: Arc<LabelledTransitionSystem>,
13

            
14
    // For every state store layout information.
15
    pub layout_states: Vec<StateLayout>,
16
}
17

            
18
#[derive(Clone, Default)]
19
pub struct StateLayout {
20
    pub position: Vec3,
21
    pub force: Vec3,
22
}
23

            
24
impl GraphLayout {
25
    /// Construct a new layout for the given LTS.
26
1
    pub fn new(lts: &Arc<LabelledTransitionSystem>) -> GraphLayout {
27
1
        // Keep track of state layout information.
28
1
        let mut states_simulation = vec![StateLayout::default(); lts.num_of_states()];
29
1

            
30
1
        // Place the states at a random position within some bound based on the number of states.
31
1
        let mut rng = rand::rng();
32
1
        let bound = (lts.num_of_states() as f32).sqrt().ceil();
33
1

            
34
1
        debug!("Placing states within bound {bound}");
35
75
        for layout_state in &mut states_simulation {
36
74
            layout_state.position.x = rng.random_range(-bound..bound);
37
74
            layout_state.position.y = rng.random_range(-bound..bound);
38
74
        }
39

            
40
1
        GraphLayout {
41
1
            lts: lts.clone(),
42
1
            layout_states: states_simulation,
43
1
        }
44
1
    }
45

            
46
    /// Update the layout one step using spring forces for transitions and repulsion between states.
47
    ///
48
    /// Returns true iff the layout is stable.
49
5
    pub fn update(&mut self, handle_length: f32, repulsion_strength: f32, delta: f32) -> bool {
50
370
        for state_index in self.lts.iter_states() {
51
            // Ignore the last state since it cannot repulse with any other state.
52
370
            if state_index < self.layout_states.len() {
53
                // Use split_at_mut to get two mutable slices at every split point.
54
370
                let (left_layout, right_layout) = self.layout_states.split_at_mut(state_index + 1);
55
370
                let state_layout = left_layout.last_mut().unwrap();
56

            
57
                // Accumulate repulsion forces between vertices.
58
13875
                for other_state_layout in right_layout {
59
13505
                    let force = compute_repulsion_force(
60
13505
                        &state_layout.position,
61
13505
                        &other_state_layout.position,
62
13505
                        repulsion_strength,
63
13505
                    );
64
13505

            
65
13505
                    state_layout.force += force;
66
13505
                    other_state_layout.force -= force;
67
13505
                }
68
            }
69

            
70
            // Accumulate forces over all connected edges.
71
460
            for (_, to_index) in self.lts.outgoing_transitions(state_index) {
72
                // Index an edge in the graph.
73
460
                match index_edge(&mut self.layout_states, state_index, *to_index) {
74
                    Edge::Selfloop(_) => {
75
                        // Handle self loop, but we apply no forces in this case.
76
                    }
77
460
                    Edge::Regular(from_layout, to_layout) => {
78
460
                        let force = compute_spring_force(&from_layout.position, &to_layout.position, handle_length);
79
460

            
80
460
                        from_layout.force += force;
81
460
                        to_layout.force -= force;
82
460
                    }
83
                }
84
            }
85
        }
86

            
87
        // Keep track of the total displacement of the system, to determine stablity
88
5
        let mut displacement = 0.0;
89

            
90
375
        for state_layout in &mut self.layout_states {
91
            // Integrate the forces.
92
370
            state_layout.position += state_layout.force * delta;
93
370
            displacement += (state_layout.force * delta).length_squared();
94
370

            
95
370
            // Reset the force.
96
370
            state_layout.force = Vec3::default();
97
370

            
98
370
            // A safety check for when the layout exploded.
99
370
            assert!(
100
370
                state_layout.position.is_finite(),
101
                "Invalid position {} obtained",
102
                state_layout.position
103
            );
104
        }
105

            
106
5
        (displacement / self.layout_states.len() as f32) < 0.01
107
5
    }
108
}
109

            
110
/// Compute a sping force between two points with a desired rest length.
111
460
fn compute_spring_force(p1: &Vec3, p2: &Vec3, rest_length: f32) -> Vec3 {
112
460
    let dist = p1.distance(*p2);
113
460

            
114
460
    if dist < 0.1 {
115
        // Give it some offset force.
116
        Vec3::new(0.0, 0.2, 0.0)
117
    } else {
118
        // This is already multiplied by -1.0, i.e. (p2 - p1) == (p1 - p2) * -1.0
119
460
        (*p2 - *p1) / dist * f32::log2(dist / rest_length)
120
    }
121
460
}
122

            
123
/// Computes a repulsion force between two points with a given strength.
124
13505
fn compute_repulsion_force(p1: &Vec3, p2: &Vec3, repulsion_strength: f32) -> Vec3 {
125
13505
    let dist = p1.distance_squared(*p2);
126
13505

            
127
13505
    if dist < 1.0 {
128
        // Give it some offset force.
129
113
        Vec3::new(0.0, 0.0, 0.0)
130
    } else {
131
13392
        (*p1 - *p2) * repulsion_strength / dist
132
    }
133
13505
}
134

            
135
#[cfg(test)]
136
mod tests {
137
    use std::sync::Arc;
138

            
139
    use io::io_aut::read_aut;
140

            
141
    use super::GraphLayout;
142

            
143
    #[test]
144
1
    fn test_graph_layout() {
145
1
        let file = include_str!("../../../../examples/lts/abp.aut");
146
1
        let lts = Arc::new(read_aut(file.as_bytes(), vec![]).unwrap());
147
1

            
148
1
        let mut layout = GraphLayout::new(&lts);
149
1

            
150
1
        // Perform a number of updates
151
1
        layout.update(5.0, 1.0, 0.01);
152
1
        layout.update(5.0, 1.0, 0.01);
153
1
        layout.update(5.0, 1.0, 0.01);
154
1
        layout.update(5.0, 1.0, 0.01);
155
1
        layout.update(5.0, 1.0, 0.01);
156
1
    }
157
}