diff --git a/.gitignore b/.gitignore index ab653d5..16ea2a7 100644 --- a/.gitignore +++ b/.gitignore @@ -5,3 +5,8 @@ Cargo.lock *\# \#* rmi_data +binary_models/ +rmi_models/ + +binary_models/ + diff --git a/.vscode/c_cpp_properties.json b/.vscode/c_cpp_properties.json new file mode 100644 index 0000000..c2098a2 --- /dev/null +++ b/.vscode/c_cpp_properties.json @@ -0,0 +1,18 @@ +{ + "configurations": [ + { + "name": "linux-gcc-x64", + "includePath": [ + "${workspaceFolder}/**" + ], + "compilerPath": "/usr/bin/gcc", + "cStandard": "${default}", + "cppStandard": "${default}", + "intelliSenseMode": "linux-gcc-x64", + "compilerArgs": [ + "" + ] + } + ], + "version": 4 +} \ No newline at end of file diff --git a/.vscode/launch.json b/.vscode/launch.json new file mode 100644 index 0000000..f77b2ec --- /dev/null +++ b/.vscode/launch.json @@ -0,0 +1,48 @@ +{ + "version": "0.2.0", + "configurations": [ + { + "name": "Test the correctness of RMI loading", + "type": "cppdbg", + "request": "launch", + "program": "${workspaceFolder}/tests/verify_rmi", + "args": [], + "stopAtEntry": false, + "cwd": "${fileDirname}", + "environment": [], + "externalConsole": false, + "MIMode": "gdb", + "setupCommands": [ + { + "description": "Enable pretty-printing for gdb", + "text": "-enable-pretty-printing", + "ignoreFailures": true + }, + { + "description": "Set Disassembly Flavor to Intel", + "text": "-gdb-set disassembly-flavor intel", + "ignoreFailures": true + } + ] + }, + { + "name": "C/C++ Runner: Debug Session", + "type": "cppdbg", + "request": "launch", + "args": [], + "stopAtEntry": false, + "externalConsole": false, + "cwd": "/home/andy/Projects/RMI", + "program": "/home/andy/Projects/RMI/build/Debug/outDebug", + "MIMode": "gdb", + "miDebuggerPath": "gdb", + "setupCommands": [ + { + "description": "Enable pretty-printing for gdb", + "text": "-enable-pretty-printing", + "ignoreFailures": true + } + ] + } + ] +} \ No newline at end of file diff --git a/.vscode/settings.json b/.vscode/settings.json new file mode 100644 index 0000000..3e5eb95 --- /dev/null +++ b/.vscode/settings.json @@ -0,0 +1,59 @@ +{ + "C_Cpp_Runner.cCompilerPath": "gcc", + "C_Cpp_Runner.cppCompilerPath": "g++", + "C_Cpp_Runner.debuggerPath": "gdb", + "C_Cpp_Runner.cStandard": "", + "C_Cpp_Runner.cppStandard": "", + "C_Cpp_Runner.msvcBatchPath": "", + "C_Cpp_Runner.useMsvc": false, + "C_Cpp_Runner.warnings": [ + "-Wall", + "-Wextra", + "-Wpedantic", + "-Wshadow", + "-Wformat=2", + "-Wcast-align", + "-Wconversion", + "-Wsign-conversion", + "-Wnull-dereference" + ], + "C_Cpp_Runner.msvcWarnings": [ + "/W4", + "/permissive-", + "/w14242", + "/w14287", + "/w14296", + "/w14311", + "/w14826", + "/w44062", + "/w44242", + "/w14905", + "/w14906", + "/w14263", + "/w44265", + "/w14928" + ], + "C_Cpp_Runner.enableWarnings": true, + "C_Cpp_Runner.warningsAsError": false, + "C_Cpp_Runner.compilerArgs": [], + "C_Cpp_Runner.linkerArgs": [], + "C_Cpp_Runner.includePaths": [], + "C_Cpp_Runner.includeSearch": [ + "*", + "**/*" + ], + "C_Cpp_Runner.excludeSearch": [ + "**/build", + "**/build/**", + "**/.*", + "**/.*/**", + "**/.vscode", + "**/.vscode/**" + ], + "C_Cpp_Runner.useAddressSanitizer": false, + "C_Cpp_Runner.useUndefinedSanitizer": false, + "C_Cpp_Runner.useLeakSanitizer": false, + "C_Cpp_Runner.showCompilationTime": false, + "C_Cpp_Runner.useLinkTimeOptimization": false, + "C_Cpp_Runner.msvcSecureNoWarnings": false +} \ No newline at end of file diff --git a/.vscode/tasks.json b/.vscode/tasks.json new file mode 100644 index 0000000..08d9005 --- /dev/null +++ b/.vscode/tasks.json @@ -0,0 +1,28 @@ +{ + "tasks": [ + { + "type": "cppbuild", + "label": "C/C++: gcc build active file", + "command": "/usr/bin/gcc", + "args": [ + "-fdiagnostics-color=always", + "-g", + "${file}", + "-o", + "${fileDirname}/${fileBasenameNoExtension}" + ], + "options": { + "cwd": "${fileDirname}" + }, + "problemMatcher": [ + "$gcc" + ], + "group": { + "kind": "build", + "isDefault": true + }, + "detail": "Task generated by Debugger." + } + ], + "version": "2.0.0" +} \ No newline at end of file diff --git a/README.md b/README.md index de1c6ec..c96a996 100644 --- a/README.md +++ b/README.md @@ -82,6 +82,7 @@ Currently, the following types of RMI layers are supported: * `radix`, eliminates common prefixes and returns a fixed number of significant bits based on the branching factor * `bradix`, same as radix, but attempts to choose the number of bits based on balancing the dataset * `histogram`, partitions the data into several even-sized blocks (based on the branching factor) +* `optimal_pla`, a piecewise linear approximation leaf that stitches together short linear segments with a bounded per-segment error Tuning an RMI is critical to getting good performance. A good place to start is a `cubic` layer followed by a large linear layer, for example: `cubic,linear 262144`. For automatic tuning, try the RMI optimizer using the `--optimize` flag: diff --git a/rmi_lib/Cargo.toml b/rmi_lib/Cargo.toml index df21e0c..36be380 100644 --- a/rmi_lib/Cargo.toml +++ b/rmi_lib/Cargo.toml @@ -21,3 +21,5 @@ superslice = "1.0.0" json = "0.12.0" indicatif = "0.13.0" tabular = "0.1.4" +serde = { version = "1.0", features = ["derive"] } +serde_json = "1.0" diff --git a/rmi_lib/src/binary.rs b/rmi_lib/src/binary.rs new file mode 100644 index 0000000..b3ec476 --- /dev/null +++ b/rmi_lib/src/binary.rs @@ -0,0 +1,273 @@ +use crate::models::{KeyType, ModelDataType, ModelParam}; +use crate::train::TrainedRMI; + +#[derive(Debug, Clone)] +pub enum BinaryParamKind { + Int, + Float, + ShortArray, + IntArray, + Int32Array, + FloatArray, +} + +#[derive(Debug, Clone)] +pub struct BinaryModelParam { + pub kind: BinaryParamKind, + pub len: usize, + pub value: ModelParam, +} + +impl From for BinaryModelParam { + fn from(param: ModelParam) -> Self { + let kind = match ¶m { + ModelParam::Int(_) => BinaryParamKind::Int, + ModelParam::Float(_) => BinaryParamKind::Float, + ModelParam::ShortArray(_) => BinaryParamKind::ShortArray, + ModelParam::IntArray(_) => BinaryParamKind::IntArray, + ModelParam::Int32Array(_) => BinaryParamKind::Int32Array, + ModelParam::FloatArray(_) => BinaryParamKind::FloatArray, + }; + + BinaryModelParam { + kind, + len: param.len(), + value: param, + } + } +} + +#[derive(Debug, Clone)] +pub struct Model { + pub model_type: String, + pub input_type: ModelDataType, + pub output_type: ModelDataType, + pub params: Vec, + pub error: Option, +} + +#[derive(Debug, Clone)] +pub struct Stage { + pub models: Vec, +} + +#[derive(Debug, Clone)] +pub struct CacheFix { + pub line_size: usize, + pub spline_points: Vec<(u64, usize)>, +} + +#[derive(Debug, Clone)] +pub struct RMIModel { + + pub branching_factor: u64, + pub models: String, + pub last_layer_reports_error: bool, + pub num_rmi_rows: usize, + pub num_data_rows: usize, + pub model_avg_error: f64, + pub model_avg_l2_error: f64, + pub model_avg_log2_error: f64, + pub model_max_error: u64, + pub model_max_error_idx: usize, + pub model_max_log2_error: f64, + pub build_time: u128, + pub last_layer_max_l1s: Vec, + pub stages: Vec, + pub cache_fix: Option, +} + +impl RMIModel { + pub fn from_trained( + rmi: &TrainedRMI, + key_type: KeyType, + last_layer_reports_error: bool, + ) -> RMIModel { + let mut stages = Vec::new(); + + for stage in &rmi.rmi { + let mut models = Vec::new(); + for m in stage { + let params = m.params().into_iter().map(Into::into).collect(); + + models.push(Model { + model_type: m.model_name().to_string(), + input_type: m.input_type(), + output_type: m.output_type(), + params, + error: m.error_bound(), + }); + } + stages.push(Stage { models }); + } + + let cache_fix = rmi.cache_fix.as_ref().map(|(line_size, points)| CacheFix { + line_size: *line_size, + spline_points: points.clone(), + }); + + RMIModel { + branching_factor: rmi.branching_factor, + models: rmi.models.clone(), + last_layer_reports_error, + num_rmi_rows: rmi.num_rmi_rows, + num_data_rows: rmi.num_data_rows, + model_avg_error: rmi.model_avg_error, + model_avg_l2_error: rmi.model_avg_l2_error, + model_avg_log2_error: rmi.model_avg_log2_error, + model_max_error: rmi.model_max_error, + model_max_error_idx: rmi.model_max_error_idx, + model_max_log2_error: rmi.model_max_log2_error, + build_time: rmi.build_time, + last_layer_max_l1s: rmi.last_layer_max_l1s.clone(), + stages, + cache_fix, + } + } +} + +use std::io::{Result, Write}; + +fn write_string(w: &mut W, s: &str) -> Result<()> { + let bytes = s.as_bytes(); + w.write_all(&(bytes.len() as u64).to_le_bytes())?; + w.write_all(bytes) +} + +fn write_model_data_type(w: &mut W, t: &ModelDataType) -> Result<()> { + let code = match t { + ModelDataType::Int => 0u8, + ModelDataType::Int128 => 1u8, + ModelDataType::Float => 2u8, + }; + + w.write_all(&[code]) +} + +fn write_param(w: &mut W, p: &BinaryModelParam) -> Result<()> { + let kind_code = match p.kind { + BinaryParamKind::Int => 0u8, + BinaryParamKind::Float => 1u8, + BinaryParamKind::ShortArray => 2u8, + BinaryParamKind::IntArray => 3u8, + BinaryParamKind::Int32Array => 4u8, + BinaryParamKind::FloatArray => 5u8, + }; + + w.write_all(&[kind_code])?; + w.write_all(&(p.len as u64).to_le_bytes())?; + + match &p.value { + ModelParam::Int(v) => w.write_all(&v.to_le_bytes()), + ModelParam::Float(v) => w.write_all(&v.to_le_bytes()), + ModelParam::ShortArray(arr) => { + for v in arr { + w.write_all(&v.to_le_bytes())?; + } + Ok(()) + } + ModelParam::IntArray(arr) => { + for v in arr { + w.write_all(&v.to_le_bytes())?; + } + Ok(()) + } + ModelParam::Int32Array(arr) => { + for v in arr { + w.write_all(&v.to_le_bytes())?; + } + Ok(()) + } + ModelParam::FloatArray(arr) => { + for v in arr { + w.write_all(&v.to_le_bytes())?; + } + Ok(()) + } + } +} + +fn write_cache_fix(w: &mut W, cf: &CacheFix) -> Result<()> { + w.write_all(&(cf.line_size as u64).to_le_bytes())?; + w.write_all(&(cf.spline_points.len() as u64).to_le_bytes())?; + for (key, offset) in &cf.spline_points { + w.write_all(&key.to_le_bytes())?; + w.write_all(&(*offset as u64).to_le_bytes())?; + } + Ok(()) +} + +fn write_key_type(w: &mut W, k: KeyType) -> Result<()> { + let code = match k { + KeyType::U32 => 0u8, + KeyType::U64 => 1u8, + KeyType::F64 => 2u8, + KeyType::U128 => 3u8, + }; + + w.write_all(&[code]) +} + +impl RMIModel { + pub fn save_binary(&self, path: &str) -> Result<()> { + let mut f = std::fs::File::create(path)?; + + f.write_all(b"RMIB")?; + f.write_all(&1u32.to_le_bytes())?; + + f.write_all(&self.branching_factor.to_le_bytes())?; + f.write_all(&(self.num_rmi_rows as u64).to_le_bytes())?; + f.write_all(&(self.num_data_rows as u64).to_le_bytes())?; + f.write_all(&self.build_time.to_le_bytes())?; + f.write_all(&self.model_avg_error.to_le_bytes())?; + f.write_all(&self.model_avg_l2_error.to_le_bytes())?; + f.write_all(&self.model_avg_log2_error.to_le_bytes())?; + f.write_all(&self.model_max_error.to_le_bytes())?; + f.write_all(&(self.model_max_error_idx as u64).to_le_bytes())?; + f.write_all(&self.model_max_log2_error.to_le_bytes())?; + f.write_all(&[self.last_layer_reports_error as u8])?; + + write_string(&mut f, &self.models)?; + + f.write_all(&(self.last_layer_max_l1s.len() as u64).to_le_bytes())?; + for v in &self.last_layer_max_l1s { + f.write_all(&v.to_le_bytes())?; + } + + if let Some(cf) = &self.cache_fix { + f.write_all(&[1u8])?; + write_cache_fix(&mut f, cf)?; + } else { + f.write_all(&[0u8])?; + } + + let stage_count = self.stages.len() as u64; + f.write_all(&stage_count.to_le_bytes())?; + + for stage in &self.stages { + let model_count = stage.models.len() as u64; + f.write_all(&model_count.to_le_bytes())?; + + for m in &stage.models { + write_string(&mut f, &m.model_type)?; + write_model_data_type(&mut f, &m.input_type)?; + write_model_data_type(&mut f, &m.output_type)?; + + match m.error { + Some(err) => { + f.write_all(&[1u8])?; + f.write_all(&err.to_le_bytes())?; + } + None => f.write_all(&[0u8])?, + }; + + f.write_all(&(m.params.len() as u64).to_le_bytes())?; + for p in &m.params { + write_param(&mut f, p)?; + } + } + } + + Ok(()) + } +} diff --git a/rmi_lib/src/codegen.rs b/rmi_lib/src/codegen.rs index 6193b39..cf38bb6 100644 --- a/rmi_lib/src/codegen.rs +++ b/rmi_lib/src/codegen.rs @@ -9,13 +9,14 @@ use crate::models::Model; use crate::models::*; +use crate::manifest::{self, CacheFixMetadata, LayerMetadata, LayerStorage, ParameterDescriptor, ParamValue}; use bytesize::ByteSize; use log::*; use std::collections::HashSet; use std::io::Write; use std::str; use crate::train::TrainedRMI; -use std::fs::File; +use std::fs::{self, File}; use std::io::BufWriter; use std::path::Path; use std::fmt; @@ -208,6 +209,39 @@ impl LayerParams { }; } + fn num_models(&self) -> usize { + match self { + LayerParams::Constant(_, _) => 1, + LayerParams::Array(_, ppm, params) | + LayerParams::MixedArray(_, ppm, params) => { + assert_eq!(params.len() % ppm, 0); + params.len() / ppm + } + } + } + + fn sample_params(&self) -> &[ModelParam] { + match self { + LayerParams::Constant(_, params) => params.as_slice(), + LayerParams::Array(_, ppm, params) | + LayerParams::MixedArray(_, ppm, params) => ¶ms[0..*ppm], + } + } + + fn storage_descriptor(&self, namespace: &str) -> LayerStorage { + match self { + LayerParams::Constant(_, params) => LayerStorage::Constant { + values: params.iter().map(ParamValue::from).collect(), + }, + LayerParams::Array(idx, _, _) => LayerStorage::Array { + file: format!("{}_{}", namespace, array_name!(*idx)), + }, + LayerParams::MixedArray(idx, _, _) => LayerStorage::MixedArray { + file: format!("{}_{}", namespace, array_name!(*idx)), + }, + } + } + fn size(&self) -> usize { return self.params().iter().map(|p| p.size()).sum(); } @@ -462,9 +496,15 @@ fn generate_code( .enumerate() .map(|(layer_idx, models)| params_for_layer(layer_idx, models)) .collect(); - + let report_last_layer_errors = !rmi.last_layer_max_l1s.is_empty(); + let namespace_dir = Path::new(data_dir).join(namespace); + if !namespace_dir.exists() { + fs::create_dir_all(&namespace_dir) + .expect("Unable to create namespace-specific RMI data directory"); + } + let mut report_lle: Vec = Vec::new(); if report_last_layer_errors { let lle = &rmi.last_layer_max_l1s; @@ -512,7 +552,7 @@ fn generate_code( LayerParams::Array(idx, _, _) | LayerParams::MixedArray(idx, _, _) => { - let data_path = Path::new(&data_dir) + let data_path = namespace_dir .join(format!("{}_{}", namespace, array_name!(idx))); let f = File::create(data_path) .expect("Could not write data file to RMI directory"); @@ -522,8 +562,13 @@ fn generate_code( lp.to_decl(data_output)?; // write to source code read_code.push(" {".to_string()); - read_code.push(format!(" std::ifstream infile(std::filesystem::path(dataPath) / \"{ns}_{fn}\", std::ios::in | std::ios::binary);", + read_code.push(format!(" auto primary = std::filesystem::path(dataPath) / \"{ns}\" / \"{ns}_{fn}\";", + ns=namespace, fn=array_name!(idx))); + read_code.push(format!(" std::ifstream infile(primary, std::ios::in | std::ios::binary);")); + read_code.push(" if (!infile.good()) {".to_string()); + read_code.push(format!(" infile.open(std::filesystem::path(dataPath) / \"{ns}_{fn}\", std::ios::in | std::ios::binary);", ns=namespace, fn=array_name!(idx))); + read_code.push(" }".to_string()); read_code.push(" if (!infile.good()) return false;".to_string()); if lp.requires_malloc() { read_code.push(format!(" {} = ({}*) malloc({});", @@ -538,6 +583,41 @@ fn generate_code( } } } + let manifest_layers: Vec = layer_params + .iter() + .take(rmi.rmi.len()) + .zip(rmi.rmi.iter()) + .map(|(lp, models)| LayerMetadata { + index: lp.index(), + model_type: models[0].model_name().to_string(), + num_models: lp.num_models(), + params_per_model: lp.params_per_model(), + parameters: lp.sample_params().iter().map(ParameterDescriptor::from).collect(), + storage: lp.storage_descriptor(namespace), + }) + .collect(); + + let cache_fix_meta = if let Some((line_size, points)) = &rmi.cache_fix { + let idx = layer_params.len() - 1; + Some(CacheFixMetadata { + file: format!("{}_{}", namespace, array_name!(idx)), + line_size: *line_size, + points: points.len(), + }) + } else { + None + }; + + manifest::write_metadata( + namespace, + data_dir, + key_type, + &rmi, + report_last_layer_errors, + manifest_layers, + cache_fix_meta, + ).expect("Unable to write RMI manifest"); + read_code.push(" return true;".to_string()); read_code.push("}".to_string()); @@ -757,17 +837,23 @@ inline size_t FCLAMP(double inp, double bound) {{ pub fn output_rmi(namespace: &str, mut trained_model: TrainedRMI, data_dir: &str, + output_dir: &str, key_type: KeyType, include_errors: bool) -> Result<(), std::io::Error> { - - let f1 = File::create(format!("{}.cpp", namespace)).expect("Could not write RMI CPP file"); + + let output_dir = Path::new(output_dir); + fs::create_dir_all(output_dir)?; + + let f1_path = output_dir.join(format!("{}.cpp", namespace)); + let f1 = File::create(&f1_path).expect("Could not write RMI CPP file"); let mut bw1 = BufWriter::new(f1); - - let f2 = - File::create(format!("{}_data.h", namespace)).expect("Could not write RMI data file"); + + let f2_path = output_dir.join(format!("{}_data.h", namespace)); + let f2 = File::create(&f2_path).expect("Could not write RMI data file"); let mut bw2 = BufWriter::new(f2); - - let f3 = File::create(format!("{}.h", namespace)).expect("Could not write RMI header file"); + + let f3_path = output_dir.join(format!("{}.h", namespace)); + let f3 = File::create(&f3_path).expect("Could not write RMI header file"); let mut bw3 = BufWriter::new(f3); if !include_errors { diff --git a/rmi_lib/src/lib.rs b/rmi_lib/src/lib.rs index 0dcc925..29998fd 100644 --- a/rmi_lib/src/lib.rs +++ b/rmi_lib/src/lib.rs @@ -1,7 +1,9 @@ mod codegen; +mod binary; mod models; mod train; mod cache_fix; +mod manifest; pub mod optimizer; pub use models::{RMITrainingData, RMITrainingDataIteratorProvider, ModelInput}; @@ -10,3 +12,5 @@ pub use optimizer::find_pareto_efficient_configs; pub use train::{train, train_for_size, train_bounded}; pub use codegen::rmi_size; pub use codegen::output_rmi; +pub use manifest::{CacheFixMetadata, LayerMetadata, ParameterDescriptor, ParamKind, ParamValue, RmiMetadata}; +pub use binary::{Model, Stage, RMIModel}; diff --git a/rmi_lib/src/manifest.rs b/rmi_lib/src/manifest.rs new file mode 100644 index 0000000..1938d55 --- /dev/null +++ b/rmi_lib/src/manifest.rs @@ -0,0 +1,156 @@ +use crate::models::KeyType; +use crate::models::ModelParam; +use crate::train::TrainedRMI; +use serde::{Deserialize, Serialize}; +use std::fs::{self, File}; +use std::io::Write; +use std::path::{Path, PathBuf}; + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub enum ParamKind { + Int, + Float, + ShortArray, + IntArray, + Int32Array, + FloatArray, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(tag = "type", content = "value")] +pub enum ParamValue { + Int(u64), + Float(f64), + ShortArray(Vec), + IntArray(Vec), + Int32Array(Vec), + FloatArray(Vec), +} + +impl From<&ModelParam> for ParamValue { + fn from(param: &ModelParam) -> Self { + match param { + ModelParam::Int(v) => ParamValue::Int(*v), + ModelParam::Float(v) => ParamValue::Float(*v), + ModelParam::ShortArray(v) => ParamValue::ShortArray(v.clone()), + ModelParam::IntArray(v) => ParamValue::IntArray(v.clone()), + ModelParam::Int32Array(v) => ParamValue::Int32Array(v.clone()), + ModelParam::FloatArray(v) => ParamValue::FloatArray(v.clone()), + } + } +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ParameterDescriptor { + pub kind: ParamKind, + pub len: usize, +} + +impl From<&ModelParam> for ParameterDescriptor { + fn from(param: &ModelParam) -> Self { + let kind = match param { + ModelParam::Int(_) => ParamKind::Int, + ModelParam::Float(_) => ParamKind::Float, + ModelParam::ShortArray(_) => ParamKind::ShortArray, + ModelParam::IntArray(_) => ParamKind::IntArray, + ModelParam::Int32Array(_) => ParamKind::Int32Array, + ModelParam::FloatArray(_) => ParamKind::FloatArray, + }; + + ParameterDescriptor { + kind, + len: param.len(), + } + } +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub enum LayerStorage { + Constant { values: Vec }, + Array { file: String }, + MixedArray { file: String }, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct LayerMetadata { + pub index: usize, + pub model_type: String, + pub num_models: usize, + pub params_per_model: usize, + pub parameters: Vec, + pub storage: LayerStorage, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct CacheFixMetadata { + pub file: String, + pub line_size: usize, + pub points: usize, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct RmiMetadata { + pub namespace: String, + pub key_type: String, + pub models: String, + pub branching_factor: u64, + pub build_time_ns: u128, + pub num_rmi_rows: usize, + pub num_data_rows: usize, + pub model_avg_error: f64, + pub model_avg_l2_error: f64, + pub model_avg_log2_error: f64, + pub model_max_error: u64, + pub model_max_error_idx: usize, + pub model_max_log2_error: f64, + pub last_layer_reports_error: bool, + pub layers: Vec, + pub cache_fix: Option, +} + +impl RmiMetadata { + pub fn manifest_path>(data_dir: P, namespace: &str) -> PathBuf { + data_dir.as_ref().join(namespace).join("manifest.json") + } +} + +pub fn write_metadata>(namespace: &str, + data_dir: P, + key_type: KeyType, + rmi: &TrainedRMI, + last_layer_reports_error: bool, + layers: Vec, + cache_fix: Option) + -> std::io::Result<()> { + + let metadata = RmiMetadata { + namespace: namespace.to_string(), + key_type: key_type.as_str().to_string(), + models: rmi.models.clone(), + branching_factor: rmi.branching_factor, + build_time_ns: rmi.build_time, + num_rmi_rows: rmi.num_rmi_rows, + num_data_rows: rmi.num_data_rows, + model_avg_error: rmi.model_avg_error, + model_avg_l2_error: rmi.model_avg_l2_error, + model_avg_log2_error: rmi.model_avg_log2_error, + model_max_error: rmi.model_max_error, + model_max_error_idx: rmi.model_max_error_idx, + model_max_log2_error: rmi.model_max_log2_error, + last_layer_reports_error, + layers, + cache_fix, + }; + + let manifest_path = RmiMetadata::manifest_path(data_dir, namespace); + if let Some(parent) = manifest_path.parent() { + fs::create_dir_all(parent)?; + } + + let mut file = File::create(manifest_path)?; + let serialized = serde_json::to_vec_pretty(&metadata) + .expect("serialization to JSON should not fail"); + file.write_all(&serialized)?; + Ok(()) +} + diff --git a/rmi_lib/src/models/balanced_radix.rs b/rmi_lib/src/models/balanced_radix.rs index a8e6633..f473037 100644 --- a/rmi_lib/src/models/balanced_radix.rs +++ b/rmi_lib/src/models/balanced_radix.rs @@ -161,6 +161,10 @@ inline uint64_t bradix_clamp_low(uint64_t prefix_length, }; } + fn model_name(&self) -> &'static str { + "bradix" + } + fn needs_bounds_check(&self) -> bool { return false; } diff --git a/rmi_lib/src/models/cubic_spline.rs b/rmi_lib/src/models/cubic_spline.rs index 55cef04..c801964 100644 --- a/rmi_lib/src/models/cubic_spline.rs +++ b/rmi_lib/src/models/cubic_spline.rs @@ -181,6 +181,10 @@ inline double cubic(double a, double b, double c, double d, double x) { fn function_name(&self) -> String { return String::from("cubic"); } + + fn model_name(&self) -> &'static str { + "cubic" + } fn needs_bounds_check(&self) -> bool { return false; } diff --git a/rmi_lib/src/models/histogram.rs b/rmi_lib/src/models/histogram.rs index bc73a48..7d3924d 100644 --- a/rmi_lib/src/models/histogram.rs +++ b/rmi_lib/src/models/histogram.rs @@ -99,6 +99,10 @@ inline uint64_t ed_histogram(const uint64_t length, } fn function_name(&self) -> String { return String::from("ed_histogram"); } + + fn model_name(&self) -> &'static str { + "histogram" + } fn restriction(&self) -> ModelRestriction { return ModelRestriction::MustBeTop; } fn needs_bounds_check(&self) -> bool { return false; } } diff --git a/rmi_lib/src/models/linear.rs b/rmi_lib/src/models/linear.rs index a564d13..5065673 100644 --- a/rmi_lib/src/models/linear.rs +++ b/rmi_lib/src/models/linear.rs @@ -113,6 +113,10 @@ inline double linear(double alpha, double beta, double inp) { return String::from("linear"); } + fn model_name(&self) -> &'static str { + "linear" + } + fn set_to_constant_model(&mut self, constant: u64) -> bool { self.params = (constant as f64, 0.0); return true; @@ -202,6 +206,10 @@ inline double loglinear(double alpha, double beta, double inp) { fn function_name(&self) -> String { return String::from("loglinear"); } + + fn model_name(&self) -> &'static str { + "loglinear" + } fn standard_functions(&self) -> HashSet { let mut to_r = HashSet::new(); to_r.insert(StdFunctions::EXP1); @@ -285,11 +293,15 @@ inline double linear(double alpha, double beta, double inp) { }", ); } - + fn function_name(&self) -> String { return String::from("linear"); } + fn model_name(&self) -> &'static str { + "robust_linear" + } + fn set_to_constant_model(&mut self, constant: u64) -> bool { self.params = (constant as f64, 0.0); return true; diff --git a/rmi_lib/src/models/linear_spline.rs b/rmi_lib/src/models/linear_spline.rs index 8a20f1f..377becd 100644 --- a/rmi_lib/src/models/linear_spline.rs +++ b/rmi_lib/src/models/linear_spline.rs @@ -76,6 +76,10 @@ inline double linear(double alpha, double beta, double inp) { return String::from("linear"); } + fn model_name(&self) -> &'static str { + "linear_spline" + } + fn set_to_constant_model(&mut self, constant: u64) -> bool { self.params = (constant as f64, 0.0); return true; diff --git a/rmi_lib/src/models/mod.rs b/rmi_lib/src/models/mod.rs index 1b086d4..124d21a 100644 --- a/rmi_lib/src/models/mod.rs +++ b/rmi_lib/src/models/mod.rs @@ -16,6 +16,7 @@ mod normal; mod radix; mod stdlib; mod utils; +mod optimal_pla; pub use balanced_radix::BalancedRadixModel; pub use cubic_spline::CubicSplineModel; @@ -29,6 +30,7 @@ pub use normal::NormalModel; pub use radix::RadixModel; pub use radix::RadixTable; pub use stdlib::StdFunctions; +pub use optimal_pla::OptimalPLAModel; use std::cmp::Ordering; use std::collections::HashSet; @@ -52,6 +54,15 @@ impl KeyType { } } + pub fn as_str(&self) -> &'static str { + match self { + KeyType::U32 => "u32", + KeyType::U64 => "u64", + KeyType::F64 => "f64", + KeyType::U128 => "u128", + } + } + pub fn to_model_data_type(self) -> ModelDataType { match self { KeyType::U32 => ModelDataType::Int, @@ -490,6 +501,7 @@ impl From for ModelInput { ModelInput::Float(f) } } +#[derive(Clone, Copy, Debug)] pub enum ModelDataType { Int, Int128, @@ -743,6 +755,7 @@ pub trait Model: Sync + Send { fn code(&self) -> String; fn function_name(&self) -> String; + fn model_name(&self) -> &'static str; fn standard_functions(&self) -> HashSet { return HashSet::new(); diff --git a/rmi_lib/src/models/normal.rs b/rmi_lib/src/models/normal.rs index 1940e6d..eb4571f 100644 --- a/rmi_lib/src/models/normal.rs +++ b/rmi_lib/src/models/normal.rs @@ -118,6 +118,10 @@ inline double ncdf(double mean, double stdev, double scale, double inp) { fn function_name(&self) -> String { return String::from("ncdf"); } + + fn model_name(&self) -> &'static str { + "normal" + } fn standard_functions(&self) -> HashSet { let mut to_r = HashSet::new(); to_r.insert(StdFunctions::EXP1); @@ -193,6 +197,10 @@ inline double lncdf(double mean, double stdev, double scale, double inp) { fn function_name(&self) -> String { return String::from("lncdf"); } + + fn model_name(&self) -> &'static str { + "lognormal" + } fn standard_functions(&self) -> HashSet { let mut to_r = HashSet::new(); to_r.insert(StdFunctions::EXP1); diff --git a/rmi_lib/src/models/optimal_pla.rs b/rmi_lib/src/models/optimal_pla.rs new file mode 100644 index 0000000..4a72f2a --- /dev/null +++ b/rmi_lib/src/models/optimal_pla.rs @@ -0,0 +1,218 @@ +// < begin copyright > +// Copyright Ryan Marcus 2020 +// +// See root directory of this project for license terms. +// +// < end copyright > + +use crate::models::*; +use log::trace; + +const MAX_SEGMENT_ABS_ERROR: f64 = 1.0; + +fn simple_lr>(loc_data: T) -> (f64, f64) { + let mut mean_x = 0.0; + let mut mean_y = 0.0; + let mut c = 0.0; + let mut n: u64 = 0; + let mut m2 = 0.0; + let mut data_size = 0; + + for (x, y) in loc_data { + n += 1; + let dx = x - mean_x; + mean_x += dx / (n as f64); + mean_y += (y - mean_y) / (n as f64); + c += dx * (y - mean_y); + + let dx2 = x - mean_x; + m2 += dx * dx2; + data_size += 1; + } + + if data_size == 0 { + return (0.0, 0.0); + } + + if data_size == 1 { + return (mean_y, 0.0); + } + + let cov = c / ((n - 1) as f64); + let var = m2 / ((n - 1) as f64); + + if var == 0.0 { + return (mean_y, 0.0); + } + + let beta: f64 = cov / var; + let alpha = mean_y - beta * mean_x; + + return (alpha, beta); +} + +fn segment_error(data: &RMITrainingData, + start: usize, + end: usize, + params: (f64, f64)) -> f64 { + let (alpha, beta) = params; + data.iter() + .skip(start) + .take(end - start) + .map(|(key, pos)| { + let prediction = beta.mul_add(key.as_float(), alpha); + (prediction - pos as f64).abs() + }) + .fold(0.0, f64::max) +} + +fn fit_segment(data: &RMITrainingData, start: usize, end: usize) -> (f64, f64) { + simple_lr(data + .iter() + .skip(start) + .take(end - start) + .map(|(inp, offset)| (inp.as_float(), offset as f64))) +} + +fn build_segments(data: &RMITrainingData) -> (Vec, Vec, Vec) { + if data.len() == 0 { + return (vec![0.0], vec![0.0], vec![u64::MAX]); + } + + let mut intercepts = Vec::new(); + let mut slopes = Vec::new(); + let mut boundaries = Vec::new(); + + let mut start = 0; + while start < data.len() { + let mut end = start + 1; + let mut best_params = fit_segment(data, start, end); + let mut best_err = segment_error(data, start, end, best_params); + + while end < data.len() { + let candidate_params = fit_segment(data, start, end + 1); + let candidate_err = segment_error(data, start, end + 1, candidate_params); + + if candidate_err <= MAX_SEGMENT_ABS_ERROR { + best_params = candidate_params; + best_err = candidate_err; + end += 1; + } else { + break; + } + } + + trace!("PLA segment {}:{} err {}", start, end, best_err); + intercepts.push(best_params.0); + slopes.push(best_params.1); + boundaries.push(data.get_key(end - 1).as_uint()); + start = end; + } + + return (intercepts, slopes, boundaries); +} + +pub struct OptimalPLAModel { + intercepts: Vec, + slopes: Vec, + boundaries: Vec, +} + +impl OptimalPLAModel { + pub fn new(data: &RMITrainingData) -> OptimalPLAModel { + let (intercepts, slopes, boundaries) = build_segments(data); + OptimalPLAModel { intercepts, slopes, boundaries } + } +} + +impl Model for OptimalPLAModel { + fn predict_to_float(&self, inp: &ModelInput) -> f64 { + if self.boundaries.is_empty() { + return 0.0; + } + + let key = inp.as_int(); + let idx = match self.boundaries.binary_search(&key) { + Ok(exact) => exact + 1, + Err(insert) => insert, + }; + + let bounded_idx = idx.min(self.intercepts.len() - 1); + self.slopes[bounded_idx].mul_add(inp.as_float(), self.intercepts[bounded_idx]) + } + + fn input_type(&self) -> ModelDataType { ModelDataType::Float } + fn output_type(&self) -> ModelDataType { ModelDataType::Float } + + fn params(&self) -> Vec { + vec![ + ModelParam::Int(self.intercepts.len() as u64), + ModelParam::FloatArray(self.intercepts.clone()), + ModelParam::FloatArray(self.slopes.clone()), + ModelParam::IntArray(self.boundaries.clone()), + ] + } + + fn code(&self) -> String { + String::from( + "inline double optimal_pla(uint64_t length, const double intercepts[], const double slopes[], const uint64_t boundaries[], double inp) {\n uint64_t idx = bs_upper_bound(boundaries, length, (uint64_t)inp);\n if (idx >= length) { idx = length - 1; }\n return std::fma(slopes[idx], inp, intercepts[idx]);\n}" + ) + } + + fn standard_functions(&self) -> HashSet { + let mut to_r = HashSet::new(); + to_r.insert(StdFunctions::BinarySearch); + to_r + } + + fn function_name(&self) -> String { String::from("optimal_pla") } + + fn model_name(&self) -> &'static str { "optimal_pla" } + + fn needs_bounds_check(&self) -> bool { false } + + fn set_to_constant_model(&mut self, constant: u64) -> bool { + self.intercepts = vec![constant as f64]; + self.slopes = vec![0.0]; + self.boundaries = vec![u64::MAX]; + true + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn fits_linear_run() { + let mut pts = Vec::new(); + for i in 0..50u64 { pts.push((i, i * 2)); } + let md = ModelData::IntKeyToIntPos(pts); + let pla = OptimalPLAModel::new(&md); + + assert_eq!(pla.boundaries.len(), 1); + assert_eq!(pla.predict_to_int(10.into()), 20); + assert_eq!(pla.predict_to_int(49.into()), 98); + } + + #[test] + fn makes_multiple_segments_when_needed() { + let md = ModelData::IntKeyToIntPos(vec![ + (0, 0), (1, 1), (2, 2), // first segment + (100, 120), (101, 121), (102, 122), // second far apart + ]); + + let pla = OptimalPLAModel::new(&md); + assert!(pla.boundaries.len() >= 2); + + assert_eq!(pla.predict_to_int(0.into()), 0); + assert_eq!(pla.predict_to_int(101.into()), 121); + } + + #[test] + fn empty_defaults_to_zero() { + let md: ModelData = ModelData::empty(); + let pla = OptimalPLAModel::new(&md); + assert_eq!(pla.predict_to_int(5.into()), 0); + } +} diff --git a/rmi_lib/src/models/radix.rs b/rmi_lib/src/models/radix.rs index 558c419..3cdbaaa 100644 --- a/rmi_lib/src/models/radix.rs +++ b/rmi_lib/src/models/radix.rs @@ -72,6 +72,10 @@ inline uint64_t radix(uint64_t prefix_length, uint64_t bits, uint64_t inp) { fn function_name(&self) -> String { return String::from("radix"); } + + fn model_name(&self) -> &'static str { + "radix" + } fn needs_bounds_check(&self) -> bool { return false; } @@ -161,6 +165,17 @@ inline uint64_t radix_table(const uint32_t* table, const uint64_t inp) {{ fn function_name(&self) -> String { return String::from("radix_table"); } + + fn model_name(&self) -> &'static str { + match self.table_bits { + 8 => "radix8", + 18 => "radix18", + 22 => "radix22", + 26 => "radix26", + 28 => "radix28", + _ => "radix_table", + } + } fn needs_bounds_check(&self) -> bool { return false; } diff --git a/rmi_lib/src/train/mod.rs b/rmi_lib/src/train/mod.rs index c52918f..4a3aefb 100644 --- a/rmi_lib/src/train/mod.rs +++ b/rmi_lib/src/train/mod.rs @@ -50,6 +50,7 @@ fn train_model(model_type: &str, "radix28" => Box::new(RadixTable::new(data, 28)), "bradix" => Box::new(BalancedRadixModel::new(data)), "histogram" => Box::new(EquidepthHistogramModel::new(data)), + "optimal_pla" => Box::new(OptimalPLAModel::new(data)), _ => panic!("Unknown model type: {}", model_type), }; diff --git a/src/load.rs b/src/load.rs index 9d9bbee..db91a5e 100644 --- a/src/load.rs +++ b/src/load.rs @@ -1,26 +1,25 @@ -// < begin copyright > +// < begin copyright > // Copyright Ryan Marcus 2020 -// +// // See root directory of this project for license terms. -// -// < end copyright > - - -use memmap::MmapOptions; -use rmi_lib::{RMITrainingData, RMITrainingDataIteratorProvider, KeyType}; +// +// < end copyright > + use byteorder::{LittleEndian, ReadBytesExt}; -use std::fs::File; +use memmap::MmapOptions; +use rmi_lib::{KeyType, RMITrainingData, RMITrainingDataIteratorProvider}; use std::convert::TryInto; +use std::fs::File; pub enum DataType { UINT64, UINT32, - FLOAT64 + FLOAT64, } struct SliceAdapterU64 { data: memmap::Mmap, - length: usize + length: usize, } impl RMITrainingDataIteratorProvider for SliceAdapterU64 { @@ -28,25 +27,31 @@ impl RMITrainingDataIteratorProvider for SliceAdapterU64 { fn cdf_iter(&self) -> Box + '_> { Box::new((0..self.length).map(move |i| self.get(i).unwrap())) } - + fn get(&self, idx: usize) -> Option<(Self::InpType, usize)> { - if idx >= self.length { return None; }; - let mi = u64::from_le_bytes((&self.data[8 + idx * 8..8 + (idx + 1) * 8]) - .try_into().unwrap()); + if idx >= self.length { + return None; + }; + let mi = u64::from_le_bytes( + (&self.data[8 + idx * 8..8 + (idx + 1) * 8]) + .try_into() + .unwrap(), + ); return Some((mi.into(), idx)); } - + fn key_type(&self) -> KeyType { KeyType::U64 } - - fn len(&self) -> usize { self.length } -} + fn len(&self) -> usize { + self.length + } +} struct SliceAdapterU32 { data: memmap::Mmap, - length: usize + length: usize, } impl RMITrainingDataIteratorProvider for SliceAdapterU32 { @@ -54,24 +59,30 @@ impl RMITrainingDataIteratorProvider for SliceAdapterU32 { fn cdf_iter(&self) -> Box + '_> { Box::new((0..self.length).map(move |i| self.get(i).unwrap())) } - + fn get(&self, idx: usize) -> Option<(Self::InpType, usize)> { - if idx >= self.length { return None; }; + if idx >= self.length { + return None; + }; let mi = (&self.data[8 + idx * 4..8 + (idx + 1) * 4]) - .read_u32::().unwrap().into(); + .read_u32::() + .unwrap() + .into(); return Some((mi, idx)); } - + fn key_type(&self) -> KeyType { KeyType::U32 } - - fn len(&self) -> usize { self.length } + + fn len(&self) -> usize { + self.length + } } struct SliceAdapterF64 { data: memmap::Mmap, - length: usize + length: usize, } impl RMITrainingDataIteratorProvider for SliceAdapterF64 { @@ -79,25 +90,31 @@ impl RMITrainingDataIteratorProvider for SliceAdapterF64 { fn cdf_iter(&self) -> Box + '_> { Box::new((0..self.length).map(move |i| self.get(i).unwrap())) } - + fn get(&self, idx: usize) -> Option<(Self::InpType, usize)> { - if idx >= self.length { return None; }; + if idx >= self.length { + return None; + }; let mi = (&self.data[8 + idx * 8..8 + (idx + 1) * 8]) - .read_f64::().unwrap().into(); + .read_f64::() + .unwrap() + .into(); return Some((mi, idx)); } - + fn key_type(&self) -> KeyType { KeyType::F64 } - - fn len(&self) -> usize { self.length } + + fn len(&self) -> usize { + self.length + } } pub enum RMIMMap { UINT64(RMITrainingData), UINT32(RMITrainingData), - FLOAT64(RMITrainingData) + FLOAT64(RMITrainingData), } macro_rules! dynamic { @@ -110,7 +127,6 @@ macro_rules! dynamic { } } - impl RMIMMap { pub fn soft_copy(&self) -> RMIMMap { match self { @@ -123,34 +139,31 @@ impl RMIMMap { pub fn into_u64(self) -> Option> { match self { RMIMMap::UINT64(x) => Some(x), - _ => None + _ => None, } } } - -pub fn load_data(filepath: &str, - dt: DataType) -> (usize, RMIMMap) { - let fd = File::open(filepath).unwrap_or_else(|_| { - panic!("Unable to open data file at {}", filepath) - }); +pub fn load_data(filepath: &str, dt: DataType) -> (usize, RMIMMap) { + let fd = + File::open(filepath).unwrap_or_else(|_| panic!("Unable to open data file at {}", filepath)); let mmap = unsafe { MmapOptions::new().map(&fd).unwrap() }; let num_items = (&mmap[0..8]).read_u64::().unwrap() as usize; let rtd = match dt { - DataType::UINT64 => - RMIMMap::UINT64(RMITrainingData::new(Box::new( - SliceAdapterU64 { data: mmap, length: num_items } - ))), - DataType::UINT32 => - RMIMMap::UINT32(RMITrainingData::new(Box::new( - SliceAdapterU32 { data: mmap, length: num_items } - ))), - DataType::FLOAT64 => - RMIMMap::FLOAT64(RMITrainingData::new(Box::new( - SliceAdapterF64 { data: mmap, length: num_items } - ))) + DataType::UINT64 => RMIMMap::UINT64(RMITrainingData::new(Box::new(SliceAdapterU64 { + data: mmap, + length: num_items, + }))), + DataType::UINT32 => RMIMMap::UINT32(RMITrainingData::new(Box::new(SliceAdapterU32 { + data: mmap, + length: num_items, + }))), + DataType::FLOAT64 => RMIMMap::FLOAT64(RMITrainingData::new(Box::new(SliceAdapterF64 { + data: mmap, + length: num_items, + }))), }; return (num_items, rtd); diff --git a/src/main.rs b/src/main.rs index b6e32ff..63b442c 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,11 +1,9 @@ -// < begin copyright > +// < begin copyright > // Copyright Ryan Marcus 2020 -// +// // See root directory of this project for license terms. -// -// < end copyright > - - +// +// < end copyright > #![allow(clippy::needless_return)] @@ -13,22 +11,22 @@ mod load; use load::{load_data, DataType}; -use rmi_lib::{train, train_bounded}; -use rmi_lib::KeyType; use rmi_lib::optimizer; +use rmi_lib::KeyType; +use rmi_lib::RMIModel; +use rmi_lib::{train, train_bounded}; use json::*; use log::*; +use rayon::prelude::*; use std::f64; +use std::fs; use std::fs::File; use std::io::BufWriter; -use std::fs; use std::path::Path; -use rayon::prelude::*; -use indicatif::{ProgressBar, ProgressStyle}; use clap::{App, Arg}; - +use indicatif::{ProgressBar, ProgressStyle}; fn main() { env_logger::init(); @@ -73,6 +71,15 @@ fn main() { .short("d") .value_name("dir") .help("exports parameters to files in this directory (default: rmi_data)")) + .arg(Arg::with_name("output-path") + .long("output-path") + .short("o") + .value_name("dir") + .help("exports generated model files to this directory (default: current directory)")) + .arg(Arg::with_name("binary-output") + .long("binary-output") + .value_name("PATH") + .help("Path to write binary model file")) .arg(Arg::with_name("no-errors") .long("no-errors") .help("do not save last-level errors, and modify the RMI function signature")) @@ -103,20 +110,25 @@ fn main() { // set the max number of threads to 4 by default, otherwise Rayon goes // crazy on larger machines and allocates too many workers for folds / reduces - let num_threads = matches.value_of("threads") + let num_threads = matches + .value_of("threads") .map(|x| x.parse::().unwrap()) .unwrap_or(4); - rayon::ThreadPoolBuilder::new().num_threads(num_threads).build_global().unwrap(); - - let fp = matches.value_of("input").unwrap(); + rayon::ThreadPoolBuilder::new() + .num_threads(num_threads) + .build_global() + .unwrap(); + let fp = matches.value_of("input").unwrap(); let data_dir = matches.value_of("data-path").unwrap_or("rmi_data"); - + let output_dir = matches.value_of("output-path").unwrap_or("."); + let binary_output = matches.value_of("binary-output").map(|s| s.to_string()); + if matches.value_of("namespace").is_some() && matches.value_of("param-grid").is_some() { panic!("Can only specify one of namespace or param-grid"); } - + info!("Reading {}...", fp); let mut key_type = KeyType::U64; @@ -132,8 +144,7 @@ fn main() { }; if matches.is_present("optimize") { - let results = dynamic!(optimizer::find_pareto_efficient_configs, - data, 10); + let results = dynamic!(optimizer::find_pareto_efficient_configs, data, 10); optimizer::RMIStatistics::display_table(&results); @@ -141,20 +152,24 @@ fn main() { matches.value_of("namespace").unwrap() } else { let path = Path::new(fp); - path.file_name().map(|s| s.to_str()).unwrap_or(Some("rmi")).unwrap() + path.file_name() + .map(|s| s.to_str()) + .unwrap_or(Some("rmi")) + .unwrap() }; - - let grid_specs: Vec = results.into_iter() + + let grid_specs: Vec = results + .into_iter() .enumerate() .map(|(idx, v)| { let nmspc = format!("{}_{}", nmspc_prefix, idx); v.to_grid_spec(&nmspc) - }).collect(); + }) + .collect(); let grid_specs_json = object!("configs" => grid_specs); let fp = matches.value_of("optimize").unwrap(); - let f = File::create(fp) - .expect("Could not write optimization results file"); + let f = File::create(fp).expect("Could not write optimization results file"); let mut bw = BufWriter::new(f); grid_specs_json.write(&mut bw).unwrap(); return; @@ -162,13 +177,28 @@ fn main() { // if we aren't optimizing, we should make sure the RMI data directory exists. if !Path::new(data_dir).exists() { - info!("The RMI data directory specified {} does not exist. Creating it.", - data_dir); + info!( + "The RMI data directory specified {} does not exist. Creating it.", + data_dir + ); std::fs::create_dir_all(data_dir) .expect("The RMI data directory did not exist, and it could not be created."); } - + + if !Path::new(output_dir).exists() { + info!( + "The RMI output directory specified {} does not exist. Creating it.", + output_dir + ); + std::fs::create_dir_all(output_dir) + .expect("The RMI output directory did not exist, and it could not be created."); + } + if let Some(param_grid) = matches.value_of("param-grid").map(|x| x.to_string()) { + if binary_output.is_some() { + warn!("Binary output is ignored when training multiple RMIs from a parameter grid"); + } + let pg = { let raw_json = fs::read_to_string(param_grid.clone()).unwrap(); let mut as_json = json::parse(raw_json.as_str()).unwrap(); @@ -182,7 +212,7 @@ fn main() { let branching = el["branching factor"].as_u64().unwrap(); let namespace = match el["namespace"].as_str() { Some(s) => Some(String::from(s)), - None => None + None => None, }; to_test.push((layers, branching, namespace)); @@ -191,19 +221,23 @@ fn main() { trace!("# RMIs to train: {}", to_test.len()); let pbar = ProgressBar::new(to_test.len() as u64); - pbar.set_style(ProgressStyle::default_bar() - .template("{pos} / {len} ({msg}) {wide_bar} {eta}")); + pbar.set_style( + ProgressStyle::default_bar().template("{pos} / {len} ({msg}) {wide_bar} {eta}"), + ); let train_func = |(models, branch_factor, namespace): &(String, u64, Option)| { - trace!("Training RMI {} with branching factor {}", - models, *branch_factor); - + trace!( + "Training RMI {} with branching factor {}", + models, + *branch_factor + ); + let loc_data = data.soft_copy(); let mut trained_model = dynamic!(train, loc_data, models, *branch_factor); - + let size_bs = rmi_lib::rmi_size(&trained_model); - + let result_obj = object! { "layers" => models.clone(), "branching factor" => *branch_factor, @@ -223,61 +257,63 @@ fn main() { if matches.is_present("zero-build-time") { trained_model.build_time = 0; } - + if let Some(nmspc) = namespace { rmi_lib::output_rmi( &nmspc, trained_model, data_dir, + output_dir, key_type, - true).unwrap(); - + true, + ) + .unwrap(); } - + pbar.inc(1); return result_obj; }; - let results: Vec = - if matches.is_present("disable-parallel-training") { - trace!("Training models sequentially"); - to_test.iter().map(train_func).collect() - } else { - trace!("Training models in parallel"); - to_test.par_iter().map(train_func).collect() - }; - + let results: Vec = if matches.is_present("disable-parallel-training") { + trace!("Training models sequentially"); + to_test.iter().map(train_func).collect() + } else { + trace!("Training models in parallel"); + to_test.par_iter().map(train_func).collect() + }; + //let results: Vec = to_test //.par_iter().map( pbar.finish(); - let f = File::create(format!("{}_results", param_grid)).expect("Could not write results file"); + let f = File::create(format!("{}_results", param_grid)) + .expect("Could not write results file"); let mut bw = BufWriter::new(f); let json_results = object! { "results" => results }; json_results.write(&mut bw).unwrap(); - } else { panic!("Configs must have an array as its value"); } - } else if matches.value_of("namespace").is_some() { let namespace = matches.value_of("namespace").unwrap().to_string(); let mut trained_model = match matches.value_of("max-size") { None => { - // assume they gave a model spec + // assume they gave a model spec let models = matches.value_of("models").unwrap(); let branch_factor = matches .value_of("branching factor") .unwrap() .parse::() .unwrap(); - + let trained_model = match matches.value_of("bounded") { None => dynamic!(train, data, models, branch_factor), Some(s) => { - let line_size = s.parse::() + let line_size = s + .parse::() .expect("Line size must be a positive integer."); - let d_u64 = data.into_u64() + let d_u64 = data + .into_u64() .expect("Can only construct a bounded RMI on u64 data."); train_bounded(&d_u64, models, branch_factor, line_size) } @@ -292,9 +328,12 @@ fn main() { trained_model } }; - + let no_errors = matches.is_present("no-errors"); - info!("Model build time: {} ms", trained_model.build_time / 1_000_000); + info!( + "Model build time: {} ms", + trained_model.build_time / 1_000_000 + ); info!( "Average model error: {} ({}%)", @@ -319,7 +358,15 @@ fn main() { trained_model.model_max_error, trained_model.model_max_error as f64 / num_rows as f64 * 100.0 ); - + + if let Some(path) = binary_output.as_ref() { + let rmi_model = RMIModel::from_trained(&trained_model, key_type, !no_errors); + rmi_model + .save_binary(path) + .expect("Failed to save binary RMI model"); + info!("Saved binary RMI model to {}", path); + } + if !matches.is_present("no-code") { if matches.is_present("zero-build-time") { trained_model.build_time = 0; @@ -329,8 +376,11 @@ fn main() { &namespace, trained_model, data_dir, + output_dir, key_type, - !no_errors).unwrap(); + !no_errors, + ) + .unwrap(); } else { trace!("Skipping code generation due to CLI flag"); } diff --git a/tests/cache_fix_osm/main.cpp b/tests/cache_fix_osm/main.cpp index d8d1773..e809068 100644 --- a/tests/cache_fix_osm/main.cpp +++ b/tests/cache_fix_osm/main.cpp @@ -1,7 +1,7 @@ #include #include #include -#include "rmi.h" +#include "../common/rmi_learned_index.h" int main() { // load the data @@ -19,7 +19,8 @@ int main() { std::cout << "Data loaded." << std::endl; - std::cout << "RMI status: " << rmi::load("rmi_data") << std::endl; + RMILearnedIndex learned_index; + std::cout << "RMI status: " << learned_index.Load("rmi_data") << std::endl; size_t err; @@ -29,7 +30,7 @@ int main() { std::distance(data.begin(), std::lower_bound(data.begin(), data.end(), lookup)); - uint64_t rmi_guess = rmi::lookup(lookup, &err); + uint64_t rmi_guess = learned_index.Lookup(lookup, &err); uint64_t diff = (rmi_guess > true_index ? rmi_guess - true_index : true_index - rmi_guess); @@ -42,6 +43,6 @@ int main() { } } - rmi::cleanup(); + learned_index.Cleanup(); exit(0); } diff --git a/tests/cache_fix_wiki/main.cpp b/tests/cache_fix_wiki/main.cpp index 35afc50..3f150ac 100644 --- a/tests/cache_fix_wiki/main.cpp +++ b/tests/cache_fix_wiki/main.cpp @@ -1,7 +1,7 @@ #include #include #include -#include "rmi.h" +#include "../common/rmi_learned_index.h" int main() { // load the data @@ -19,7 +19,8 @@ int main() { std::cout << "Data loaded." << std::endl; - std::cout << "RMI status: " << rmi::load("rmi_data") << std::endl; + RMILearnedIndex learned_index; + std::cout << "RMI status: " << learned_index.Load("rmi_data") << std::endl; size_t err; @@ -29,7 +30,7 @@ int main() { std::distance(data.begin(), std::lower_bound(data.begin(), data.end(), lookup)); - uint64_t rmi_guess = rmi::lookup(lookup, &err); + uint64_t rmi_guess = learned_index.Lookup(lookup, &err); uint64_t diff = (rmi_guess > true_index ? rmi_guess - true_index : true_index - rmi_guess); @@ -42,6 +43,6 @@ int main() { } } - rmi::cleanup(); + learned_index.Cleanup(); exit(0); } diff --git a/tests/common/rmi_learned_index.h b/tests/common/rmi_learned_index.h new file mode 100644 index 0000000..eb9cd00 --- /dev/null +++ b/tests/common/rmi_learned_index.h @@ -0,0 +1,38 @@ +#pragma once + +#include +#include +#include +#include + +#include "rmi.h" + +class RMILearnedIndex { + public: + bool Load(const std::string& root, + const std::string& dataset_namespace = rmi::NAME) { + std::filesystem::path full_path = std::filesystem::path(root) / dataset_namespace; + if (!std::filesystem::exists(full_path)) { + full_path = std::filesystem::path(root); + } + data_path_ = full_path.string(); + loaded_ = rmi::load(data_path_.c_str()); + return loaded_; + } + + uint64_t Lookup(uint64_t key, size_t* err) const { + return rmi::lookup(key, err); + } + + void Cleanup() { + rmi::cleanup(); + loaded_ = false; + } + + size_t SizeBytes() const { return rmi::RMI_SIZE; } + uint64_t BuildTimeNs() const { return rmi::BUILD_TIME_NS; } + + private: + std::string data_path_; + bool loaded_{false}; +}; diff --git a/tests/max_size_wiki/main.cpp b/tests/max_size_wiki/main.cpp index 959da88..025e181 100644 --- a/tests/max_size_wiki/main.cpp +++ b/tests/max_size_wiki/main.cpp @@ -1,7 +1,7 @@ #include #include #include -#include "rmi.h" +#include "../common/rmi_learned_index.h" int main() { // load the data @@ -19,7 +19,8 @@ int main() { std::cout << "Data loaded." << std::endl; - std::cout << "RMI status: " << rmi::load("rmi_data") << std::endl; + RMILearnedIndex learned_index; + std::cout << "RMI status: " << learned_index.Load("rmi_data") << std::endl; if (rmi::RMI_SIZE > 50000000) { std::cout << "RMI was larger than 50MB" << std::endl; @@ -34,7 +35,7 @@ int main() { std::distance(data.begin(), std::lower_bound(data.begin(), data.end(), lookup)); - uint64_t rmi_guess = rmi::lookup(lookup, &err); + uint64_t rmi_guess = learned_index.Lookup(lookup, &err); uint64_t diff = (rmi_guess > true_index ? rmi_guess - true_index : true_index - rmi_guess); if (diff > err) { @@ -46,6 +47,6 @@ int main() { } } - rmi::cleanup(); + learned_index.Cleanup(); exit(0); } diff --git a/tests/radix_model_wiki/main.cpp b/tests/radix_model_wiki/main.cpp index 264f549..eeda706 100644 --- a/tests/radix_model_wiki/main.cpp +++ b/tests/radix_model_wiki/main.cpp @@ -1,7 +1,7 @@ #include #include #include -#include "rmi.h" +#include "../common/rmi_learned_index.h" int main() { // load the data @@ -19,7 +19,8 @@ int main() { std::cout << "Data loaded." << std::endl; - std::cout << "RMI status: " << rmi::load("rmi_data") << std::endl; + RMILearnedIndex learned_index; + std::cout << "RMI status: " << learned_index.Load("rmi_data") << std::endl; size_t err; @@ -29,7 +30,7 @@ int main() { std::distance(data.begin(), std::lower_bound(data.begin(), data.end(), lookup)); - uint64_t rmi_guess = rmi::lookup(lookup, &err); + uint64_t rmi_guess = learned_index.Lookup(lookup, &err); uint64_t diff = (rmi_guess > true_index ? rmi_guess - true_index : true_index - rmi_guess); if (diff > err) { @@ -41,6 +42,6 @@ int main() { } } - rmi::cleanup(); + learned_index.Cleanup(); exit(0); } diff --git a/tests/simple_model_osm/main.cpp b/tests/simple_model_osm/main.cpp index 06768c3..095694c 100644 --- a/tests/simple_model_osm/main.cpp +++ b/tests/simple_model_osm/main.cpp @@ -1,7 +1,7 @@ #include #include #include -#include "rmi.h" +#include "../common/rmi_learned_index.h" int main() { // load the data @@ -19,7 +19,8 @@ int main() { std::cout << "Data loaded." << std::endl; - std::cout << "RMI status: " << rmi::load("rmi_data") << std::endl; + RMILearnedIndex learned_index; + std::cout << "RMI status: " << learned_index.Load("rmi_data") << std::endl; size_t err; @@ -29,7 +30,7 @@ int main() { std::distance(data.begin(), std::lower_bound(data.begin(), data.end(), lookup)); - uint64_t rmi_guess = rmi::lookup(lookup, &err); + uint64_t rmi_guess = learned_index.Lookup(lookup, &err); uint64_t diff = (rmi_guess > true_index ? rmi_guess - true_index : true_index - rmi_guess); if (diff > err) { @@ -41,6 +42,6 @@ int main() { } } - rmi::cleanup(); + learned_index.Cleanup(); exit(0); } diff --git a/tests/simple_model_wiki/main.cpp b/tests/simple_model_wiki/main.cpp index 264f549..eeda706 100644 --- a/tests/simple_model_wiki/main.cpp +++ b/tests/simple_model_wiki/main.cpp @@ -1,7 +1,7 @@ #include #include #include -#include "rmi.h" +#include "../common/rmi_learned_index.h" int main() { // load the data @@ -19,7 +19,8 @@ int main() { std::cout << "Data loaded." << std::endl; - std::cout << "RMI status: " << rmi::load("rmi_data") << std::endl; + RMILearnedIndex learned_index; + std::cout << "RMI status: " << learned_index.Load("rmi_data") << std::endl; size_t err; @@ -29,7 +30,7 @@ int main() { std::distance(data.begin(), std::lower_bound(data.begin(), data.end(), lookup)); - uint64_t rmi_guess = rmi::lookup(lookup, &err); + uint64_t rmi_guess = learned_index.Lookup(lookup, &err); uint64_t diff = (rmi_guess > true_index ? rmi_guess - true_index : true_index - rmi_guess); if (diff > err) { @@ -41,6 +42,6 @@ int main() { } } - rmi::cleanup(); + learned_index.Cleanup(); exit(0); } diff --git a/tests/verify_rmi b/tests/verify_rmi new file mode 100755 index 0000000..fde25bd Binary files /dev/null and b/tests/verify_rmi differ