1
use std::collections::HashMap;
2
use std::collections::HashSet;
3
use std::env;
4
use std::error::Error;
5
use std::fs::File;
6
use std::fs::{self};
7
use std::io::BufRead;
8
use std::io::Write;
9
use std::path::Path;
10

            
11
use duct::cmd;
12
use regex::Regex;
13
use serde::Deserialize;
14
use serde::Serialize;
15
use strum::Display;
16
use strum::EnumString;
17

            
18
#[derive(Deserialize, Serialize)]
19
struct MeasurementEntry {
20
    rewriter: String,
21
    benchmark_name: String,
22
    timings: Vec<f32>,
23
}
24

            
25
#[derive(EnumString, Display, PartialEq)]
26
pub enum Rewriter {
27
    #[strum(serialize = "innermost")]
28
    Innermost,
29

            
30
    #[strum(serialize = "sabre")]
31
    Sabre,
32

            
33
    #[strum(serialize = "jitty")]
34
    Jitty,
35

            
36
    #[strum(serialize = "jittyc")]
37
    JittyCompiling,
38
}
39

            
40
/// Benchmarks all the REC specifications in the example folder.
41
///
42
/// - mcrl2 This enables benchmarking the upstream mcrl2rewrite tool
43
pub fn benchmark(output_path: impl AsRef<Path>, rewriter: Rewriter) -> Result<(), Box<dyn Error>> {
44
    // Find the mcrl2rewrite tool based on which rewriter we want to benchmark
45
    let cwd = env::current_dir()?;
46

            
47
    let mcrl2_rewrite_path = if rewriter == Rewriter::Innermost || rewriter == Rewriter::Sabre {
48
        // Build the tool with the correct settings
49
        cmd!("cargo", "build", "--profile", "bench", "--bin", "mcrl2rewrite").run()?;
50

            
51
        // Using which is a bit unnecessary, but it deals nicely with .exe on Windows and can also be used to do other searching.
52
        which::which_in("mcrl2rewrite", Some("target/release/"), cwd)?
53
    } else {
54
        which::which("mcrl2rewrite")?
55
    };
56

            
57
    let mcrl2_rewrite_timing = match rewriter {
58
        Rewriter::Innermost => Regex::new(r#"Innermost rewrite took ([0-9]*) ms"#)?,
59
        Rewriter::Sabre => Regex::new(r#"Sabre rewrite took ([0-9]*) ms"#)?,
60
        Rewriter::Jitty | Rewriter::JittyCompiling => Regex::new(r#"rewriting: ([0-9]*) milliseconds."#)?,
61
    };
62

            
63
    // Create the output directory before creating the file.
64
    if let Some(parent) = output_path.as_ref().parent() {
65
        fs::create_dir_all(parent)?;
66
    }
67

            
68
    let mut result_file = File::create(output_path)?;
69

            
70
    // Consider all the specifications in the example directory.
71
    for file in fs::read_dir("examples/REC/mcrl2")? {
72
        let path = file?.path();
73

            
74
        // We take the dataspec file, and append the expressions ourselves.
75
        if path.extension().is_some_and(|ext| ext == "dataspec") {
76
            let data_spec = path.clone();
77
            let expressions = path.with_extension("expressions");
78

            
79
            let benchmark_name = path.file_stem().unwrap().to_string_lossy();
80
            println!("Benchmarking {}", benchmark_name);
81

            
82
            let mut arguments = vec!["600".to_string(), mcrl2_rewrite_path.to_string_lossy().to_string()];
83

            
84
            match rewriter {
85
                Rewriter::Innermost => {
86
                    arguments.push("rewrite".to_string());
87
                    arguments.push("innermost".to_string());
88
                }
89
                Rewriter::Sabre => {
90
                    arguments.push("rewrite".to_string());
91
                    arguments.push("sabre".to_string());
92
                }
93
                Rewriter::Jitty => {
94
                    arguments.push("-rjitty".to_string());
95
                    arguments.push("--timings".to_string());
96
                }
97
                Rewriter::JittyCompiling => {
98
                    arguments.push("-rjittyc".to_string());
99
                    arguments.push("--timings".to_string());
100
                }
101
            }
102

            
103
            arguments.push(data_spec.to_string_lossy().to_string());
104
            arguments.push(expressions.to_string_lossy().to_string());
105

            
106
            let mut measurements = MeasurementEntry {
107
                rewriter: rewriter.to_string(),
108
                benchmark_name: benchmark_name.to_string(),
109
                timings: Vec::new(),
110
            };
111

            
112
            // Run the benchmarks several times until one of them fails
113
            for _ in 0..5 {
114
                match cmd("timeout", &arguments).stdout_capture().stderr_capture().run() {
115
                    Ok(result) => {
116
                        // Parse the standard output to read the rewriting time and insert it into results.
117
                        for line in result.stdout.lines().chain(result.stderr.lines()) {
118
                            let line = line?;
119

            
120
                            if let Some(result) = mcrl2_rewrite_timing.captures(&line) {
121
                                let (_, [grp1]) = result.extract();
122
                                let timing: f32 = grp1.parse()?;
123

            
124
                                println!("Benchmark {} timing {} milliseconds", benchmark_name, timing);
125

            
126
                                // Write the output to the file and include a newline.
127
                                measurements.timings.push(timing / 1000.0);
128
                            }
129
                        }
130
                    }
131
                    Err(err) => {
132
                        println!("Benchmark {} timed out or crashed", benchmark_name);
133
                        println!("Command failed {:?}", err);
134
                        break;
135
                    }
136
                };
137
            }
138

            
139
            serde_json::to_writer(&mut result_file, &measurements)?;
140

            
141
            writeln!(&result_file)?;
142
        }
143
    }
144

            
145
    Ok(())
146
}
147

            
148
fn average(values: &[f32]) -> f32 {
149
    values.iter().sum::<f32>() / values.len() as f32
150
}
151

            
152
/// Prints a float with two decimals, since format specifiers cannot be stacked.
153
fn print_float(value: f32) -> String {
154
    format!("{:.1}", value)
155
}
156

            
157
pub fn create_table(json_path: impl AsRef<Path>) -> Result<(), Box<dyn Error>> {
158
    let output = fs::read_to_string(json_path)?;
159

            
160
    // Keep track of all the read results.
161
    let mut results: HashMap<String, HashMap<String, f32>> = HashMap::new();
162

            
163
    // Figure out the list of rewriters used to print '-' values.
164
    let mut rewriters: HashSet<String> = HashSet::new();
165

            
166
    for line in output.lines() {
167
        let timing = serde_json::from_str::<MeasurementEntry>(line)?;
168

            
169
        rewriters.insert(timing.rewriter.clone());
170

            
171
        results
172
            .entry(timing.benchmark_name)
173
            .and_modify(|e| {
174
                e.insert(timing.rewriter.clone(), average(&timing.timings));
175
            })
176
            .or_insert_with(|| {
177
                let mut table = HashMap::new();
178
                table.insert(timing.rewriter.clone(), average(&timing.timings));
179
                table
180
            });
181
    }
182

            
183
    // Print the header of the table.
184
    let mut first = true;
185
    for rewriter in &rewriters {
186
        if first {
187
            print!("{: >30}", rewriter);
188
            first = false;
189
        } else {
190
            print!("{: >10} |", rewriter);
191
        }
192
    }
193

            
194
    // Print the entries in the table.
195
    for (benchmark, result) in &results {
196
        print!("{: >30}", benchmark);
197

            
198
        for rewriter in &rewriters {
199
            if let Some(timing) = result.get(rewriter) {
200
                print!("| {: >10}", print_float(*timing));
201
            } else {
202
                print!("| {: >10}", "-");
203
            }
204
        }
205

            
206
        println!();
207
    }
208

            
209
    Ok(())
210
}