diff --git a/bench/codegen.rs b/bench/codegen.rs index c9578333..c7ba73a7 100644 --- a/bench/codegen.rs +++ b/bench/codegen.rs @@ -16,6 +16,7 @@ fn bench(c: &mut Criterion) { CodegenSettings { is_async: false, derive_ser: true, + is_recursive: false, }, ) .unwrap() @@ -30,6 +31,7 @@ fn bench(c: &mut Criterion) { CodegenSettings { is_async: true, derive_ser: true, + is_recursive: false, }, ) .unwrap() diff --git a/cornucopia/src/cli.rs b/cornucopia/src/cli.rs index e83efda3..2135780f 100644 --- a/cornucopia/src/cli.rs +++ b/cornucopia/src/cli.rs @@ -23,6 +23,9 @@ struct Args { /// Derive serde's `Serialize` trait for generated types. #[clap(long)] serialize: bool, + /// Recursive lookup + #[clap(long)] + recursive: bool, } #[derive(Debug, Subcommand)] @@ -48,6 +51,7 @@ pub fn run() -> Result<(), Error> { action, sync, serialize, + recursive, } = Args::parse(); match action { @@ -60,6 +64,7 @@ pub fn run() -> Result<(), Error> { CodegenSettings { is_async: !sync, derive_ser: serialize, + is_recursive: recursive, }, )?; } @@ -73,6 +78,7 @@ pub fn run() -> Result<(), Error> { CodegenSettings { is_async: !sync, derive_ser: serialize, + is_recursive: recursive, }, ) { container::cleanup(podman).ok(); diff --git a/cornucopia/src/codegen.rs b/cornucopia/src/codegen.rs index 82315c5c..a84c99c6 100644 --- a/cornucopia/src/codegen.rs +++ b/cornucopia/src/codegen.rs @@ -311,6 +311,7 @@ fn gen_row_structs( CodegenSettings { is_async, derive_ser, + .. }: CodegenSettings, ) { let PreparedItem { @@ -614,6 +615,7 @@ fn gen_custom_type( CodegenSettings { derive_ser, is_async, + .. }: CodegenSettings, ) { let PreparedType { diff --git a/cornucopia/src/lib.rs b/cornucopia/src/lib.rs index 6547d04f..6a3ebf5b 100644 --- a/cornucopia/src/lib.rs +++ b/cornucopia/src/lib.rs @@ -20,7 +20,7 @@ use codegen::generate as generate_internal; use error::WriteOutputError; use parser::parse_query_module; use prepare_queries::prepare; -use read_queries::read_query_modules; +use read_queries::{read_query_modules, read_query_modules_recursive}; #[doc(hidden)] pub use cli::run; @@ -33,6 +33,7 @@ pub use load_schema::load_schema; pub struct CodegenSettings { pub is_async: bool, pub derive_ser: bool, + pub is_recursive: bool, } /// Generates Rust queries from PostgreSQL queries located at `queries_path`, @@ -46,10 +47,17 @@ pub fn generate_live( settings: CodegenSettings, ) -> Result { // Read - let modules = read_query_modules(queries_path)? - .into_iter() - .map(parse_query_module) - .collect::>()?; + let modules = if settings.is_recursive { + read_query_modules_recursive(queries_path)? + .into_iter() + .map(parse_query_module) + .collect::>()? + } else { + read_query_modules(queries_path)? + .into_iter() + .map(parse_query_module) + .collect::>()? + }; // Generate let prepared_modules = prepare(client, modules)?; let generated_code = generate_internal(prepared_modules, settings); diff --git a/cornucopia/src/read_queries.rs b/cornucopia/src/read_queries.rs index 75449fa8..c8f6d0ee 100644 --- a/cornucopia/src/read_queries.rs +++ b/cornucopia/src/read_queries.rs @@ -1,3 +1,5 @@ +use std::path::{Path, PathBuf}; + use miette::NamedSource; use self::error::Error; @@ -68,6 +70,76 @@ pub(crate) fn read_query_modules(dir_path: &str) -> Result, Erro Ok(modules_info) } +/// Reads queries in the directory and checks each directory found within given path. +/// Only .sql files are considered. +/// +/// # Error +/// Returns an error if `dir_path` does not point to a valid directory or if a query file cannot be parsed. +pub(crate) fn read_query_modules_recursive(dir_path: &str) -> Result, Error> { + let mut modules_info = Vec::new(); + for entry_result in std::fs::read_dir(dir_path).map_err(|err| Error { + err, + path: String::from(dir_path), + })? { + // Directory entry + let entry = entry_result.map_err(|err| Error { + err, + path: dir_path.to_owned(), + })?; + let path_buf = entry.path(); + + let path_bufs = if path_buf.is_dir() { + find_queries(&path_buf, Vec::::new()) + } else { + vec![path_buf] + }; + + // Check we're dealing with a .sql file + for path_buf in path_bufs { + if path_buf + .extension() + .map(|extension| extension == "sql") + .unwrap_or_default() + { + let module_name = path_buf + .file_stem() + .expect("is a file") + .to_str() + .expect("file name is valid utf8") + .to_string(); + + let file_contents = std::fs::read_to_string(&path_buf).map_err(|err| Error { + err, + path: dir_path.to_owned(), + })?; + + modules_info.push(ModuleInfo { + path: String::from(path_buf.to_string_lossy()), + name: module_name, + content: file_contents, + }); + } + } + } + // Sort module for consistent codegen + modules_info.sort_by(|a, b| a.name.cmp(&b.name)); + Ok(modules_info) +} + +fn find_queries(start: &Path, mut queries: Vec) -> Vec { + for entry in start.read_dir().unwrap() { + let entry = entry.unwrap(); + let path = entry.path(); + if path.is_dir() { + queries = find_queries(&path, queries); + } else { + queries.push(path); + } + } + + queries +} + pub(crate) mod error { use miette::Diagnostic; use thiserror::Error as ThisError; diff --git a/integration/src/main.rs b/integration/src/main.rs index eb42a1b2..ee3e7e1e 100644 --- a/integration/src/main.rs +++ b/integration/src/main.rs @@ -161,6 +161,7 @@ fn run_errors_test( CodegenSettings { is_async: false, derive_ser: false, + is_recursive: false, }, )?; Ok(()) @@ -235,6 +236,7 @@ fn run_codegen_test( CodegenSettings { is_async, derive_ser, + is_recursive: false, }, ) .map_err(Error::report)?; @@ -254,6 +256,7 @@ fn run_codegen_test( CodegenSettings { is_async, derive_ser, + is_recursive: false, }, ) .map_err(Error::report)?;