From 932150b70e8f1eb0059738a8db894789caee9c04 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Juli=C3=A1n=20D=2E=20Ot=C3=A1lvaro?= Date: Tue, 13 Jan 2026 14:53:42 +0000 Subject: [PATCH] feat: Json --- examples/json_exa.rs | 312 +++++++ schemas/model-v1.json | 792 ++++++++++++++++++ src/json/codegen/analytical.rs | 11 + src/json/codegen/closures.rs | 571 +++++++++++++ src/json/codegen/mod.rs | 235 ++++++ src/json/codegen/ode.rs | 11 + src/json/codegen/sde.rs | 11 + src/json/errors.rs | 157 ++++ src/json/library/mod.rs | 517 ++++++++++++ src/json/library/models/pk_1cmt_iv.json | 17 + src/json/library/models/pk_1cmt_iv_ode.json | 20 + src/json/library/models/pk_1cmt_oral.json | 17 + src/json/library/models/pk_1cmt_oral_ode.json | 27 + src/json/library/models/pk_2cmt_iv.json | 17 + src/json/library/models/pk_2cmt_iv_ode.json | 21 + src/json/library/models/pk_2cmt_oral.json | 17 + src/json/library/models/pk_2cmt_oral_ode.json | 28 + src/json/library/models/pk_3cmt_iv.json | 17 + src/json/library/models/pk_3cmt_oral.json | 17 + src/json/mod.rs | 219 +++++ src/json/model.rs | 414 +++++++++ src/json/types.rs | 499 +++++++++++ src/json/validation.rs | 451 ++++++++++ src/lib.rs | 1 + tests/test_json.rs | 788 +++++++++++++++++ 25 files changed, 5187 insertions(+) create mode 100644 examples/json_exa.rs create mode 100644 schemas/model-v1.json create mode 100644 src/json/codegen/analytical.rs create mode 100644 src/json/codegen/closures.rs create mode 100644 src/json/codegen/mod.rs create mode 100644 src/json/codegen/ode.rs create mode 100644 src/json/codegen/sde.rs create mode 100644 src/json/errors.rs create mode 100644 src/json/library/mod.rs create mode 100644 src/json/library/models/pk_1cmt_iv.json create mode 100644 src/json/library/models/pk_1cmt_iv_ode.json create mode 100644 src/json/library/models/pk_1cmt_oral.json create mode 100644 src/json/library/models/pk_1cmt_oral_ode.json create mode 100644 src/json/library/models/pk_2cmt_iv.json create mode 100644 src/json/library/models/pk_2cmt_iv_ode.json create mode 100644 src/json/library/models/pk_2cmt_oral.json create mode 100644 src/json/library/models/pk_2cmt_oral_ode.json create mode 100644 src/json/library/models/pk_3cmt_iv.json create mode 100644 src/json/library/models/pk_3cmt_oral.json create mode 100644 src/json/mod.rs create mode 100644 src/json/model.rs create mode 100644 src/json/types.rs create mode 100644 src/json/validation.rs create mode 100644 tests/test_json.rs diff --git a/examples/json_exa.rs b/examples/json_exa.rs new file mode 100644 index 0000000..cc8791a --- /dev/null +++ b/examples/json_exa.rs @@ -0,0 +1,312 @@ +// Run with: cargo run --example json_exa --features exa +// +// This example demonstrates JSON model compilation using the `exa` feature. +// It compares predictions from: +// 1. A statically defined ODE model (Rust code) +// 2. A dynamically compiled ODE model (via exa, raw Rust string) +// 3. A JSON-defined ODE model (via compile_json) +// 4. A JSON-defined Analytical model (via compile_json) + +#[cfg(feature = "exa")] +fn main() { + use pharmsol::prelude::*; + use pharmsol::{exa, json, Analytical, ODE}; + use std::path::PathBuf; + + // Create test subject with infusion and observations + let subject = Subject::builder("1") + .infusion(0.0, 500.0, 0, 0.5) + .observation(0.5, 1.645776, 0) + .observation(1.0, 1.216442, 0) + .observation(2.0, 0.4622729, 0) + .observation(3.0, 0.1697458, 0) + .observation(4.0, 0.06382178, 0) + .observation(6.0, 0.009099384, 0) + .observation(8.0, 0.001017932, 0) + .build(); + + // Parameters: ke (elimination rate constant), V (volume of distribution) + let params = vec![1.2, 50.0]; + + let test_dir = std::env::current_dir().expect("Failed to get current directory"); + + // Shared template path for all compilations (they run sequentially) + let template_path = std::env::temp_dir().join("exa_json_example"); + + // ========================================================================= + // 1. Create ODE model directly (static Rust code) + // ========================================================================= + println!("1. Creating static ODE model..."); + let static_ode = equation::ODE::new( + |x, p, _t, dx, _bolus, rateiv, _cov| { + fetch_params!(p, ke, _v); + dx[0] = -ke * x[0] + rateiv[0]; + }, + |_p, _t, _cov| lag! {}, + |_p, _t, _cov| fa! {}, + |_p, _t, _cov, _x| {}, + |x, p, _t, _cov, y| { + fetch_params!(p, _ke, v); + y[0] = x[0] / v; + }, + (1, 1), + ); + println!(" ✓ Static ODE model created\n"); + + // ========================================================================= + // 2. Compile ODE model dynamically using exa (raw Rust string) + // ========================================================================= + println!("2. Compiling ODE model via exa (raw Rust)..."); + let exa_ode_path = test_dir.join("exa_ode_model.pkm"); + + let exa_ode_compiled = exa::build::compile::( + r#" + equation::ODE::new( + |x, p, _t, dx, _bolus, rateiv, _cov| { + fetch_params!(p, ke, _V); + dx[0] = -ke * x[0] + rateiv[0]; + }, + |_p, _t, _cov| lag! {}, + |_p, _t, _cov| fa! {}, + |_p, _t, _cov, _x| {}, + |x, p, _t, _cov, y| { + fetch_params!(p, _ke, V); + y[0] = x[0] / V; + }, + (1, 1), + ) + "# + .to_string(), + Some(exa_ode_path.clone()), + vec!["ke".to_string(), "V".to_string()], + template_path.clone(), + |_, _| {}, + ) + .expect("Failed to compile ODE model via exa"); + + let exa_ode_path = PathBuf::from(&exa_ode_compiled); + let (_lib_exa_ode, (dynamic_exa_ode, _)) = + unsafe { exa::load::load::(exa_ode_path.clone()) }; + println!(" ✓ Compiled to: {}\n", exa_ode_compiled); + + // ========================================================================= + // 3. Compile ODE model from JSON using compile_json + // ========================================================================= + println!("3. Compiling ODE model from JSON..."); + + let json_ode = r#"{ + "schema": "1.0", + "id": "pk_1cmt_iv_ode", + "type": "ode", + "parameters": ["ke", "V"], + "compartments": ["central"], + "diffeq": { + "central": "-ke * central + rateiv[0]" + }, + "output": "central / V", + "display": { + "name": "One-Compartment IV ODE", + "category": "pk" + } + }"#; + + // First, show the generated code + let generated = json::generate_code(json_ode).expect("Failed to generate code from JSON"); + println!(" Generated Rust code:"); + println!(" ─────────────────────────────────────"); + for line in generated.equation_code.lines().take(15) { + println!(" {}", line); + } + println!(" ...\n"); + + let json_ode_path = test_dir.join("json_ode_model.pkm"); + + let json_ode_compiled = json::compile_json::( + json_ode, + Some(json_ode_path.clone()), + template_path.clone(), + |_, _| {}, + ) + .expect("Failed to compile JSON ODE model"); + + let json_ode_path = PathBuf::from(&json_ode_compiled); + let (_lib_json_ode, (dynamic_json_ode, meta_ode)) = + unsafe { exa::load::load::(json_ode_path.clone()) }; + println!( + " ✓ Compiled to: {} (params: {:?})\n", + json_ode_compiled, + meta_ode.get_params() + ); + + // ========================================================================= + // 4. Compile Analytical model from JSON using compile_json + // ========================================================================= + println!("4. Compiling Analytical model from JSON..."); + + let json_analytical = r#"{ + "schema": "1.0", + "id": "pk_1cmt_iv_analytical", + "type": "analytical", + "analytical": "one_compartment", + "parameters": ["ke", "V"], + "output": "x[0] / V", + "display": { + "name": "One-Compartment IV Analytical", + "category": "pk" + } + }"#; + + let json_analytical_path = test_dir.join("json_analytical_model.pkm"); + + let json_analytical_compiled = json::compile_json::( + json_analytical, + Some(json_analytical_path.clone()), + template_path.clone(), + |_, _| {}, + ) + .expect("Failed to compile JSON Analytical model"); + + let json_analytical_path = PathBuf::from(&json_analytical_compiled); + let (_lib_json_analytical, (dynamic_json_analytical, meta_analytical)) = + unsafe { exa::load::load::(json_analytical_path.clone()) }; + println!( + " ✓ Compiled to: {} (params: {:?})\n", + json_analytical_compiled, + meta_analytical.get_params() + ); + + // ========================================================================= + // 5. Compare predictions from all four models + // ========================================================================= + println!("{}", "═".repeat(80)); + println!("Comparing predictions (ke={}, V={})", params[0], params[1]); + println!("{}", "═".repeat(80)); + + let static_preds = static_ode + .estimate_predictions(&subject, ¶ms) + .expect("Static ODE prediction failed"); + let exa_ode_preds = dynamic_exa_ode + .estimate_predictions(&subject, ¶ms) + .expect("Exa ODE prediction failed"); + let json_ode_preds = dynamic_json_ode + .estimate_predictions(&subject, ¶ms) + .expect("JSON ODE prediction failed"); + let json_analytical_preds = dynamic_json_analytical + .estimate_predictions(&subject, ¶ms) + .expect("JSON Analytical prediction failed"); + + let static_flat = static_preds.flat_predictions(); + let exa_ode_flat = exa_ode_preds.flat_predictions(); + let json_ode_flat = json_ode_preds.flat_predictions(); + let json_analytical_flat = json_analytical_preds.flat_predictions(); + + println!( + "\n{:<8} {:>14} {:>14} {:>14} {:>14}", + "Time", "Static ODE", "Exa ODE", "JSON ODE", "JSON Analyt." + ); + println!("{}", "─".repeat(80)); + + let times = [0.5, 1.0, 2.0, 3.0, 4.0, 6.0, 8.0]; + for (i, &time) in times.iter().enumerate() { + println!( + "{:<8.1} {:>14.6} {:>14.6} {:>14.6} {:>14.6}", + time, static_flat[i], exa_ode_flat[i], json_ode_flat[i], json_analytical_flat[i] + ); + } + + // ========================================================================= + // 6. Verification + // ========================================================================= + println!("\n{}", "═".repeat(80)); + println!("Verification:"); + println!("{}", "─".repeat(80)); + + // Static ODE vs Exa ODE + let static_vs_exa = static_flat + .iter() + .zip(exa_ode_flat.iter()) + .all(|(a, b)| (a - b).abs() < 1e-10); + println!( + " Static ODE vs Exa ODE: {} (tolerance: 1e-10)", + if static_vs_exa { + "✓ MATCH" + } else { + "✗ MISMATCH" + } + ); + + // Static ODE vs JSON ODE + let static_vs_json_ode = static_flat + .iter() + .zip(json_ode_flat.iter()) + .all(|(a, b)| (a - b).abs() < 1e-10); + println!( + " Static ODE vs JSON ODE: {} (tolerance: 1e-10)", + if static_vs_json_ode { + "✓ MATCH" + } else { + "✗ MISMATCH" + } + ); + + // Static ODE vs JSON Analytical + let static_vs_json_analytical = static_flat + .iter() + .zip(json_analytical_flat.iter()) + .all(|(a, b)| (a - b).abs() < 1e-3); + println!( + " Static ODE vs JSON Analytical: {} (tolerance: 1e-3)", + if static_vs_json_analytical { + "✓ CLOSE" + } else { + "✗ DIFFERS" + } + ); + + // ========================================================================= + // 7. Demonstrate JSON Model Library + // ========================================================================= + println!("\n{}", "═".repeat(80)); + println!("JSON Model Library:"); + println!("{}", "─".repeat(80)); + + let library = json::ModelLibrary::builtin(); + println!(" Available builtin models ({}):", library.list().len()); + for id in library.list() { + let model = library.get(id).unwrap(); + let model_type = match &model.model_type { + json::ModelType::Analytical => "Analytical", + json::ModelType::Ode => "ODE", + json::ModelType::Sde => "SDE", + }; + let name = model + .display + .as_ref() + .and_then(|d| d.name.as_ref()) + .map(|s| s.as_str()) + .unwrap_or("(unnamed)"); + println!(" • {} [{}]: {}", id, model_type, name); + } + + // ========================================================================= + // 8. Clean up + // ========================================================================= + println!("\n{}", "═".repeat(80)); + println!("Cleaning up..."); + + std::fs::remove_file(&exa_ode_path).ok(); + std::fs::remove_file(&json_ode_path).ok(); + std::fs::remove_file(&json_analytical_path).ok(); + std::fs::remove_dir_all(&template_path).ok(); + + println!(" ✓ Removed compiled model files"); + println!(" ✓ Removed temporary build directory"); + println!("\nDone!"); +} + +#[cfg(not(feature = "exa"))] +fn main() { + eprintln!("This example requires the 'exa' feature."); + eprintln!("Run with: cargo run --example json_exa --features exa"); + std::process::exit(1); +} diff --git a/schemas/model-v1.json b/schemas/model-v1.json new file mode 100644 index 0000000..cc798bc --- /dev/null +++ b/schemas/model-v1.json @@ -0,0 +1,792 @@ +{ + "$schema": "https://json-schema.org/draft/2020-12/schema", + "$id": "https://pharmsol.rs/schemas/model-v1.json", + "title": "pharmsol Model Definition", + "description": "JSON Schema for pharmacometric model definitions in pharmsol. Supports analytical, ODE, and SDE model types.", + "type": "object", + + "$defs": { + "parameterName": { + "type": "string", + "pattern": "^[a-zA-Z_][a-zA-Z0-9_]*$", + "description": "Valid parameter name (starts with letter or underscore)" + }, + + "compartmentName": { + "type": "string", + "pattern": "^[a-zA-Z_][a-zA-Z0-9_]*$", + "description": "Valid compartment name (starts with letter or underscore)" + }, + + "expression": { + "type": "string", + "minLength": 1, + "description": "A Rust expression (e.g., 'x[0] / V', '-ka * x[0]')" + }, + + "analyticalFunction": { + "type": "string", + "enum": [ + "one_compartment", + "one_compartment_with_absorption", + "two_compartments", + "two_compartments_with_absorption", + "three_compartments", + "three_compartments_with_absorption" + ], + "description": "Built-in analytical solution function name" + }, + + "parameterScale": { + "type": "string", + "enum": ["linear", "log", "logit"], + "default": "log", + "description": "Parameter transformation scale for estimation" + }, + + "parameterDefinition": { + "type": "object", + "properties": { + "name": { + "$ref": "#/$defs/parameterName", + "description": "Parameter symbol/name" + }, + "bounds": { + "type": "array", + "items": { "type": "number" }, + "minItems": 2, + "maxItems": 2, + "description": "Lower and upper bounds [min, max]" + }, + "scale": { + "$ref": "#/$defs/parameterScale" + }, + "units": { + "type": "string", + "description": "Parameter units (e.g., 'L/h', '1/h', 'L')" + }, + "description": { + "type": "string", + "description": "Human-readable description" + }, + "typical": { + "type": "number", + "description": "Typical/initial value" + } + }, + "required": ["name"], + "additionalProperties": false + }, + + "derivedParameter": { + "type": "object", + "properties": { + "symbol": { + "$ref": "#/$defs/parameterName", + "description": "Symbol for the derived parameter" + }, + "expression": { + "$ref": "#/$defs/expression", + "description": "Expression to compute the derived parameter" + } + }, + "required": ["symbol", "expression"], + "additionalProperties": false + }, + + "parameterization": { + "type": "object", + "properties": { + "id": { + "type": "string", + "pattern": "^[a-z][a-z0-9_]*$", + "description": "Unique identifier for this parameterization" + }, + "name": { + "type": "string", + "description": "Human-readable name" + }, + "default": { + "type": "boolean", + "default": false, + "description": "Whether this is the default parameterization" + }, + "parameters": { + "type": "array", + "items": { "$ref": "#/$defs/parameterDefinition" }, + "minItems": 1, + "description": "Parameter definitions for this parameterization" + }, + "derived": { + "type": "array", + "items": { "$ref": "#/$defs/derivedParameter" }, + "description": "Parameters derived from the primary parameters" + }, + "nonmem": { + "type": "string", + "description": "NONMEM TRANS equivalent (e.g., 'TRANS1', 'TRANS2')" + } + }, + "required": ["id", "parameters"], + "additionalProperties": false + }, + + "covariateType": { + "type": "string", + "enum": ["continuous", "categorical"], + "default": "continuous" + }, + + "interpolationMethod": { + "type": "string", + "enum": ["linear", "constant", "locf"], + "default": "linear", + "description": "How to interpolate covariate values between time points" + }, + + "covariateDefinition": { + "type": "object", + "properties": { + "id": { + "type": "string", + "pattern": "^[a-zA-Z_][a-zA-Z0-9_]*$", + "description": "Covariate identifier (used in code)" + }, + "name": { + "type": "string", + "description": "Human-readable name" + }, + "type": { + "$ref": "#/$defs/covariateType" + }, + "units": { + "type": "string", + "description": "Units for continuous covariates" + }, + "reference": { + "type": "number", + "description": "Reference value for centering (e.g., 70 for weight)" + }, + "interpolation": { + "$ref": "#/$defs/interpolationMethod" + }, + "levels": { + "type": "array", + "items": { "type": "string" }, + "description": "Possible values for categorical covariates" + } + }, + "required": ["id"], + "additionalProperties": false + }, + + "covariateEffectType": { + "type": "string", + "enum": [ + "allometric", + "linear", + "exponential", + "proportional", + "categorical", + "custom" + ], + "description": "Type of covariate effect relationship" + }, + + "covariateEffect": { + "type": "object", + "properties": { + "on": { + "$ref": "#/$defs/parameterName", + "description": "Parameter affected by this covariate" + }, + "covariate": { + "type": "string", + "description": "Covariate ID" + }, + "type": { + "$ref": "#/$defs/covariateEffectType" + }, + "exponent": { + "type": "number", + "description": "Exponent for allometric scaling (e.g., 0.75 for CL)" + }, + "slope": { + "type": "number", + "description": "Slope for linear/exponential effects" + }, + "reference": { + "type": "number", + "description": "Reference value for centering" + }, + "expression": { + "$ref": "#/$defs/expression", + "description": "Custom expression for type='custom'" + }, + "levels": { + "type": "object", + "additionalProperties": { "type": "number" }, + "description": "Multipliers for each categorical level" + } + }, + "required": ["on", "type"], + "allOf": [ + { + "if": { "properties": { "type": { "const": "allometric" } } }, + "then": { "required": ["covariate", "exponent"] } + }, + { + "if": { "properties": { "type": { "const": "linear" } } }, + "then": { "required": ["covariate", "slope"] } + }, + { + "if": { "properties": { "type": { "const": "custom" } } }, + "then": { "required": ["expression"] } + }, + { + "if": { "properties": { "type": { "const": "categorical" } } }, + "then": { "required": ["covariate", "levels"] } + } + ], + "additionalProperties": false + }, + + "errorModelType": { + "type": "string", + "enum": ["additive", "proportional", "combined", "polynomial"], + "description": "Type of residual error model" + }, + + "errorModel": { + "type": "object", + "properties": { + "type": { + "$ref": "#/$defs/errorModelType" + }, + "additive": { + "type": "number", + "minimum": 0, + "description": "Additive error standard deviation" + }, + "proportional": { + "type": "number", + "minimum": 0, + "description": "Proportional error coefficient (CV)" + }, + "cv": { + "type": "number", + "minimum": 0, + "description": "Coefficient of variation (alias for proportional)" + }, + "sd": { + "type": "number", + "minimum": 0, + "description": "Standard deviation (alias for additive)" + }, + "coefficients": { + "type": "array", + "items": { "type": "number" }, + "minItems": 4, + "maxItems": 4, + "description": "Polynomial coefficients [c0, c1, c2, c3]" + }, + "lambda": { + "type": "number", + "default": 0, + "description": "Lambda parameter for polynomial error" + } + }, + "required": ["type"], + "additionalProperties": false + }, + + "outputDefinition": { + "type": "object", + "properties": { + "id": { + "type": "string", + "description": "Output identifier" + }, + "equation": { + "$ref": "#/$defs/expression", + "description": "Output equation expression" + }, + "name": { + "type": "string", + "description": "Human-readable name" + }, + "units": { + "type": "string", + "description": "Output units" + } + }, + "required": ["equation"], + "additionalProperties": false + }, + + "diffeqObject": { + "type": "object", + "additionalProperties": { + "$ref": "#/$defs/expression" + }, + "description": "Map of compartment name to differential equation expression" + }, + + "lagObject": { + "type": "object", + "additionalProperties": { + "oneOf": [{ "$ref": "#/$defs/expression" }, { "type": "number" }] + }, + "description": "Map of compartment index (as string) to lag time expression or value" + }, + + "faObject": { + "type": "object", + "additionalProperties": { + "oneOf": [ + { "$ref": "#/$defs/expression" }, + { "type": "number", "minimum": 0, "maximum": 1 } + ] + }, + "description": "Map of compartment index (as string) to bioavailability expression or value" + }, + + "initObject": { + "type": "object", + "additionalProperties": { + "oneOf": [{ "$ref": "#/$defs/expression" }, { "type": "number" }] + }, + "description": "Map of compartment name or index to initial condition" + }, + + "diffusionObject": { + "type": "object", + "additionalProperties": { + "oneOf": [ + { "$ref": "#/$defs/expression" }, + { "type": "number", "minimum": 0 } + ] + }, + "description": "Map of state name to diffusion coefficient" + }, + + "position": { + "type": "object", + "properties": { + "x": { "type": "number" }, + "y": { "type": "number" } + }, + "required": ["x", "y"], + "additionalProperties": false + }, + + "layoutObject": { + "type": "object", + "additionalProperties": { + "$ref": "#/$defs/position" + }, + "description": "Map of compartment/element name to position" + }, + + "complexity": { + "type": "string", + "enum": ["basic", "intermediate", "advanced"], + "description": "Model complexity level" + }, + + "category": { + "type": "string", + "enum": ["pk", "pd", "pkpd", "disease", "other"], + "description": "Model category" + }, + + "displayInfo": { + "type": "object", + "properties": { + "name": { + "type": "string", + "description": "Human-readable model name" + }, + "shortName": { + "type": "string", + "description": "Abbreviated name" + }, + "category": { + "$ref": "#/$defs/category" + }, + "subcategory": { + "type": "string", + "description": "Model subcategory" + }, + "complexity": { + "$ref": "#/$defs/complexity" + }, + "icon": { + "type": "string", + "description": "Icon identifier" + }, + "tags": { + "type": "array", + "items": { "type": "string" }, + "description": "Searchable tags" + } + }, + "additionalProperties": false + }, + + "reference": { + "type": "object", + "properties": { + "authors": { "type": "string" }, + "title": { "type": "string" }, + "journal": { "type": "string" }, + "year": { "type": "integer" }, + "doi": { "type": "string" }, + "pmid": { "type": "string" } + }, + "additionalProperties": false + }, + + "documentation": { + "type": "object", + "properties": { + "summary": { + "type": "string", + "description": "One-line summary" + }, + "description": { + "type": "string", + "description": "Detailed description" + }, + "equations": { + "type": "object", + "properties": { + "differential": { "type": "string" }, + "solution": { "type": "string" } + }, + "description": "LaTeX equations for display" + }, + "assumptions": { + "type": "array", + "items": { "type": "string" }, + "description": "Model assumptions" + }, + "whenToUse": { + "type": "array", + "items": { "type": "string" }, + "description": "When to use this model" + }, + "whenNotToUse": { + "type": "array", + "items": { "type": "string" }, + "description": "When NOT to use this model" + }, + "references": { + "type": "array", + "items": { "$ref": "#/$defs/reference" }, + "description": "Literature references" + } + }, + "additionalProperties": false + } + }, + + "properties": { + "schema": { + "type": "string", + "pattern": "^[0-9]+\\.[0-9]+$", + "description": "Schema version (e.g., '1.0')" + }, + "id": { + "type": "string", + "pattern": "^[a-z][a-z0-9_]*$", + "description": "Unique model identifier (snake_case)" + }, + "type": { + "type": "string", + "enum": ["analytical", "ode", "sde"], + "description": "Model equation type" + }, + "extends": { + "type": "string", + "description": "Library model ID to inherit from" + }, + "version": { + "type": "string", + "pattern": "^[0-9]+\\.[0-9]+\\.[0-9]+", + "description": "Model version (semver)" + }, + "aliases": { + "type": "array", + "items": { "type": "string" }, + "description": "Alternative names (e.g., NONMEM ADVAN codes)" + }, + + "parameters": { + "type": "array", + "items": { "$ref": "#/$defs/parameterName" }, + "minItems": 1, + "uniqueItems": true, + "description": "Parameter names in fetch order" + }, + "compartments": { + "type": "array", + "items": { "$ref": "#/$defs/compartmentName" }, + "uniqueItems": true, + "description": "Compartment names (indexed in order)" + }, + "states": { + "type": "array", + "items": { "type": "string" }, + "uniqueItems": true, + "description": "State variable names (for SDE)" + }, + + "analytical": { + "$ref": "#/$defs/analyticalFunction", + "description": "Built-in analytical solution function" + }, + "diffeq": { + "oneOf": [ + { "$ref": "#/$defs/expression" }, + { "$ref": "#/$defs/diffeqObject" } + ], + "description": "Differential equations (string or object)" + }, + "drift": { + "oneOf": [ + { "$ref": "#/$defs/expression" }, + { "$ref": "#/$defs/diffeqObject" } + ], + "description": "SDE drift term (deterministic part)" + }, + "diffusion": { + "$ref": "#/$defs/diffusionObject", + "description": "SDE diffusion coefficients" + }, + "secondary": { + "$ref": "#/$defs/expression", + "description": "Secondary equations (for analytical)" + }, + + "output": { + "$ref": "#/$defs/expression", + "description": "Single output equation" + }, + "outputs": { + "type": "array", + "items": { "$ref": "#/$defs/outputDefinition" }, + "minItems": 1, + "description": "Multiple output definitions" + }, + + "init": { + "oneOf": [ + { "$ref": "#/$defs/expression" }, + { "$ref": "#/$defs/initObject" } + ], + "description": "Initial conditions" + }, + "lag": { + "$ref": "#/$defs/lagObject", + "description": "Lag times per input compartment" + }, + "fa": { + "$ref": "#/$defs/faObject", + "description": "Bioavailability per input compartment" + }, + "neqs": { + "type": "array", + "items": { "type": "integer", "minimum": 1 }, + "minItems": 2, + "maxItems": 2, + "description": "[num_states, num_outputs]" + }, + "particles": { + "type": "integer", + "minimum": 100, + "default": 1000, + "description": "Number of particles for SDE simulation" + }, + + "parameterization": { + "oneOf": [{ "type": "string" }, { "$ref": "#/$defs/parameterization" }], + "description": "Active parameterization (ID or inline definition)" + }, + "parameterizations": { + "type": "array", + "items": { "$ref": "#/$defs/parameterization" }, + "description": "Available parameterization variants" + }, + + "features": { + "type": "array", + "items": { + "type": "string", + "enum": ["lag_time", "bioavailability", "initial_conditions"] + }, + "description": "Enabled optional features" + }, + "covariates": { + "type": "array", + "items": { "$ref": "#/$defs/covariateDefinition" }, + "description": "Covariate definitions" + }, + "covariateEffects": { + "type": "array", + "items": { "$ref": "#/$defs/covariateEffect" }, + "description": "Covariate effect specifications" + }, + "errorModel": { + "$ref": "#/$defs/errorModel", + "description": "Residual error model" + }, + "errorModels": { + "type": "object", + "additionalProperties": { "$ref": "#/$defs/errorModel" }, + "description": "Error models per output (keyed by output ID)" + }, + + "display": { + "$ref": "#/$defs/displayInfo", + "description": "UI display information" + }, + "layout": { + "$ref": "#/$defs/layoutObject", + "description": "Visual diagram layout" + }, + "documentation": { + "$ref": "#/$defs/documentation", + "description": "Rich documentation" + } + }, + + "required": ["schema", "id", "type"], + + "allOf": [ + { + "if": { + "properties": { "type": { "const": "analytical" } }, + "required": ["type"] + }, + "then": { + "required": ["analytical"], + "properties": { + "diffeq": false, + "drift": false, + "diffusion": false, + "particles": false + } + } + }, + { + "if": { + "properties": { "type": { "const": "ode" } }, + "required": ["type"] + }, + "then": { + "required": ["diffeq"], + "properties": { + "analytical": false, + "drift": false, + "diffusion": false, + "particles": false + } + } + }, + { + "if": { + "properties": { "type": { "const": "sde" } }, + "required": ["type"] + }, + "then": { + "required": ["drift", "diffusion"], + "properties": { + "analytical": false, + "diffeq": false + } + } + }, + { + "if": { + "not": { "required": ["extends"] } + }, + "then": { + "anyOf": [{ "required": ["output"] }, { "required": ["outputs"] }] + } + }, + { + "if": { + "not": { "required": ["extends"] } + }, + "then": { + "required": ["parameters"] + } + } + ], + + "additionalProperties": false, + + "examples": [ + { + "schema": "1.0", + "id": "pk_1cmt_iv", + "type": "analytical", + "analytical": "one_compartment", + "parameters": ["ke", "V"], + "output": "x[0] / V" + }, + { + "schema": "1.0", + "id": "pk_1cmt_oral", + "type": "analytical", + "analytical": "one_compartment_with_absorption", + "parameters": ["ka", "ke", "V"], + "output": "x[1] / V" + }, + { + "schema": "1.0", + "id": "pk_1cmt_oral_lag", + "type": "analytical", + "analytical": "one_compartment_with_absorption", + "parameters": ["ka", "ke", "V", "tlag"], + "lag": { "0": "tlag" }, + "output": "x[1] / V", + "neqs": [2, 1] + }, + { + "schema": "1.0", + "id": "pk_2cmt_ode", + "type": "ode", + "compartments": ["depot", "central", "peripheral"], + "parameters": ["ka", "ke", "k12", "k21", "V"], + "diffeq": { + "depot": "-ka * x[0]", + "central": "ka * x[0] - ke * x[1] - k12 * x[1] + k21 * x[2] + rateiv[1]", + "peripheral": "k12 * x[1] - k21 * x[2]" + }, + "output": "x[1] / V", + "neqs": [3, 1] + }, + { + "schema": "1.0", + "id": "pk_1cmt_sde", + "type": "sde", + "parameters": ["ke0", "sigma_ke", "V"], + "states": ["amount", "ke"], + "drift": { + "amount": "-ke * x[0]", + "ke": "-0.5 * (ke - ke0)" + }, + "diffusion": { + "ke": "sigma_ke" + }, + "init": { + "ke": "ke0" + }, + "output": "x[0] / V", + "neqs": [2, 1], + "particles": 1000 + } + ] +} diff --git a/src/json/codegen/analytical.rs b/src/json/codegen/analytical.rs new file mode 100644 index 0000000..d6c48a1 --- /dev/null +++ b/src/json/codegen/analytical.rs @@ -0,0 +1,11 @@ +//! Analytical model code generation +//! +//! This module contains specialized code generation logic for analytical models. +//! Most of the heavy lifting is done by the ClosureGenerator in closures.rs. + +// Currently, all analytical-specific generation is handled in mod.rs +// and closures.rs. This module is reserved for future specialized logic +// such as: +// - Analytical function parameter validation +// - Secondary equation optimization +// - Symbolic differentiation for sensitivity analysis diff --git a/src/json/codegen/closures.rs b/src/json/codegen/closures.rs new file mode 100644 index 0000000..e7724b1 --- /dev/null +++ b/src/json/codegen/closures.rs @@ -0,0 +1,571 @@ +//! Closure generation for model equations +//! +//! This module generates the closure functions that are passed to +//! equation constructors (Analytical, ODE, SDE). + +use std::collections::HashMap; + +use crate::json::errors::JsonModelError; +use crate::json::model::JsonModel; +use crate::json::types::*; + +/// Generator for closure functions +pub struct ClosureGenerator<'a> { + model: &'a JsonModel, + compartment_map: HashMap, + state_map: HashMap, +} + +impl<'a> ClosureGenerator<'a> { + /// Create a new closure generator + pub fn new(model: &'a JsonModel) -> Self { + Self { + model, + compartment_map: model.compartment_map(), + state_map: model.state_map(), + } + } + + /// Generate the fetch_params! macro call + fn fetch_params(&self) -> String { + let params = self.model.get_parameters(); + if params.is_empty() { + return String::new(); + } + format!("fetch_params!(p, {});", params.join(", ")) + } + + /// Generate compartment bindings (e.g., let central = x[0];) + fn generate_compartment_bindings(&self) -> String { + if self.compartment_map.is_empty() { + return String::new(); + } + + let mut bindings: Vec<_> = self + .compartment_map + .iter() + .map(|(name, &idx)| format!("let {} = x[{}];", name, idx)) + .collect(); + bindings.sort(); // Consistent ordering + bindings.join("\n ") + } + + /// Generate state bindings for SDE (e.g., let state0 = x[0];) + fn generate_state_bindings(&self) -> String { + if self.state_map.is_empty() { + return String::new(); + } + + let mut bindings: Vec<_> = self + .state_map + .iter() + .map(|(name, &idx)| format!("let {} = x[{}];", name, idx)) + .collect(); + bindings.sort(); // Consistent ordering + bindings.join("\n ") + } + + /// Generate fetch_cov! macro call for covariates used in covariate effects + fn fetch_covariates(&self) -> String { + // Collect all covariate names used in effects + let Some(effects) = &self.model.covariate_effects else { + return String::new(); + }; + + let cov_names: Vec<_> = effects + .iter() + .filter_map(|e| e.covariate.as_ref()) + .map(|c| c.as_str()) + .collect::>() + .into_iter() + .collect(); + + if cov_names.is_empty() { + return String::new(); + } + + // Generate code to fetch each covariate + let fetch_lines: Vec<_> = cov_names + .iter() + .map(|name| { + format!( + "let {} = cov.get_covariate(\"{}\", t).unwrap_or(0.0);", + name, name + ) + }) + .collect(); + + fetch_lines.join("\n ") + } + + /// Generate covariate effect code to inject before equations + fn generate_covariate_effects(&self) -> String { + let Some(effects) = &self.model.covariate_effects else { + return String::new(); + }; + + if effects.is_empty() { + return String::new(); + } + + // First, fetch all covariates used + let fetch_cov = self.fetch_covariates(); + + let mut lines = Vec::new(); + + for effect in effects { + let param = &effect.on; + let code = match effect.effect_type { + CovariateEffectType::Allometric => { + let cov = effect.covariate.as_ref().unwrap(); + let exp = effect.exponent.unwrap_or(0.75); + let reference = effect.reference.unwrap_or(70.0); + format!( + "let {param} = {param} * ({cov} / {:.1}).powf({:.4});", + reference, exp + ) + } + CovariateEffectType::Linear => { + let cov = effect.covariate.as_ref().unwrap(); + let slope = effect.slope.unwrap_or(0.0); + let reference = effect.reference.unwrap_or(0.0); + format!( + "let {param} = {param} * (1.0 + {:.6} * ({cov} - {:.6}));", + slope, reference + ) + } + CovariateEffectType::Exponential => { + let cov = effect.covariate.as_ref().unwrap(); + let slope = effect.slope.unwrap_or(0.0); + let reference = effect.reference.unwrap_or(0.0); + format!( + "let {param} = {param} * ({:.6} * ({cov} - {:.6})).exp();", + slope, reference + ) + } + CovariateEffectType::Proportional => { + let cov = effect.covariate.as_ref().unwrap(); + let slope = effect.slope.unwrap_or(0.0); + format!("let {param} = {param} * (1.0 + {:.6} * {cov});", slope) + } + CovariateEffectType::Custom => { + let expr = effect.expression.as_ref().unwrap(); + format!("let {param} = {expr};") + } + CovariateEffectType::Categorical => { + // Categorical effects require match statement + let cov = effect.covariate.as_ref().unwrap(); + if let Some(levels) = &effect.levels { + let arms: Vec<_> = levels + .iter() + .map(|(k, v)| format!("\"{}\" => {:.6}", k, v)) + .collect(); + format!( + "let {param} = {param} * match {cov} {{ {}, _ => 1.0 }};", + arms.join(", ") + ) + } else { + String::new() + } + } + }; + if !code.is_empty() { + lines.push(code); + } + } + + // Prepend fetch code + if !fetch_cov.is_empty() { + return format!("{}\n {}", fetch_cov, lines.join("\n ")); + } + + lines.join("\n ") + } + + /// Generate derived parameters code + fn generate_derived_params(&self) -> String { + // Use model-level derived parameters + if let Some(derived) = &self.model.derived { + let lines: Vec<_> = derived + .iter() + .map(|d| format!("let {} = {};", d.symbol, d.expression)) + .collect(); + return lines.join("\n "); + } + String::new() + } + + // ═══════════════════════════════════════════════════════════════════════════ + // Closure Generators + // ═══════════════════════════════════════════════════════════════════════════ + + /// Generate the output closure + /// Signature: fn(&V, &V, T, &Covariates, &mut V) + pub fn generate_output(&self) -> Result { + let output_expr = if let Some(output) = &self.model.output { + output.clone() + } else if let Some(outputs) = &self.model.outputs { + // Multiple outputs + outputs + .iter() + .enumerate() + .map(|(i, o)| format!("y[{}] = {};", i, o.equation)) + .collect::>() + .join("\n ") + } else { + return Err(JsonModelError::MissingOutput); + }; + + let fetch_params = self.fetch_params(); + let derived = self.generate_derived_params(); + let cov_effects = self.generate_covariate_effects(); + + // Determine if we have a single expression or multiple statements + let body = if output_expr.contains("y[") { + // Already has y[] assignments + output_expr + } else { + // Single expression, wrap it + format!("y[0] = {};", output_expr) + }; + + let compartments = self.generate_compartment_bindings(); + + Ok(format!( + r#"|x, p, _t, _cov, y| {{ + {fetch_params} + {compartments} + {derived} + {cov_effects} + {body} + }}"# + )) + } + + /// Generate the differential equation closure + /// Signature: fn(&V, &V, T, &mut V, &V, &V, &Covariates) + pub fn generate_diffeq(&self) -> Result { + let diffeq = self + .model + .diffeq + .as_ref() + .ok_or_else(|| JsonModelError::missing_field("diffeq", "ode"))?; + + let body = match diffeq { + DiffEqSpec::String(s) => s.clone(), + DiffEqSpec::Object(map) => { + // Convert named compartments to dx[n] format + let mut lines = Vec::new(); + for (name, expr) in map { + let idx = self.compartment_map.get(name).copied().unwrap_or_else(|| { + // Try parsing as number + name.parse::().unwrap_or(0) + }); + lines.push(format!("dx[{}] = {};", idx, expr)); + } + lines.join("\n ") + } + }; + + let fetch_params = self.fetch_params(); + let compartments = self.generate_compartment_bindings(); + let derived = self.generate_derived_params(); + let cov_effects = self.generate_covariate_effects(); + + Ok(format!( + r#"|x, p, _t, dx, _b, rateiv, _cov| {{ + {fetch_params} + {compartments} + {derived} + {cov_effects} + {body} + }}"# + )) + } + + /// Generate the drift closure for SDE + /// Signature: fn(&V, &V, T, &mut V, V, &Covariates) + pub fn generate_drift(&self) -> Result { + let drift = self + .model + .drift + .as_ref() + .ok_or_else(|| JsonModelError::missing_field("drift", "sde"))?; + + let body = match drift { + DiffEqSpec::String(s) => s.clone(), + DiffEqSpec::Object(map) => { + let mut lines = Vec::new(); + for (name, expr) in map { + let idx = self.state_map.get(name).copied().unwrap_or_else(|| { + self.compartment_map + .get(name) + .copied() + .unwrap_or_else(|| name.parse::().unwrap_or(0)) + }); + lines.push(format!("dx[{}] = {};", idx, expr)); + } + lines.join("\n ") + } + }; + + let fetch_params = self.fetch_params(); + let states = self.generate_state_bindings(); + let derived = self.generate_derived_params(); + let cov_effects = self.generate_covariate_effects(); + + Ok(format!( + r#"|x, p, _t, dx, rateiv, _cov| {{ + {fetch_params} + {states} + {derived} + {cov_effects} + {body} + }}"# + )) + } + + /// Generate the diffusion closure for SDE + /// Signature: fn(&V, &mut V) + pub fn generate_diffusion(&self) -> Result { + let diffusion = self + .model + .diffusion + .as_ref() + .ok_or_else(|| JsonModelError::missing_field("diffusion", "sde"))?; + + let fetch_params = self.fetch_params(); + let states = self.generate_state_bindings(); + + let mut lines = Vec::new(); + for (name, expr) in diffusion { + let idx = self.state_map.get(name).copied().unwrap_or_else(|| { + self.compartment_map + .get(name) + .copied() + .unwrap_or_else(|| name.parse::().unwrap_or(0)) + }); + lines.push(format!("d[{}] = {};", idx, expr.to_rust_expr())); + } + let body = lines.join("\n "); + + Ok(format!( + r#"|x, p, d| {{ + {fetch_params} + {states} + {body} + }}"# + )) + } + + /// Generate the lag closure + /// Signature: fn(&V, T, &Covariates) -> HashMap + pub fn generate_lag(&self) -> Result { + let Some(lag) = &self.model.lag else { + return Ok("|_p, _t, _cov| lag! {}".to_string()); + }; + + if lag.is_empty() { + return Ok("|_p, _t, _cov| lag! {}".to_string()); + } + + let fetch_params = self.fetch_params(); + + let entries: Vec<_> = lag + .iter() + .map(|(name, expr)| { + // Convert compartment name to index + let idx = self + .compartment_map + .get(name) + .copied() + .unwrap_or_else(|| name.parse::().unwrap_or(0)); + format!("{} => {}", idx, expr.to_rust_expr()) + }) + .collect(); + + Ok(format!( + r#"|p, _t, _cov| {{ + {fetch_params} + lag! {{ {} }} + }}"#, + entries.join(", ") + )) + } + + /// Generate the fa (bioavailability) closure + /// Signature: fn(&V, T, &Covariates) -> HashMap + pub fn generate_fa(&self) -> Result { + let Some(fa) = &self.model.fa else { + return Ok("|_p, _t, _cov| fa! {}".to_string()); + }; + + if fa.is_empty() { + return Ok("|_p, _t, _cov| fa! {}".to_string()); + } + + let fetch_params = self.fetch_params(); + + let entries: Vec<_> = fa + .iter() + .map(|(name, expr)| { + // Convert compartment name to index + let idx = self + .compartment_map + .get(name) + .copied() + .unwrap_or_else(|| name.parse::().unwrap_or(0)); + format!("{} => {}", idx, expr.to_rust_expr()) + }) + .collect(); + + Ok(format!( + r#"|p, _t, _cov| {{ + {fetch_params} + fa! {{ {} }} + }}"#, + entries.join(", ") + )) + } + + /// Generate the init closure + /// Signature: fn(&V, T, &Covariates, &mut V) + pub fn generate_init(&self) -> Result { + let Some(init) = &self.model.init else { + return Ok("|_p, _t, _cov, _x| {}".to_string()); + }; + + let body = match init { + InitSpec::String(s) => s.clone(), + InitSpec::Object(map) => { + let mut lines = Vec::new(); + for (name, expr) in map { + let idx = self.state_map.get(name).copied().unwrap_or_else(|| { + self.compartment_map + .get(name) + .copied() + .unwrap_or_else(|| name.parse::().unwrap_or(0)) + }); + lines.push(format!("x[{}] = {};", idx, expr.to_rust_expr())); + } + lines.join("\n ") + } + }; + + let fetch_params = self.fetch_params(); + + Ok(format!( + r#"|p, _t, _cov, x| {{ + {fetch_params} + {body} + }}"# + )) + } + + /// Generate the secondary equation closure (for analytical) + /// Signature: fn(&mut V, T, &Covariates) + pub fn generate_secondary(&self) -> Result { + let Some(secondary) = &self.model.secondary else { + return Ok("|_p, _t, _cov| {}".to_string()); + }; + + let fetch_params = self.fetch_params(); + let cov_effects = self.generate_covariate_effects(); + + Ok(format!( + r#"|p, _t, _cov| {{ + {fetch_params} + {cov_effects} + {secondary} + }}"# + )) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_generate_output() { + let json = r#"{ + "schema": "1.0", + "id": "test", + "type": "analytical", + "analytical": "one_compartment", + "parameters": ["ke", "V"], + "output": "x[0] / V" + }"#; + + let model = JsonModel::from_str(json).unwrap(); + let gen = ClosureGenerator::new(&model); + let output = gen.generate_output().unwrap(); + + assert!(output.contains("fetch_params!(p, ke, V)")); + assert!(output.contains("y[0] = x[0] / V")); + } + + #[test] + fn test_generate_lag() { + let json = r#"{ + "schema": "1.0", + "id": "test", + "type": "analytical", + "analytical": "one_compartment_with_absorption", + "parameters": ["ka", "ke", "V", "tlag"], + "lag": { "0": "tlag" }, + "output": "x[1] / V" + }"#; + + let model = JsonModel::from_str(json).unwrap(); + let gen = ClosureGenerator::new(&model); + let lag = gen.generate_lag().unwrap(); + + assert!(lag.contains("lag!")); + assert!(lag.contains("0 => tlag")); + } + + #[test] + fn test_generate_diffeq_object() { + let json = r#"{ + "schema": "1.0", + "id": "test", + "type": "ode", + "compartments": ["depot", "central"], + "parameters": ["ka", "ke", "V"], + "diffeq": { + "depot": "-ka * x[0]", + "central": "ka * x[0] - ke * x[1] + rateiv[1]" + }, + "output": "x[1] / V" + }"#; + + let model = JsonModel::from_str(json).unwrap(); + let gen = ClosureGenerator::new(&model); + let diffeq = gen.generate_diffeq().unwrap(); + + assert!(diffeq.contains("dx[0] = -ka * x[0]")); + assert!(diffeq.contains("dx[1] = ka * x[0] - ke * x[1] + rateiv[1]")); + } + + #[test] + fn test_generate_empty_lag_fa() { + let json = r#"{ + "schema": "1.0", + "id": "test", + "type": "analytical", + "analytical": "one_compartment", + "parameters": ["ke", "V"], + "output": "x[0] / V" + }"#; + + let model = JsonModel::from_str(json).unwrap(); + let gen = ClosureGenerator::new(&model); + + let lag = gen.generate_lag().unwrap(); + let fa = gen.generate_fa().unwrap(); + + assert!(lag.contains("lag! {}")); + assert!(fa.contains("fa! {}")); + } +} diff --git a/src/json/codegen/mod.rs b/src/json/codegen/mod.rs new file mode 100644 index 0000000..a37ced0 --- /dev/null +++ b/src/json/codegen/mod.rs @@ -0,0 +1,235 @@ +//! Code generation from JSON models to Rust code +//! +//! This module transforms validated JSON models into Rust code strings +//! that can be compiled by the `exa` module. + +mod analytical; +mod closures; +mod ode; +mod sde; + +use crate::json::errors::JsonModelError; +use crate::json::model::JsonModel; +use crate::json::types::*; +use crate::simulator::equation::EqnKind; + +pub use closures::ClosureGenerator; + +/// Generated Rust code ready for compilation +#[derive(Debug, Clone)] +pub struct GeneratedCode { + /// The complete equation constructor code + pub equation_code: String, + + /// Parameter names in fetch order + pub parameters: Vec, + + /// The equation kind (ODE, Analytical, SDE) + pub kind: EqnKind, +} + +/// Code generator for JSON models +pub struct CodeGenerator<'a> { + model: &'a JsonModel, + closure_gen: ClosureGenerator<'a>, +} + +impl<'a> CodeGenerator<'a> { + /// Create a new code generator for a model + pub fn new(model: &'a JsonModel) -> Self { + Self { + model, + closure_gen: ClosureGenerator::new(model), + } + } + + /// Generate the complete Rust code + pub fn generate(&self) -> Result { + let (equation_code, kind) = match self.model.model_type { + ModelType::Analytical => { + let code = self.generate_analytical()?; + (code, EqnKind::Analytical) + } + ModelType::Ode => { + let code = self.generate_ode()?; + (code, EqnKind::ODE) + } + ModelType::Sde => { + let code = self.generate_sde()?; + (code, EqnKind::SDE) + } + }; + + Ok(GeneratedCode { + equation_code, + parameters: self.model.get_parameters(), + kind, + }) + } + + /// Generate analytical model code + fn generate_analytical(&self) -> Result { + let func = self + .model + .analytical + .as_ref() + .ok_or_else(|| JsonModelError::missing_field("analytical", "analytical"))?; + + let seq_eq = self.closure_gen.generate_secondary()?; + let lag = self.closure_gen.generate_lag()?; + let fa = self.closure_gen.generate_fa()?; + let init = self.closure_gen.generate_init()?; + let out = self.closure_gen.generate_output()?; + let neqs = self.model.get_neqs(); + + Ok(format!( + r#"equation::Analytical::new( + {func_name}, + {seq_eq}, + {lag}, + {fa}, + {init}, + {out}, + ({nstates}, {nouts}), +)"#, + func_name = func.rust_name(), + seq_eq = seq_eq, + lag = lag, + fa = fa, + init = init, + out = out, + nstates = neqs.0, + nouts = neqs.1, + )) + } + + /// Generate ODE model code + fn generate_ode(&self) -> Result { + let diffeq = self.closure_gen.generate_diffeq()?; + let lag = self.closure_gen.generate_lag()?; + let fa = self.closure_gen.generate_fa()?; + let init = self.closure_gen.generate_init()?; + let out = self.closure_gen.generate_output()?; + let neqs = self.model.get_neqs(); + + Ok(format!( + r#"equation::ODE::new( + {diffeq}, + {lag}, + {fa}, + {init}, + {out}, + ({nstates}, {nouts}), +)"#, + diffeq = diffeq, + lag = lag, + fa = fa, + init = init, + out = out, + nstates = neqs.0, + nouts = neqs.1, + )) + } + + /// Generate SDE model code + fn generate_sde(&self) -> Result { + let drift = self.closure_gen.generate_drift()?; + let diffusion = self.closure_gen.generate_diffusion()?; + let lag = self.closure_gen.generate_lag()?; + let fa = self.closure_gen.generate_fa()?; + let init = self.closure_gen.generate_init()?; + let out = self.closure_gen.generate_output()?; + let neqs = self.model.get_neqs(); + let particles = self.model.particles.unwrap_or(1000); + + Ok(format!( + r#"equation::SDE::new( + {drift}, + {diffusion}, + {lag}, + {fa}, + {init}, + {out}, + ({nstates}, {nouts}), + {particles}, +)"#, + drift = drift, + diffusion = diffusion, + lag = lag, + fa = fa, + init = init, + out = out, + nstates = neqs.0, + nouts = neqs.1, + particles = particles, + )) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_generate_analytical() { + let json = r#"{ + "schema": "1.0", + "id": "pk_1cmt_oral", + "type": "analytical", + "analytical": "one_compartment_with_absorption", + "parameters": ["ka", "ke", "V"], + "output": "x[1] / V" + }"#; + + let model = JsonModel::from_str(json).unwrap(); + let generator = CodeGenerator::new(&model); + let result = generator.generate().unwrap(); + + assert!(result + .equation_code + .contains("one_compartment_with_absorption")); + assert!(result.equation_code.contains("equation::Analytical::new")); + assert_eq!(result.parameters, vec!["ka", "ke", "V"]); + } + + #[test] + fn test_generate_ode() { + let json = r#"{ + "schema": "1.0", + "id": "pk_1cmt_ode", + "type": "ode", + "parameters": ["ke", "V"], + "diffeq": "dx[0] = -ke * x[0] + rateiv[0];", + "output": "x[0] / V", + "neqs": [1, 1] + }"#; + + let model = JsonModel::from_str(json).unwrap(); + let generator = CodeGenerator::new(&model); + let result = generator.generate().unwrap(); + + assert!(result.equation_code.contains("equation::ODE::new")); + assert!(result.equation_code.contains("-ke * x[0]")); + } + + #[test] + fn test_generate_with_lag() { + let json = r#"{ + "schema": "1.0", + "id": "pk_1cmt_oral_lag", + "type": "analytical", + "analytical": "one_compartment_with_absorption", + "parameters": ["ka", "ke", "V", "tlag"], + "lag": { "0": "tlag" }, + "output": "x[1] / V", + "neqs": [2, 1] + }"#; + + let model = JsonModel::from_str(json).unwrap(); + let generator = CodeGenerator::new(&model); + let result = generator.generate().unwrap(); + + assert!(result.equation_code.contains("lag!")); + assert!(result.equation_code.contains("0 => tlag")); + } +} diff --git a/src/json/codegen/ode.rs b/src/json/codegen/ode.rs new file mode 100644 index 0000000..b410b43 --- /dev/null +++ b/src/json/codegen/ode.rs @@ -0,0 +1,11 @@ +//! ODE model code generation +//! +//! This module contains specialized code generation logic for ODE models. +//! Most of the heavy lifting is done by the ClosureGenerator in closures.rs. + +// Currently, all ODE-specific generation is handled in mod.rs +// and closures.rs. This module is reserved for future specialized logic +// such as: +// - Automatic Jacobian generation +// - Stiffness detection +// - Compartment flow analysis diff --git a/src/json/codegen/sde.rs b/src/json/codegen/sde.rs new file mode 100644 index 0000000..cd9253d --- /dev/null +++ b/src/json/codegen/sde.rs @@ -0,0 +1,11 @@ +//! SDE model code generation +//! +//! This module contains specialized code generation logic for SDE models. +//! Most of the heavy lifting is done by the ClosureGenerator in closures.rs. + +// Currently, all SDE-specific generation is handled in mod.rs +// and closures.rs. This module is reserved for future specialized logic +// such as: +// - Diffusion coefficient validation +// - Particle count optimization +// - Noise process analysis diff --git a/src/json/errors.rs b/src/json/errors.rs new file mode 100644 index 0000000..b4bb2c3 --- /dev/null +++ b/src/json/errors.rs @@ -0,0 +1,157 @@ +//! Error types for JSON model parsing and code generation + +use thiserror::Error; + +/// Errors that can occur when working with JSON models +#[derive(Debug, Error)] +pub enum JsonModelError { + // ───────────────────────────────────────────────────────────────────────── + // Parsing Errors + // ───────────────────────────────────────────────────────────────────────── + /// Failed to parse JSON + #[error("Failed to parse JSON: {0}")] + ParseError(#[from] serde_json::Error), + + /// Unsupported schema version + #[error("Unsupported schema version '{version}'. Supported versions: {supported}")] + UnsupportedSchema { version: String, supported: String }, + + // ───────────────────────────────────────────────────────────────────────── + // Structural Errors + // ───────────────────────────────────────────────────────────────────────── + /// Missing required field for model type + #[error("Missing required field '{field}' for {model_type} models")] + MissingField { field: String, model_type: String }, + + /// Invalid field for model type + #[error("Field '{field}' is not valid for {model_type} models")] + InvalidFieldForType { field: String, model_type: String }, + + /// Missing output equation + #[error("Model must have either 'output' or 'outputs' field")] + MissingOutput, + + /// Missing parameters + #[error("Model must have 'parameters' field (unless using 'extends')")] + MissingParameters, + + // ───────────────────────────────────────────────────────────────────────── + // Semantic Errors + // ───────────────────────────────────────────────────────────────────────── + /// Undefined parameter used in expression + #[error("Undefined parameter '{name}' used in {context}")] + UndefinedParameter { name: String, context: String }, + + /// Undefined compartment + #[error("Undefined compartment '{name}'")] + UndefinedCompartment { name: String }, + + /// Undefined covariate + #[error("Undefined covariate '{name}' referenced in covariate effect")] + UndefinedCovariate { name: String }, + + /// Parameter order mismatch for analytical function + #[error( + "Parameter order warning for '{function}': expected parameters in order {expected:?}, \ + but got {actual:?}. This may cause incorrect model behavior." + )] + ParameterOrderWarning { + function: String, + expected: Vec, + actual: Vec, + }, + + /// Duplicate parameter name + #[error("Duplicate parameter name: '{name}'")] + DuplicateParameter { name: String }, + + /// Duplicate compartment name + #[error("Duplicate compartment name: '{name}'")] + DuplicateCompartment { name: String }, + + /// Invalid neqs specification + #[error("Invalid neqs: expected [num_states, num_outputs], got {0:?}")] + InvalidNeqs(Vec), + + // ───────────────────────────────────────────────────────────────────────── + // Expression Errors + // ───────────────────────────────────────────────────────────────────────── + /// Invalid expression syntax + #[error("Invalid expression in {context}: {message}")] + InvalidExpression { context: String, message: String }, + + /// Empty expression + #[error("Empty expression in {context}")] + EmptyExpression { context: String }, + + // ───────────────────────────────────────────────────────────────────────── + // Library Errors + // ───────────────────────────────────────────────────────────────────────── + /// Model not found in library + #[error("Model '{0}' not found in library")] + ModelNotFound(String), + + /// Circular inheritance detected + #[error("Circular inheritance detected: {0}")] + CircularInheritance(String), + + /// General library error (file I/O, etc.) + #[error("Library error: {0}")] + LibraryError(String), + + // ───────────────────────────────────────────────────────────────────────── + // Code Generation Errors + // ───────────────────────────────────────────────────────────────────────── + /// Code generation failed + #[error("Code generation failed: {0}")] + CodeGenError(String), + + /// Compilation failed + #[error("Compilation failed: {0}")] + CompilationError(String), + + // ───────────────────────────────────────────────────────────────────────── + // Covariate Effect Errors + // ───────────────────────────────────────────────────────────────────────── + /// Missing required field for covariate effect type + #[error("Covariate effect type '{effect_type}' requires field '{field}'")] + MissingCovariateEffectField { effect_type: String, field: String }, + + /// Invalid covariate effect target + #[error("Covariate effect targets unknown parameter '{parameter}'")] + InvalidCovariateEffectTarget { parameter: String }, +} + +impl JsonModelError { + /// Create a missing field error + pub fn missing_field(field: impl Into, model_type: impl Into) -> Self { + Self::MissingField { + field: field.into(), + model_type: model_type.into(), + } + } + + /// Create an invalid field error + pub fn invalid_field(field: impl Into, model_type: impl Into) -> Self { + Self::InvalidFieldForType { + field: field.into(), + model_type: model_type.into(), + } + } + + /// Create an undefined parameter error + pub fn undefined_param(name: impl Into, context: impl Into) -> Self { + Self::UndefinedParameter { + name: name.into(), + context: context.into(), + } + } + + /// Create an invalid expression error + pub fn invalid_expr(context: impl Into, message: impl Into) -> Self { + Self::InvalidExpression { + context: context.into(), + message: message.into(), + } + } +} diff --git a/src/json/library/mod.rs b/src/json/library/mod.rs new file mode 100644 index 0000000..06cebc3 --- /dev/null +++ b/src/json/library/mod.rs @@ -0,0 +1,517 @@ +//! Model Library +//! +//! Provides a registry of built-in pharmacometric models that can be: +//! - Used directly via their ID +//! - Extended via the `extends` field for customization +//! +//! # Example +//! +//! ```rust,ignore +//! use pharmsol::json::library::ModelLibrary; +//! +//! let library = ModelLibrary::builtin(); +//! +//! // List available models +//! for id in library.list() { +//! println!("Available: {}", id); +//! } +//! +//! // Get a model +//! if let Some(model) = library.get("pk/1cmt-iv") { +//! println!("Found model: {}", model.id); +//! } +//! ``` + +use crate::json::errors::JsonModelError; +use crate::json::model::JsonModel; +use crate::json::types::{DisplayInfo, Documentation, ModelType}; +use std::collections::HashMap; +use std::path::Path; + +/// A registry of JSON model definitions +#[derive(Debug, Clone)] +pub struct ModelLibrary { + models: HashMap, +} + +// Embed built-in models at compile time +mod embedded { + // PK Analytical Models + pub const PK_1CMT_IV: &str = include_str!("models/pk_1cmt_iv.json"); + pub const PK_1CMT_ORAL: &str = include_str!("models/pk_1cmt_oral.json"); + pub const PK_2CMT_IV: &str = include_str!("models/pk_2cmt_iv.json"); + pub const PK_2CMT_ORAL: &str = include_str!("models/pk_2cmt_oral.json"); + pub const PK_3CMT_IV: &str = include_str!("models/pk_3cmt_iv.json"); + pub const PK_3CMT_ORAL: &str = include_str!("models/pk_3cmt_oral.json"); + + // PK ODE Models + pub const PK_1CMT_IV_ODE: &str = include_str!("models/pk_1cmt_iv_ode.json"); + pub const PK_1CMT_ORAL_ODE: &str = include_str!("models/pk_1cmt_oral_ode.json"); + pub const PK_2CMT_IV_ODE: &str = include_str!("models/pk_2cmt_iv_ode.json"); + pub const PK_2CMT_ORAL_ODE: &str = include_str!("models/pk_2cmt_oral_ode.json"); +} + +impl ModelLibrary { + /// Create a new empty library + pub fn new() -> Self { + Self { + models: HashMap::new(), + } + } + + /// Create a library with all built-in models + pub fn builtin() -> Self { + let mut library = Self::new(); + + // Load embedded models + let embedded_models = [ + embedded::PK_1CMT_IV, + embedded::PK_1CMT_ORAL, + embedded::PK_2CMT_IV, + embedded::PK_2CMT_ORAL, + embedded::PK_3CMT_IV, + embedded::PK_3CMT_ORAL, + embedded::PK_1CMT_IV_ODE, + embedded::PK_1CMT_ORAL_ODE, + embedded::PK_2CMT_IV_ODE, + embedded::PK_2CMT_ORAL_ODE, + ]; + + for json in embedded_models { + if let Ok(model) = JsonModel::from_str(json) { + library.models.insert(model.id.clone(), model); + } + } + + library + } + + /// Load models from a directory (recursively searches for .json files) + pub fn from_dir(path: &Path) -> Result { + let mut library = Self::new(); + library.load_dir(path)?; + Ok(library) + } + + /// Load models from a directory into this library + pub fn load_dir(&mut self, path: &Path) -> Result<(), JsonModelError> { + if !path.exists() { + return Err(JsonModelError::LibraryError(format!( + "Directory not found: {}", + path.display() + ))); + } + + Self::load_dir_recursive(path, &mut self.models)?; + Ok(()) + } + + fn load_dir_recursive( + path: &Path, + models: &mut HashMap, + ) -> Result<(), JsonModelError> { + let entries = std::fs::read_dir(path).map_err(|e| { + JsonModelError::LibraryError(format!("Failed to read directory: {}", e)) + })?; + + for entry in entries { + let entry = entry.map_err(|e| { + JsonModelError::LibraryError(format!("Failed to read entry: {}", e)) + })?; + let file_path = entry.path(); + + if file_path.is_dir() { + Self::load_dir_recursive(&file_path, models)?; + } else if file_path.extension().is_some_and(|ext| ext == "json") { + let content = std::fs::read_to_string(&file_path).map_err(|e| { + JsonModelError::LibraryError(format!( + "Failed to read {}: {}", + file_path.display(), + e + )) + })?; + + match JsonModel::from_str(&content) { + Ok(model) => { + models.insert(model.id.clone(), model); + } + Err(e) => { + // Log warning but continue loading other models + eprintln!("Warning: Failed to parse {}: {}", file_path.display(), e); + } + } + } + } + + Ok(()) + } + + /// Get a model by ID + pub fn get(&self, id: &str) -> Option<&JsonModel> { + self.models.get(id) + } + + /// Check if a model exists + pub fn contains(&self, id: &str) -> bool { + self.models.contains_key(id) + } + + /// Add a model to the library + pub fn add(&mut self, model: JsonModel) { + self.models.insert(model.id.clone(), model); + } + + /// Remove a model from the library + pub fn remove(&mut self, id: &str) -> Option { + self.models.remove(id) + } + + /// List all model IDs + pub fn list(&self) -> Vec<&str> { + let mut ids: Vec<&str> = self.models.keys().map(|s| s.as_str()).collect(); + ids.sort(); + ids + } + + /// Get the number of models + pub fn len(&self) -> usize { + self.models.len() + } + + /// Check if the library is empty + pub fn is_empty(&self) -> bool { + self.models.is_empty() + } + + /// Search models by partial ID or name match + pub fn search(&self, query: &str) -> Vec<&JsonModel> { + let query_lower = query.to_lowercase(); + self.models + .values() + .filter(|model| { + // Match by ID + if model.id.to_lowercase().contains(&query_lower) { + return true; + } + // Match by name in display info + if let Some(ref display) = model.display { + if let Some(ref name) = display.name { + if name.to_lowercase().contains(&query_lower) { + return true; + } + } + } + false + }) + .collect() + } + + /// Filter models by type + pub fn filter_by_type(&self, model_type: ModelType) -> Vec<&JsonModel> { + self.models + .values() + .filter(|m| m.model_type == model_type) + .collect() + } + + /// Filter models by tag (from display info) + pub fn filter_by_tag(&self, tag: &str) -> Vec<&JsonModel> { + let tag_lower = tag.to_lowercase(); + self.models + .values() + .filter(|model| { + if let Some(ref display) = model.display { + if let Some(ref tags) = display.tags { + return tags.iter().any(|t| t.to_lowercase() == tag_lower); + } + } + false + }) + .collect() + } + + /// Resolve a model's inheritance chain, returning a fully resolved model + /// + /// This processes the `extends` field to merge base model properties + /// with the derived model's overrides. + pub fn resolve(&self, model: &JsonModel) -> Result { + self.resolve_with_chain(model, &mut Vec::new()) + } + + fn resolve_with_chain( + &self, + model: &JsonModel, + chain: &mut Vec, + ) -> Result { + // Check for circular inheritance + if chain.contains(&model.id) { + return Err(JsonModelError::CircularInheritance(format!( + "{} -> {}", + chain.join(" -> "), + model.id + ))); + } + + // If no base, return model as-is + let Some(ref base_id) = model.extends else { + return Ok(model.clone()); + }; + + // Track inheritance chain + chain.push(model.id.clone()); + + // Get base model + let base = self + .get(base_id) + .ok_or_else(|| JsonModelError::ModelNotFound(base_id.clone()))?; + + // Recursively resolve base + let resolved_base = self.resolve_with_chain(base, chain)?; + + // Merge: derived model overrides base + Ok(merge_models(&resolved_base, model)) + } +} + +impl Default for ModelLibrary { + fn default() -> Self { + Self::new() + } +} + +/// Merge two models, with derived overriding base +fn merge_models(base: &JsonModel, derived: &JsonModel) -> JsonModel { + JsonModel { + // ───────────────────────────────────────────────────────────────────── + // Layer 1: Identity (derived always owns these) + // ───────────────────────────────────────────────────────────────────── + schema: derived.schema.clone(), + id: derived.id.clone(), + model_type: derived.model_type, + extends: None, // Clear extends after resolution + version: derived.version.clone().or_else(|| base.version.clone()), + aliases: merge_option_vec(&base.aliases, &derived.aliases), + + // ───────────────────────────────────────────────────────────────────── + // Layer 2: Structural Model + // ───────────────────────────────────────────────────────────────────── + parameters: derived + .parameters + .clone() + .or_else(|| base.parameters.clone()), + compartments: derived + .compartments + .clone() + .or_else(|| base.compartments.clone()), + states: derived.states.clone().or_else(|| base.states.clone()), + + // ───────────────────────────────────────────────────────────────────── + // Equation Fields + // ───────────────────────────────────────────────────────────────────── + analytical: derived.analytical.or(base.analytical), + diffeq: derived.diffeq.clone().or_else(|| base.diffeq.clone()), + drift: derived.drift.clone().or_else(|| base.drift.clone()), + diffusion: derived.diffusion.clone().or_else(|| base.diffusion.clone()), + secondary: derived.secondary.clone().or_else(|| base.secondary.clone()), + + // ───────────────────────────────────────────────────────────────────── + // Output + // ───────────────────────────────────────────────────────────────────── + output: derived.output.clone().or_else(|| base.output.clone()), + outputs: derived.outputs.clone().or_else(|| base.outputs.clone()), + + // ───────────────────────────────────────────────────────────────────── + // Optional Features + // ───────────────────────────────────────────────────────────────────── + init: derived.init.clone().or_else(|| base.init.clone()), + lag: derived.lag.clone().or_else(|| base.lag.clone()), + fa: derived.fa.clone().or_else(|| base.fa.clone()), + neqs: derived.neqs.or(base.neqs), + particles: derived.particles.or(base.particles), + + // ───────────────────────────────────────────────────────────────────── + // Layer 3: Model Extensions + // ───────────────────────────────────────────────────────────────────── + derived: merge_option_vec(&base.derived, &derived.derived), + features: merge_option_vec(&base.features, &derived.features), + covariates: merge_option_vec(&base.covariates, &derived.covariates), + covariate_effects: merge_option_vec(&base.covariate_effects, &derived.covariate_effects), + + // ───────────────────────────────────────────────────────────────────── + // Layer 4: UI Metadata + // ───────────────────────────────────────────────────────────────────── + display: merge_display(&base.display, &derived.display), + layout: merge_option_hashmap(&base.layout, &derived.layout), + documentation: merge_documentation(&base.documentation, &derived.documentation), + } +} + +/// Merge optional vectors (append derived items) +fn merge_option_vec(base: &Option>, derived: &Option>) -> Option> { + match (base, derived) { + (None, None) => None, + (Some(b), None) => Some(b.clone()), + (None, Some(d)) => Some(d.clone()), + (Some(b), Some(d)) => { + let mut merged = b.clone(); + merged.extend(d.iter().cloned()); + Some(merged) + } + } +} + +/// Merge optional HashMaps (derived overrides base keys) +fn merge_option_hashmap( + base: &Option>, + derived: &Option>, +) -> Option> { + match (base, derived) { + (None, None) => None, + (Some(b), None) => Some(b.clone()), + (None, Some(d)) => Some(d.clone()), + (Some(b), Some(d)) => { + let mut merged = b.clone(); + merged.extend(d.iter().map(|(k, v)| (k.clone(), v.clone()))); + Some(merged) + } + } +} + +/// Merge display info (derived overrides base) +fn merge_display(base: &Option, derived: &Option) -> Option { + match (base, derived) { + (None, None) => None, + (Some(b), None) => Some(b.clone()), + (None, Some(d)) => Some(d.clone()), + (Some(b), Some(d)) => Some(DisplayInfo { + name: d.name.clone().or_else(|| b.name.clone()), + short_name: d.short_name.clone().or_else(|| b.short_name.clone()), + category: d.category.or(b.category), + subcategory: d.subcategory.clone().or_else(|| b.subcategory.clone()), + complexity: d.complexity.or(b.complexity), + icon: d.icon.clone().or_else(|| b.icon.clone()), + tags: merge_option_vec(&b.tags, &d.tags), + }), + } +} + +/// Merge documentation (derived overrides base) +fn merge_documentation( + base: &Option, + derived: &Option, +) -> Option { + match (base, derived) { + (None, None) => None, + (Some(b), None) => Some(b.clone()), + (None, Some(d)) => Some(d.clone()), + (Some(b), Some(d)) => Some(Documentation { + summary: d.summary.clone().or_else(|| b.summary.clone()), + description: d.description.clone().or_else(|| b.description.clone()), + equations: d.equations.clone().or_else(|| b.equations.clone()), + assumptions: merge_option_vec(&b.assumptions, &d.assumptions), + when_to_use: merge_option_vec(&b.when_to_use, &d.when_to_use), + when_not_to_use: merge_option_vec(&b.when_not_to_use, &d.when_not_to_use), + references: merge_option_vec(&b.references, &d.references), + }), + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_builtin_library() { + let library = ModelLibrary::builtin(); + assert!(!library.is_empty()); + + // Should have analytical models + let analytical = library.filter_by_type(ModelType::Analytical); + assert!(!analytical.is_empty()); + } + + #[test] + fn test_search() { + let library = ModelLibrary::builtin(); + + // Search by ID + let results = library.search("1cmt"); + assert!(!results.is_empty()); + } + + #[test] + fn test_resolve_simple() { + let mut library = ModelLibrary::new(); + + // Add a base model + let base = JsonModel::from_str( + r#"{ + "schema": "1.0", + "id": "base-model", + "type": "analytical", + "analytical": "one_compartment", + "parameters": ["ke", "V"], + "output": "x[0] / V" + }"#, + ) + .unwrap(); + library.add(base); + + // Add a derived model + let derived = JsonModel::from_str( + r#"{ + "schema": "1.0", + "id": "derived-model", + "extends": "base-model", + "type": "analytical", + "analytical": "one_compartment", + "parameters": ["ke", "V", "extra"] + }"#, + ) + .unwrap(); + + // Resolve should merge + let resolved = library.resolve(&derived).unwrap(); + assert_eq!(resolved.parameters.as_ref().unwrap().len(), 3); + assert!(resolved.output.is_some()); // Inherited from base + } + + #[test] + fn test_circular_inheritance() { + let mut library = ModelLibrary::new(); + + let model_a = JsonModel::from_str( + r#"{ + "schema": "1.0", + "id": "model-a", + "extends": "model-b", + "type": "analytical", + "analytical": "one_compartment", + "parameters": ["ke", "V"] + }"#, + ) + .unwrap(); + + let model_b = JsonModel::from_str( + r#"{ + "schema": "1.0", + "id": "model-b", + "extends": "model-a", + "type": "analytical", + "analytical": "one_compartment", + "parameters": ["ke", "V"] + }"#, + ) + .unwrap(); + + library.add(model_a.clone()); + library.add(model_b); + + // Should detect circular inheritance + let result = library.resolve(&model_a); + assert!(matches!( + result, + Err(JsonModelError::CircularInheritance(_)) + )); + } +} diff --git a/src/json/library/models/pk_1cmt_iv.json b/src/json/library/models/pk_1cmt_iv.json new file mode 100644 index 0000000..6b80469 --- /dev/null +++ b/src/json/library/models/pk_1cmt_iv.json @@ -0,0 +1,17 @@ +{ + "schema": "1.0", + "id": "pk/1cmt-iv", + "type": "analytical", + "analytical": "one_compartment", + "parameters": ["ke", "V"], + "output": "x[0] / V", + "neqs": [1, 1], + "display": { + "name": "One-Compartment IV Bolus", + "category": "pk", + "tags": ["1-compartment", "iv", "linear"] + }, + "documentation": { + "summary": "Single compartment model with intravenous bolus administration and first-order elimination" + } +} diff --git a/src/json/library/models/pk_1cmt_iv_ode.json b/src/json/library/models/pk_1cmt_iv_ode.json new file mode 100644 index 0000000..af5103a --- /dev/null +++ b/src/json/library/models/pk_1cmt_iv_ode.json @@ -0,0 +1,20 @@ +{ + "schema": "1.0", + "id": "pk/1cmt-iv-ode", + "type": "ode", + "parameters": ["CL", "V"], + "compartments": ["central"], + "diffeq": { + "central": "-CL/V * central" + }, + "output": "central / V", + "neqs": [1, 1], + "display": { + "name": "One-Compartment IV Bolus (ODE)", + "category": "pk", + "tags": ["1-compartment", "iv", "ode", "clearance"] + }, + "documentation": { + "summary": "One-compartment ODE model using clearance (CL) and volume (V) parameterization" + } +} diff --git a/src/json/library/models/pk_1cmt_oral.json b/src/json/library/models/pk_1cmt_oral.json new file mode 100644 index 0000000..814f121 --- /dev/null +++ b/src/json/library/models/pk_1cmt_oral.json @@ -0,0 +1,17 @@ +{ + "schema": "1.0", + "id": "pk/1cmt-oral", + "type": "analytical", + "analytical": "one_compartment_with_absorption", + "parameters": ["ka", "ke", "V"], + "output": "x[1] / V", + "neqs": [2, 1], + "display": { + "name": "One-Compartment First-Order Absorption", + "category": "pk", + "tags": ["1-compartment", "oral", "linear", "first-order-absorption"] + }, + "documentation": { + "summary": "Single compartment model with first-order oral absorption and first-order elimination" + } +} diff --git a/src/json/library/models/pk_1cmt_oral_ode.json b/src/json/library/models/pk_1cmt_oral_ode.json new file mode 100644 index 0000000..94e1b59 --- /dev/null +++ b/src/json/library/models/pk_1cmt_oral_ode.json @@ -0,0 +1,27 @@ +{ + "schema": "1.0", + "id": "pk/1cmt-oral-ode", + "type": "ode", + "parameters": ["ka", "CL", "V"], + "compartments": ["depot", "central"], + "diffeq": { + "depot": "-ka * depot", + "central": "ka * depot - CL/V * central" + }, + "output": "central / V", + "neqs": [2, 1], + "display": { + "name": "One-Compartment Oral (ODE)", + "category": "pk", + "tags": [ + "1-compartment", + "oral", + "ode", + "clearance", + "first-order-absorption" + ] + }, + "documentation": { + "summary": "One-compartment ODE model for oral dosing with clearance (CL) and volume (V) parameterization" + } +} diff --git a/src/json/library/models/pk_2cmt_iv.json b/src/json/library/models/pk_2cmt_iv.json new file mode 100644 index 0000000..9b312b1 --- /dev/null +++ b/src/json/library/models/pk_2cmt_iv.json @@ -0,0 +1,17 @@ +{ + "schema": "1.0", + "id": "pk/2cmt-iv", + "type": "analytical", + "analytical": "two_compartments", + "parameters": ["ke", "kcp", "kpc", "V"], + "output": "x[0] / V", + "neqs": [2, 1], + "display": { + "name": "Two-Compartment IV Bolus", + "category": "pk", + "tags": ["2-compartment", "iv", "linear"] + }, + "documentation": { + "summary": "Two-compartment model with intravenous bolus administration and first-order elimination" + } +} diff --git a/src/json/library/models/pk_2cmt_iv_ode.json b/src/json/library/models/pk_2cmt_iv_ode.json new file mode 100644 index 0000000..2ecc693 --- /dev/null +++ b/src/json/library/models/pk_2cmt_iv_ode.json @@ -0,0 +1,21 @@ +{ + "schema": "1.0", + "id": "pk/2cmt-iv-ode", + "type": "ode", + "parameters": ["CL", "V1", "Q", "V2"], + "compartments": ["central", "peripheral"], + "diffeq": { + "central": "-CL/V1 * central - Q/V1 * central + Q/V2 * peripheral", + "peripheral": "Q/V1 * central - Q/V2 * peripheral" + }, + "output": "central / V1", + "neqs": [2, 1], + "display": { + "name": "Two-Compartment IV Bolus (ODE)", + "category": "pk", + "tags": ["2-compartment", "iv", "ode", "clearance"] + }, + "documentation": { + "summary": "Two-compartment ODE model using clearance and inter-compartmental clearance parameterization" + } +} diff --git a/src/json/library/models/pk_2cmt_oral.json b/src/json/library/models/pk_2cmt_oral.json new file mode 100644 index 0000000..fb96c24 --- /dev/null +++ b/src/json/library/models/pk_2cmt_oral.json @@ -0,0 +1,17 @@ +{ + "schema": "1.0", + "id": "pk/2cmt-oral", + "type": "analytical", + "analytical": "two_compartments_with_absorption", + "parameters": ["ke", "ka", "kcp", "kpc", "V"], + "output": "x[1] / V", + "neqs": [3, 1], + "display": { + "name": "Two-Compartment First-Order Absorption", + "category": "pk", + "tags": ["2-compartment", "oral", "linear", "first-order-absorption"] + }, + "documentation": { + "summary": "Two-compartment model with first-order oral absorption and first-order elimination" + } +} diff --git a/src/json/library/models/pk_2cmt_oral_ode.json b/src/json/library/models/pk_2cmt_oral_ode.json new file mode 100644 index 0000000..c2f0a0b --- /dev/null +++ b/src/json/library/models/pk_2cmt_oral_ode.json @@ -0,0 +1,28 @@ +{ + "schema": "1.0", + "id": "pk/2cmt-oral-ode", + "type": "ode", + "parameters": ["ka", "CL", "V1", "Q", "V2"], + "compartments": ["depot", "central", "peripheral"], + "diffeq": { + "depot": "-ka * depot", + "central": "ka * depot - CL/V1 * central - Q/V1 * central + Q/V2 * peripheral", + "peripheral": "Q/V1 * central - Q/V2 * peripheral" + }, + "output": "central / V1", + "neqs": [3, 1], + "display": { + "name": "Two-Compartment Oral (ODE)", + "category": "pk", + "tags": [ + "2-compartment", + "oral", + "ode", + "clearance", + "first-order-absorption" + ] + }, + "documentation": { + "summary": "Two-compartment ODE model for oral dosing with clearance and inter-compartmental clearance parameterization" + } +} diff --git a/src/json/library/models/pk_3cmt_iv.json b/src/json/library/models/pk_3cmt_iv.json new file mode 100644 index 0000000..ac11517 --- /dev/null +++ b/src/json/library/models/pk_3cmt_iv.json @@ -0,0 +1,17 @@ +{ + "schema": "1.0", + "id": "pk/3cmt-iv", + "type": "analytical", + "analytical": "three_compartments", + "parameters": ["k10", "k12", "k13", "k21", "k31", "V"], + "output": "x[0] / V", + "neqs": [3, 1], + "display": { + "name": "Three-Compartment IV Bolus", + "category": "pk", + "tags": ["3-compartment", "iv", "linear"] + }, + "documentation": { + "summary": "Three-compartment model with intravenous bolus administration and first-order elimination" + } +} diff --git a/src/json/library/models/pk_3cmt_oral.json b/src/json/library/models/pk_3cmt_oral.json new file mode 100644 index 0000000..e2877a1 --- /dev/null +++ b/src/json/library/models/pk_3cmt_oral.json @@ -0,0 +1,17 @@ +{ + "schema": "1.0", + "id": "pk/3cmt-oral", + "type": "analytical", + "analytical": "three_compartments_with_absorption", + "parameters": ["ka", "k10", "k12", "k13", "k21", "k31", "V"], + "output": "x[1] / V", + "neqs": [4, 1], + "display": { + "name": "Three-Compartment First-Order Absorption", + "category": "pk", + "tags": ["3-compartment", "oral", "linear", "first-order-absorption"] + }, + "documentation": { + "summary": "Three-compartment model with first-order oral absorption and first-order elimination" + } +} diff --git a/src/json/mod.rs b/src/json/mod.rs new file mode 100644 index 0000000..091d0fb --- /dev/null +++ b/src/json/mod.rs @@ -0,0 +1,219 @@ +//! JSON Model Definition and Code Generation +//! +//! This module provides functionality for defining pharmacometric models using JSON +//! and generating Rust code that can be compiled by the `exa` module. +//! +//! # Overview +//! +//! The JSON model system provides a declarative way to define pharmacometric models +//! without writing Rust code directly. Models are defined in JSON following a +//! structured schema, then validated and compiled to native code. +//! +//! The system supports three equation types: +//! - **Analytical**: Built-in closed-form solutions (fastest execution) +//! - **ODE**: Custom ordinary differential equations +//! - **SDE**: Stochastic differential equations with particle filtering +//! +//! # Quick Start +//! +//! ```ignore +//! use pharmsol::json::{parse_json, validate_json, generate_code}; +//! +//! // Define a model in JSON +//! let json = r#"{ +//! "schema": "1.0", +//! "id": "pk_1cmt_oral", +//! "type": "analytical", +//! "analytical": "one_compartment_with_absorption", +//! "parameters": ["ka", "ke", "V"], +//! "output": "x[1] / V" +//! }"#; +//! +//! // Parse and validate +//! let validated = validate_json(json)?; +//! +//! // Generate Rust code +//! let code = generate_code(json)?; +//! println!("Generated: {}", code.equation_code); +//! ``` +//! +//! # Using the Model Library +//! +//! The library provides pre-built standard PK models: +//! +//! ```ignore +//! use pharmsol::json::ModelLibrary; +//! +//! let library = ModelLibrary::builtin(); +//! +//! // List available models +//! for id in library.list() { +//! println!("Available: {}", id); +//! } +//! +//! // Get a specific model +//! let model = library.get("pk/1cmt-oral").unwrap(); +//! +//! // Search by keyword +//! let oral_models = library.search("oral"); +//! +//! // Filter by type +//! let ode_models = library.filter_by_type(ModelType::Ode); +//! ``` +//! +//! # Model Inheritance +//! +//! Models can extend base models to add customizations: +//! +//! ```ignore +//! use pharmsol::json::{JsonModel, ModelLibrary}; +//! +//! let mut library = ModelLibrary::builtin(); +//! +//! // Define a model that extends a library model +//! let derived = JsonModel::from_str(r#"{ +//! "schema": "1.0", +//! "id": "pk_1cmt_wt", +//! "extends": "pk/1cmt-oral", +//! "type": "analytical", +//! "analytical": "one_compartment_with_absorption", +//! "parameters": ["ka", "ke", "V"], +//! "covariates": [{ "id": "WT", "reference": 70.0 }], +//! "covariateEffects": [{ +//! "on": "V", +//! "covariate": "WT", +//! "type": "allometric", +//! "exponent": 1.0, +//! "reference": 70.0 +//! }] +//! }"#)?; +//! +//! // Resolve inherits base model's output expression +//! let resolved = library.resolve(&derived)?; +//! ``` +//! +//! # JSON Schema +//! +//! ## Required Fields +//! +//! | Field | Description | +//! |-------|-------------| +//! | `schema` | Schema version (currently `"1.0"`) | +//! | `id` | Unique model identifier | +//! | `type` | Equation type: `"analytical"`, `"ode"`, or `"sde"` | +//! +//! ## Model Type Specific Fields +//! +//! ### Analytical Models +//! - `analytical`: One of the built-in functions (e.g., `"one_compartment_with_absorption"`) +//! - `parameters`: Parameter names in order expected by the analytical function +//! - `output`: Output equation expression +//! +//! ### ODE Models +//! - `compartments`: List of compartment names +//! - `diffeq`: Differential equations (object or string) +//! - `parameters`: Parameter names +//! - `output`: Output equation expression +//! +//! ### SDE Models +//! - `states`: List of state variable names +//! - `drift`: Drift equations (deterministic part) +//! - `diffusion`: Diffusion coefficients +//! - `particles`: Number of particles for simulation +//! +//! ## Optional Features +//! +//! - `lag`: Lag times per compartment +//! - `fa`: Bioavailability factors +//! - `init`: Initial conditions +//! - `covariates`: Covariate definitions +//! - `covariateEffects`: Covariate effect specifications +//! - `errorModel`: Residual error model +//! +//! # Available Analytical Functions +//! +//! | Function | Parameters | States | +//! |----------|------------|--------| +//! | `one_compartment` | ke | 1 | +//! | `one_compartment_with_absorption` | ka, ke | 2 | +//! | `two_compartments` | ke, kcp, kpc | 2 | +//! | `two_compartments_with_absorption` | ke, ka, kcp, kpc | 3 | +//! | `three_compartments` | k10, k12, k13, k21, k31 | 3 | +//! | `three_compartments_with_absorption` | ka, k10, k12, k13, k21, k31 | 4 | +//! +//! # Error Handling +//! +//! All functions return `Result` with descriptive errors: +//! +//! ```ignore +//! match validate_json(json) { +//! Ok(model) => println!("Valid model: {}", model.inner().id), +//! Err(JsonModelError::MissingField { field, model_type }) => { +//! eprintln!("Missing {} for {} model", field, model_type); +//! } +//! Err(JsonModelError::UnsupportedSchema { version, .. }) => { +//! eprintln!("Schema {} not supported", version); +//! } +//! Err(e) => eprintln!("Error: {}", e), +//! } +//! ``` + +mod codegen; +mod errors; +pub mod library; +mod model; +mod types; +mod validation; + +pub use codegen::{CodeGenerator, GeneratedCode}; +pub use errors::JsonModelError; +pub use library::ModelLibrary; +pub use model::JsonModel; +pub use types::*; +pub use validation::{ValidatedModel, Validator}; + +/// Parse a JSON string into a JsonModel +pub fn parse_json(json: &str) -> Result { + JsonModel::from_str(json) +} + +/// Parse and validate a JSON model +pub fn validate_json(json: &str) -> Result { + let model = JsonModel::from_str(json)?; + let validator = Validator::new(); + validator.validate(&model) +} + +/// Parse, validate, and generate code from a JSON model +pub fn generate_code(json: &str) -> Result { + let model = JsonModel::from_str(json)?; + let validator = Validator::new(); + let validated = validator.validate(&model)?; + let generator = CodeGenerator::new(validated.inner()); + generator.generate() +} + +/// Compile a JSON model to a dynamic library +/// +/// This is the high-level API that combines parsing, validation, +/// code generation, and compilation into a single call. +/// +/// Requires the `exa` feature to be enabled. +#[cfg(feature = "exa")] +pub fn compile_json( + json: &str, + output_path: Option, + template_path: std::path::PathBuf, + event_callback: impl Fn(String, String) + Send + Sync + 'static, +) -> Result { + let generated = generate_code(json)?; + + crate::exa::build::compile::( + generated.equation_code, + output_path, + generated.parameters, + template_path, + event_callback, + ) + .map_err(|e| JsonModelError::CompilationError(e.to_string())) +} diff --git a/src/json/model.rs b/src/json/model.rs new file mode 100644 index 0000000..96fb00e --- /dev/null +++ b/src/json/model.rs @@ -0,0 +1,414 @@ +//! Main JSON Model struct + +use serde::{Deserialize, Serialize}; +use std::collections::HashMap; + +use crate::json::errors::JsonModelError; +use crate::json::types::*; + +/// Supported schema versions +pub const SUPPORTED_SCHEMA_VERSIONS: &[&str] = &["1.0"]; + +/// A pharmacometric model defined in JSON +/// +/// This is the main struct that represents a parsed JSON model file. +/// It supports all three equation types (analytical, ODE, SDE) and +/// includes optional fields for covariates, error models, and UI metadata. +/// +/// # Example +/// +/// ```ignore +/// use pharmsol::json::JsonModel; +/// +/// let json = r#"{ +/// "schema": "1.0", +/// "id": "pk_1cmt_oral", +/// "type": "analytical", +/// "analytical": "one_compartment_with_absorption", +/// "parameters": ["ka", "ke", "V"], +/// "output": "x[1] / V" +/// }"#; +/// +/// let model = JsonModel::from_str(json)?; +/// assert_eq!(model.id, "pk_1cmt_oral"); +/// ``` +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] +#[serde(deny_unknown_fields)] +pub struct JsonModel { + // ───────────────────────────────────────────────────────────────────────── + // Layer 1: Identity (always required) + // ───────────────────────────────────────────────────────────────────────── + /// Schema version (e.g., "1.0") + pub schema: String, + + /// Unique model identifier (snake_case) + pub id: String, + + /// Model equation type + #[serde(rename = "type")] + pub model_type: ModelType, + + /// Library model ID to inherit from + #[serde(skip_serializing_if = "Option::is_none")] + pub extends: Option, + + /// Model version (semver) + #[serde(skip_serializing_if = "Option::is_none")] + pub version: Option, + + /// Alternative names (e.g., NONMEM ADVAN codes) + #[serde(skip_serializing_if = "Option::is_none")] + pub aliases: Option>, + + // ───────────────────────────────────────────────────────────────────────── + // Layer 2: Structural Model + // ───────────────────────────────────────────────────────────────────────── + /// Parameter names in fetch order + #[serde(skip_serializing_if = "Option::is_none")] + pub parameters: Option>, + + /// Compartment names (indexed in declaration order) + #[serde(skip_serializing_if = "Option::is_none")] + pub compartments: Option>, + + /// State variable names (for SDE) + #[serde(skip_serializing_if = "Option::is_none")] + pub states: Option>, + + // ───────────────────────────────────────────────────────────────────────── + // Equation Fields (type-dependent) + // ───────────────────────────────────────────────────────────────────────── + /// Built-in analytical solution function (for analytical type) + #[serde(skip_serializing_if = "Option::is_none")] + pub analytical: Option, + + /// Differential equations (for ODE type) + #[serde(skip_serializing_if = "Option::is_none")] + pub diffeq: Option, + + /// SDE drift term (deterministic part) + #[serde(skip_serializing_if = "Option::is_none")] + pub drift: Option, + + /// SDE diffusion coefficients + #[serde(skip_serializing_if = "Option::is_none")] + pub diffusion: Option>, + + /// Secondary equations (for analytical) + #[serde(skip_serializing_if = "Option::is_none")] + pub secondary: Option, + + // ───────────────────────────────────────────────────────────────────────── + // Output + // ───────────────────────────────────────────────────────────────────────── + /// Single output equation + #[serde(skip_serializing_if = "Option::is_none")] + pub output: Option, + + /// Multiple output definitions + #[serde(skip_serializing_if = "Option::is_none")] + pub outputs: Option>, + + // ───────────────────────────────────────────────────────────────────────── + // Optional Features + // ───────────────────────────────────────────────────────────────────────── + /// Initial conditions + #[serde(skip_serializing_if = "Option::is_none")] + pub init: Option, + + /// Lag times per input compartment + #[serde(skip_serializing_if = "Option::is_none")] + pub lag: Option>, + + /// Bioavailability per input compartment + #[serde(skip_serializing_if = "Option::is_none")] + pub fa: Option>, + + /// [num_states, num_outputs] + #[serde(skip_serializing_if = "Option::is_none")] + pub neqs: Option<(usize, usize)>, + + /// Number of particles for SDE simulation + #[serde(skip_serializing_if = "Option::is_none")] + pub particles: Option, + + // ───────────────────────────────────────────────────────────────────────── + // Layer 3: Model Extensions + // ───────────────────────────────────────────────────────────────────────── + /// Derived parameters (computed from primary parameters) + #[serde(skip_serializing_if = "Option::is_none")] + pub derived: Option>, + + /// Enabled optional features + #[serde(skip_serializing_if = "Option::is_none")] + pub features: Option>, + + /// Covariate definitions + #[serde(skip_serializing_if = "Option::is_none")] + pub covariates: Option>, + + /// Covariate effect specifications + #[serde(rename = "covariateEffects", skip_serializing_if = "Option::is_none")] + pub covariate_effects: Option>, + + // ───────────────────────────────────────────────────────────────────────── + // Layer 4: UI Metadata (ignored by compiler) + // ───────────────────────────────────────────────────────────────────────── + /// UI display information + #[serde(skip_serializing_if = "Option::is_none")] + pub display: Option, + + /// Visual diagram layout + #[serde(skip_serializing_if = "Option::is_none")] + pub layout: Option>, + + /// Rich documentation + #[serde(skip_serializing_if = "Option::is_none")] + pub documentation: Option, +} + +impl JsonModel { + /// Parse a JSON string into a JsonModel + pub fn from_str(json: &str) -> Result { + let model: Self = serde_json::from_str(json)?; + model.check_schema_version()?; + Ok(model) + } + + /// Parse from a JSON Value + pub fn from_value(value: serde_json::Value) -> Result { + let model: Self = serde_json::from_value(value)?; + model.check_schema_version()?; + Ok(model) + } + + /// Serialize to a JSON string + pub fn to_json(&self) -> Result { + Ok(serde_json::to_string_pretty(self)?) + } + + /// Check if the schema version is supported + fn check_schema_version(&self) -> Result<(), JsonModelError> { + if !SUPPORTED_SCHEMA_VERSIONS.contains(&self.schema.as_str()) { + return Err(JsonModelError::UnsupportedSchema { + version: self.schema.clone(), + supported: SUPPORTED_SCHEMA_VERSIONS.join(", "), + }); + } + Ok(()) + } + + /// Get the number of states (inferred or explicit) + pub fn num_states(&self) -> usize { + if let Some((nstates, _)) = self.neqs { + return nstates; + } + + match self.model_type { + ModelType::Analytical => { + if let Some(func) = &self.analytical { + func.num_states() + } else { + 1 + } + } + ModelType::Ode => { + if let Some(compartments) = &self.compartments { + compartments.len() + } else if let Some(DiffEqSpec::Object(map)) = &self.diffeq { + map.len() + } else { + // Try to count from dx[n] in the string + 1 + } + } + ModelType::Sde => { + if let Some(states) = &self.states { + states.len() + } else if let Some(DiffEqSpec::Object(map)) = &self.drift { + map.len() + } else { + 1 + } + } + } + } + + /// Get the number of outputs (inferred or explicit) + pub fn num_outputs(&self) -> usize { + if let Some((_, nout)) = self.neqs { + return nout; + } + + if let Some(outputs) = &self.outputs { + outputs.len() + } else if self.output.is_some() { + 1 + } else { + 1 + } + } + + /// Get the neqs tuple + pub fn get_neqs(&self) -> (usize, usize) { + self.neqs.unwrap_or((self.num_states(), self.num_outputs())) + } + + /// Get compartment-to-index mapping + pub fn compartment_map(&self) -> HashMap { + let mut map = HashMap::new(); + if let Some(compartments) = &self.compartments { + for (i, name) in compartments.iter().enumerate() { + map.insert(name.clone(), i); + } + } + map + } + + /// Get state-to-index mapping (for SDE) + pub fn state_map(&self) -> HashMap { + let mut map = HashMap::new(); + if let Some(states) = &self.states { + for (i, name) in states.iter().enumerate() { + map.insert(name.clone(), i); + } + } + map + } + + /// Check if the model uses covariates + pub fn has_covariates(&self) -> bool { + self.covariates.is_some() && !self.covariates.as_ref().unwrap().is_empty() + } + + /// Check if the model uses lag times + pub fn has_lag(&self) -> bool { + self.lag.is_some() && !self.lag.as_ref().unwrap().is_empty() + } + + /// Check if the model uses bioavailability + pub fn has_fa(&self) -> bool { + self.fa.is_some() && !self.fa.as_ref().unwrap().is_empty() + } + + /// Check if the model has initial conditions + pub fn has_init(&self) -> bool { + self.init.is_some() + } + + /// Get the parameters as a vector (guaranteed non-empty after validation) + pub fn get_parameters(&self) -> Vec { + self.parameters.clone().unwrap_or_default() + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_parse_minimal_analytical() { + let json = r#"{ + "schema": "1.0", + "id": "pk_1cmt_iv", + "type": "analytical", + "analytical": "one_compartment", + "parameters": ["ke", "V"], + "output": "x[0] / V" + }"#; + + let model = JsonModel::from_str(json).unwrap(); + assert_eq!(model.id, "pk_1cmt_iv"); + assert_eq!(model.model_type, ModelType::Analytical); + assert_eq!(model.analytical, Some(AnalyticalFunction::OneCompartment)); + assert_eq!(model.num_states(), 1); + assert_eq!(model.num_outputs(), 1); + } + + #[test] + fn test_parse_minimal_ode() { + let json = r#"{ + "schema": "1.0", + "id": "pk_2cmt_ode", + "type": "ode", + "compartments": ["depot", "central", "peripheral"], + "parameters": ["ka", "ke", "k12", "k21", "V"], + "diffeq": { + "depot": "-ka * x[0]", + "central": "ka * x[0] - ke * x[1] - k12 * x[1] + k21 * x[2] + rateiv[1]", + "peripheral": "k12 * x[1] - k21 * x[2]" + }, + "output": "x[1] / V", + "neqs": [3, 1] + }"#; + + let model = JsonModel::from_str(json).unwrap(); + assert_eq!(model.id, "pk_2cmt_ode"); + assert_eq!(model.model_type, ModelType::Ode); + assert_eq!(model.num_states(), 3); + assert_eq!(model.compartment_map().get("central"), Some(&1)); + } + + #[test] + fn test_parse_sde() { + let json = r#"{ + "schema": "1.0", + "id": "pk_1cmt_sde", + "type": "sde", + "parameters": ["ke0", "sigma_ke", "V"], + "states": ["amount", "ke"], + "drift": { + "amount": "-ke * x[0]", + "ke": "-0.5 * (ke - ke0)" + }, + "diffusion": { + "ke": "sigma_ke" + }, + "init": { + "ke": "ke0" + }, + "output": "x[0] / V", + "neqs": [2, 1], + "particles": 1000 + }"#; + + let model = JsonModel::from_str(json).unwrap(); + assert_eq!(model.model_type, ModelType::Sde); + assert_eq!(model.particles, Some(1000)); + assert_eq!(model.state_map().get("ke"), Some(&1)); + } + + #[test] + fn test_unsupported_schema() { + let json = r#"{ + "schema": "999.0", + "id": "test", + "type": "ode", + "parameters": ["ke"], + "diffeq": "dx[0] = -ke * x[0];", + "output": "x[0]" + }"#; + + let result = JsonModel::from_str(json); + assert!(matches!( + result, + Err(JsonModelError::UnsupportedSchema { .. }) + )); + } + + #[test] + fn test_unknown_field_rejected() { + let json = r#"{ + "schema": "1.0", + "id": "test", + "type": "ode", + "parameters": ["ke"], + "diffeq": "dx[0] = -ke * x[0];", + "output": "x[0]", + "unknown_field": "should fail" + }"#; + + let result = JsonModel::from_str(json); + assert!(result.is_err()); + } +} diff --git a/src/json/types.rs b/src/json/types.rs new file mode 100644 index 0000000..bb5f56a --- /dev/null +++ b/src/json/types.rs @@ -0,0 +1,499 @@ +//! Core type definitions for JSON models + +use serde::{Deserialize, Serialize}; +use std::collections::HashMap; + +// ═══════════════════════════════════════════════════════════════════════════════ +// Model Type +// ═══════════════════════════════════════════════════════════════════════════════ + +/// The type of equation system used by the model +#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)] +#[serde(rename_all = "lowercase")] +pub enum ModelType { + /// Analytical (closed-form) solution + Analytical, + /// Ordinary differential equations + Ode, + /// Stochastic differential equations + Sde, +} + +impl std::fmt::Display for ModelType { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Self::Analytical => write!(f, "analytical"), + Self::Ode => write!(f, "ode"), + Self::Sde => write!(f, "sde"), + } + } +} + +// ═══════════════════════════════════════════════════════════════════════════════ +// Analytical Functions +// ═══════════════════════════════════════════════════════════════════════════════ + +/// Built-in analytical solution functions +#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)] +#[serde(rename_all = "snake_case")] +pub enum AnalyticalFunction { + /// One compartment IV (ke) + OneCompartment, + /// One compartment with first-order absorption (ka, ke) + OneCompartmentWithAbsorption, + /// Two compartments IV (ke, kcp, kpc) + TwoCompartments, + /// Two compartments with absorption (ke, ka, kcp, kpc) + TwoCompartmentsWithAbsorption, + /// Three compartments IV (k10, k12, k13, k21, k31) + ThreeCompartments, + /// Three compartments with absorption (ka, k10, k12, k13, k21, k31) + ThreeCompartmentsWithAbsorption, +} + +impl AnalyticalFunction { + /// Get the Rust function name for code generation + pub fn rust_name(&self) -> &'static str { + match self { + Self::OneCompartment => "one_compartment", + Self::OneCompartmentWithAbsorption => "one_compartment_with_absorption", + Self::TwoCompartments => "two_compartments", + Self::TwoCompartmentsWithAbsorption => "two_compartments_with_absorption", + Self::ThreeCompartments => "three_compartments", + Self::ThreeCompartmentsWithAbsorption => "three_compartments_with_absorption", + } + } + + /// Get the expected parameter names for this function (in order) + pub fn expected_parameters(&self) -> Vec<&'static str> { + match self { + Self::OneCompartment => vec!["ke"], + Self::OneCompartmentWithAbsorption => vec!["ka", "ke"], + Self::TwoCompartments => vec!["ke", "kcp", "kpc"], + Self::TwoCompartmentsWithAbsorption => vec!["ke", "ka", "kcp", "kpc"], + Self::ThreeCompartments => vec!["k10", "k12", "k13", "k21", "k31"], + Self::ThreeCompartmentsWithAbsorption => { + vec!["ka", "k10", "k12", "k13", "k21", "k31"] + } + } + } + + /// Get the number of states for this function + pub fn num_states(&self) -> usize { + match self { + Self::OneCompartment => 1, + Self::OneCompartmentWithAbsorption => 2, + Self::TwoCompartments => 2, + Self::TwoCompartmentsWithAbsorption => 3, + Self::ThreeCompartments => 3, + Self::ThreeCompartmentsWithAbsorption => 4, + } + } +} + +// ═══════════════════════════════════════════════════════════════════════════════ +// Expression Types +// ═══════════════════════════════════════════════════════════════════════════════ + +/// A Rust expression string +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +#[serde(transparent)] +pub struct Expression(pub String); + +impl Expression { + /// Create a new expression + pub fn new(s: impl Into) -> Self { + Self(s.into()) + } + + /// Get the expression string + pub fn as_str(&self) -> &str { + &self.0 + } + + /// Check if the expression is empty + pub fn is_empty(&self) -> bool { + self.0.trim().is_empty() + } +} + +impl From for Expression { + fn from(s: String) -> Self { + Self(s) + } +} + +impl From<&str> for Expression { + fn from(s: &str) -> Self { + Self(s.to_string()) + } +} + +impl AsRef for Expression { + fn as_ref(&self) -> &str { + &self.0 + } +} + +/// Either an expression or a numeric value +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] +#[serde(untagged)] +pub enum ExpressionOrNumber { + /// A numeric constant + Number(f64), + /// A Rust expression + Expression(String), +} + +impl ExpressionOrNumber { + /// Convert to a Rust expression string + pub fn to_rust_expr(&self) -> String { + match self { + Self::Number(n) => format!("{:.6}", n), + Self::Expression(s) => s.clone(), + } + } +} + +// ═══════════════════════════════════════════════════════════════════════════════ +// Differential Equation Specification +// ═══════════════════════════════════════════════════════════════════════════════ + +/// Differential equation specification (string or object format) +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] +#[serde(untagged)] +pub enum DiffEqSpec { + /// Single string with all equations + String(String), + /// Map of compartment name to equation + Object(HashMap), +} + +impl DiffEqSpec { + /// Check if empty + pub fn is_empty(&self) -> bool { + match self { + Self::String(s) => s.trim().is_empty(), + Self::Object(m) => m.is_empty(), + } + } +} + +// ═══════════════════════════════════════════════════════════════════════════════ +// Initial Conditions +// ═══════════════════════════════════════════════════════════════════════════════ + +/// Initial condition specification +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] +#[serde(untagged)] +pub enum InitSpec { + /// Single string with all init code + String(String), + /// Map of compartment/state name to initial value + Object(HashMap), +} + +// ═══════════════════════════════════════════════════════════════════════════════ +// Output Definition +// ═══════════════════════════════════════════════════════════════════════════════ + +/// Definition of a model output +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] +pub struct OutputDefinition { + /// Output identifier + #[serde(skip_serializing_if = "Option::is_none")] + pub id: Option, + + /// Output equation expression + pub equation: String, + + /// Human-readable name + #[serde(skip_serializing_if = "Option::is_none")] + pub name: Option, + + /// Output units + #[serde(skip_serializing_if = "Option::is_none")] + pub units: Option, +} + +// ═══════════════════════════════════════════════════════════════════════════════ +// Derived Parameters +// ═══════════════════════════════════════════════════════════════════════════════ + +/// Derived parameter definition +/// +/// Derived parameters are computed from primary parameters using expressions. +/// For example, ke = CL / V computes elimination rate constant from +/// clearance and volume. +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] +pub struct DerivedParameter { + /// Symbol for the derived parameter + pub symbol: String, + + /// Expression to compute the derived parameter + pub expression: String, +} + +// ═══════════════════════════════════════════════════════════════════════════════ +// Covariates +// ═══════════════════════════════════════════════════════════════════════════════ + +/// Covariate type +#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)] +#[serde(rename_all = "lowercase")] +pub enum CovariateType { + /// Continuous covariate + #[default] + Continuous, + /// Categorical covariate + Categorical, +} + +/// Interpolation method for time-varying covariates +#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)] +#[serde(rename_all = "lowercase")] +pub enum InterpolationMethod { + /// Linear interpolation + #[default] + Linear, + /// Constant (use value at time point) + Constant, + /// Last observation carried forward + Locf, +} + +/// Covariate definition +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] +pub struct CovariateDefinition { + /// Covariate identifier + pub id: String, + + /// Human-readable name + #[serde(skip_serializing_if = "Option::is_none")] + pub name: Option, + + /// Covariate type + #[serde(rename = "type", default)] + pub cov_type: CovariateType, + + /// Units for continuous covariates + #[serde(skip_serializing_if = "Option::is_none")] + pub units: Option, + + /// Reference value for centering + #[serde(skip_serializing_if = "Option::is_none")] + pub reference: Option, + + /// Interpolation method + #[serde(default)] + pub interpolation: InterpolationMethod, + + /// Possible values for categorical covariates + #[serde(skip_serializing_if = "Option::is_none")] + pub levels: Option>, +} + +/// Covariate effect type +#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)] +#[serde(rename_all = "lowercase")] +pub enum CovariateEffectType { + /// Allometric scaling: P * (cov/ref)^exp + Allometric, + /// Linear effect: P * (1 + slope * (cov - ref)) + Linear, + /// Exponential effect: P * exp(slope * (cov - ref)) + Exponential, + /// Proportional effect: P * (1 + slope * cov) + Proportional, + /// Categorical effect: P * theta_level + Categorical, + /// Custom expression + Custom, +} + +/// Covariate effect specification +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] +pub struct CovariateEffect { + /// Parameter affected by this covariate + pub on: String, + + /// Covariate ID + #[serde(skip_serializing_if = "Option::is_none")] + pub covariate: Option, + + /// Effect type + #[serde(rename = "type")] + pub effect_type: CovariateEffectType, + + /// Exponent for allometric scaling + #[serde(skip_serializing_if = "Option::is_none")] + pub exponent: Option, + + /// Slope for linear/exponential effects + #[serde(skip_serializing_if = "Option::is_none")] + pub slope: Option, + + /// Reference value for centering + #[serde(skip_serializing_if = "Option::is_none")] + pub reference: Option, + + /// Custom expression + #[serde(skip_serializing_if = "Option::is_none")] + pub expression: Option, + + /// Multipliers for categorical levels + #[serde(skip_serializing_if = "Option::is_none")] + pub levels: Option>, +} + +// ═══════════════════════════════════════════════════════════════════════════════ +// Error Model Type (hint only, values provided by PMcore Settings) +// ═══════════════════════════════════════════════════════════════════════════════ + +/// Error model type (for documentation/hints only) +/// +/// Note: The actual error model parameters (σ values) should be configured +/// in PMcore's Settings struct, not in the JSON model. This enum is kept +/// for documentation purposes and to indicate the intended error structure. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)] +#[serde(rename_all = "lowercase")] +pub enum ErrorModelType { + /// Additive error: σ = a + Additive, + /// Proportional error: σ = b × f + Proportional, + /// Combined error: σ = √(a² + b²×f²) + Combined, + /// Polynomial error: σ = c₀ + c₁f + c₂f² + c₃f³ + Polynomial, +} + +// ═══════════════════════════════════════════════════════════════════════════════ +// UI Metadata (ignored by compiler) +// ═══════════════════════════════════════════════════════════════════════════════ + +/// Model complexity level +#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)] +#[serde(rename_all = "lowercase")] +pub enum Complexity { + Basic, + Intermediate, + Advanced, +} + +/// Model category +#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)] +#[serde(rename_all = "lowercase")] +pub enum Category { + Pk, + Pd, + Pkpd, + Disease, + Other, +} + +/// Position for layout +#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)] +pub struct Position { + pub x: f64, + pub y: f64, +} + +/// Display information for UI +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, Default)] +pub struct DisplayInfo { + /// Human-readable model name + #[serde(skip_serializing_if = "Option::is_none")] + pub name: Option, + + /// Abbreviated name + #[serde(skip_serializing_if = "Option::is_none")] + pub short_name: Option, + + /// Model category + #[serde(skip_serializing_if = "Option::is_none")] + pub category: Option, + + /// Model subcategory + #[serde(skip_serializing_if = "Option::is_none")] + pub subcategory: Option, + + /// Complexity level + #[serde(skip_serializing_if = "Option::is_none")] + pub complexity: Option, + + /// Icon identifier + #[serde(skip_serializing_if = "Option::is_none")] + pub icon: Option, + + /// Searchable tags + #[serde(skip_serializing_if = "Option::is_none")] + pub tags: Option>, +} + +/// Literature reference +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] +pub struct Reference { + #[serde(skip_serializing_if = "Option::is_none")] + pub authors: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub title: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub journal: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub year: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub doi: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub pmid: Option, +} + +/// LaTeX equations for display +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, Default)] +pub struct EquationDocs { + #[serde(skip_serializing_if = "Option::is_none")] + pub differential: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub solution: Option, +} + +/// Rich documentation +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, Default)] +pub struct Documentation { + /// One-line summary + #[serde(skip_serializing_if = "Option::is_none")] + pub summary: Option, + + /// Detailed description + #[serde(skip_serializing_if = "Option::is_none")] + pub description: Option, + + /// LaTeX equations + #[serde(skip_serializing_if = "Option::is_none")] + pub equations: Option, + + /// Model assumptions + #[serde(skip_serializing_if = "Option::is_none")] + pub assumptions: Option>, + + /// When to use this model + #[serde(skip_serializing_if = "Option::is_none")] + pub when_to_use: Option>, + + /// When NOT to use this model + #[serde(skip_serializing_if = "Option::is_none")] + pub when_not_to_use: Option>, + + /// Literature references + #[serde(skip_serializing_if = "Option::is_none")] + pub references: Option>, +} + +/// Optional features that can be enabled +#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)] +#[serde(rename_all = "snake_case")] +pub enum Feature { + LagTime, + Bioavailability, + InitialConditions, +} diff --git a/src/json/validation.rs b/src/json/validation.rs new file mode 100644 index 0000000..8a966e1 --- /dev/null +++ b/src/json/validation.rs @@ -0,0 +1,451 @@ +//! Validation for JSON models + +use std::collections::HashSet; + +use crate::json::errors::JsonModelError; +use crate::json::model::JsonModel; +use crate::json::types::*; + +/// A validated JSON model +/// +/// This wrapper type guarantees that the contained model has passed +/// all validation checks and is ready for code generation. +#[derive(Debug, Clone)] +pub struct ValidatedModel(JsonModel); + +impl ValidatedModel { + /// Get the inner JsonModel + pub fn inner(&self) -> &JsonModel { + &self.0 + } + + /// Consume the wrapper and return the inner JsonModel + pub fn into_inner(self) -> JsonModel { + self.0 + } +} + +/// Validator for JSON models +pub struct Validator { + /// Whether to treat warnings as errors + strict: bool, +} + +impl Default for Validator { + fn default() -> Self { + Self::new() + } +} + +impl Validator { + /// Create a new validator + pub fn new() -> Self { + Self { strict: false } + } + + /// Create a strict validator that treats warnings as errors + pub fn strict() -> Self { + Self { strict: true } + } + + /// Validate a JSON model + pub fn validate(&self, model: &JsonModel) -> Result { + // 1. Validate type-specific requirements + self.validate_type_requirements(model)?; + + // 2. Validate parameters + self.validate_parameters(model)?; + + // 3. Validate output + self.validate_output(model)?; + + // 4. Validate compartments/states + self.validate_compartments(model)?; + + // 5. Validate covariates + self.validate_covariates(model)?; + + // 6. Validate covariate effects + self.validate_covariate_effects(model)?; + + // 7. Validate analytical function parameters + if let Some(func) = &model.analytical { + self.validate_analytical_params(model, func)?; + } + + Ok(ValidatedModel(model.clone())) + } + + /// Validate type-specific field requirements + fn validate_type_requirements(&self, model: &JsonModel) -> Result<(), JsonModelError> { + match model.model_type { + ModelType::Analytical => { + // Must have analytical function + if model.analytical.is_none() { + return Err(JsonModelError::missing_field("analytical", "analytical")); + } + // Must not have ODE/SDE fields + if model.diffeq.is_some() { + return Err(JsonModelError::invalid_field("diffeq", "analytical")); + } + if model.drift.is_some() { + return Err(JsonModelError::invalid_field("drift", "analytical")); + } + if model.diffusion.is_some() { + return Err(JsonModelError::invalid_field("diffusion", "analytical")); + } + } + ModelType::Ode => { + // Must have diffeq + if model.diffeq.is_none() { + return Err(JsonModelError::missing_field("diffeq", "ode")); + } + // Must not have analytical/SDE fields + if model.analytical.is_some() { + return Err(JsonModelError::invalid_field("analytical", "ode")); + } + if model.drift.is_some() { + return Err(JsonModelError::invalid_field("drift", "ode")); + } + if model.diffusion.is_some() { + return Err(JsonModelError::invalid_field("diffusion", "ode")); + } + } + ModelType::Sde => { + // Must have drift and diffusion + if model.drift.is_none() { + return Err(JsonModelError::missing_field("drift", "sde")); + } + if model.diffusion.is_none() { + return Err(JsonModelError::missing_field("diffusion", "sde")); + } + // Must not have analytical/ODE fields + if model.analytical.is_some() { + return Err(JsonModelError::invalid_field("analytical", "sde")); + } + if model.diffeq.is_some() { + return Err(JsonModelError::invalid_field("diffeq", "sde")); + } + } + } + Ok(()) + } + + /// Validate parameters + fn validate_parameters(&self, model: &JsonModel) -> Result<(), JsonModelError> { + // Parameters required unless using extends + if model.extends.is_none() && model.parameters.is_none() { + return Err(JsonModelError::MissingParameters); + } + + if let Some(params) = &model.parameters { + // Check for duplicates + let mut seen = HashSet::new(); + for param in params { + if !seen.insert(param.clone()) { + return Err(JsonModelError::DuplicateParameter { + name: param.clone(), + }); + } + } + + // Check for empty parameters + if params.is_empty() && model.extends.is_none() { + return Err(JsonModelError::MissingParameters); + } + } + + Ok(()) + } + + /// Validate output + fn validate_output(&self, model: &JsonModel) -> Result<(), JsonModelError> { + // Output required unless using extends + if model.extends.is_none() && model.output.is_none() && model.outputs.is_none() { + return Err(JsonModelError::MissingOutput); + } + + // Check for empty output + if let Some(output) = &model.output { + if output.trim().is_empty() { + return Err(JsonModelError::EmptyExpression { + context: "output".to_string(), + }); + } + } + + // Check outputs array + if let Some(outputs) = &model.outputs { + for (i, out) in outputs.iter().enumerate() { + if out.equation.trim().is_empty() { + return Err(JsonModelError::EmptyExpression { + context: format!("outputs[{}]", i), + }); + } + } + } + + Ok(()) + } + + /// Validate compartments + fn validate_compartments(&self, model: &JsonModel) -> Result<(), JsonModelError> { + if let Some(compartments) = &model.compartments { + let mut seen = HashSet::new(); + for cmt in compartments { + if !seen.insert(cmt.clone()) { + return Err(JsonModelError::DuplicateCompartment { name: cmt.clone() }); + } + } + } + + if let Some(states) = &model.states { + let mut seen = HashSet::new(); + for state in states { + if !seen.insert(state.clone()) { + return Err(JsonModelError::DuplicateCompartment { + name: state.clone(), + }); + } + } + } + + Ok(()) + } + + /// Validate covariate definitions + fn validate_covariates(&self, model: &JsonModel) -> Result<(), JsonModelError> { + if let Some(covariates) = &model.covariates { + let mut seen = HashSet::new(); + for cov in covariates { + if !seen.insert(cov.id.clone()) { + return Err(JsonModelError::UndefinedCovariate { + name: format!("duplicate covariate: {}", cov.id), + }); + } + } + } + Ok(()) + } + + /// Validate covariate effects + fn validate_covariate_effects(&self, model: &JsonModel) -> Result<(), JsonModelError> { + if let Some(effects) = &model.covariate_effects { + let params: HashSet<_> = model + .parameters + .as_ref() + .map(|p| p.iter().cloned().collect()) + .unwrap_or_default(); + + let covariates: HashSet<_> = model + .covariates + .as_ref() + .map(|c| c.iter().map(|cov| cov.id.clone()).collect()) + .unwrap_or_default(); + + for effect in effects { + // Check that target parameter exists + if !params.is_empty() && !params.contains(&effect.on) { + return Err(JsonModelError::InvalidCovariateEffectTarget { + parameter: effect.on.clone(), + }); + } + + // Check type-specific requirements + match effect.effect_type { + CovariateEffectType::Allometric => { + if effect.covariate.is_none() { + return Err(JsonModelError::MissingCovariateEffectField { + effect_type: "allometric".to_string(), + field: "covariate".to_string(), + }); + } + if effect.exponent.is_none() { + return Err(JsonModelError::MissingCovariateEffectField { + effect_type: "allometric".to_string(), + field: "exponent".to_string(), + }); + } + } + CovariateEffectType::Linear | CovariateEffectType::Exponential => { + if effect.covariate.is_none() { + return Err(JsonModelError::MissingCovariateEffectField { + effect_type: format!("{:?}", effect.effect_type).to_lowercase(), + field: "covariate".to_string(), + }); + } + if effect.slope.is_none() { + return Err(JsonModelError::MissingCovariateEffectField { + effect_type: format!("{:?}", effect.effect_type).to_lowercase(), + field: "slope".to_string(), + }); + } + } + CovariateEffectType::Custom => { + if effect.expression.is_none() { + return Err(JsonModelError::MissingCovariateEffectField { + effect_type: "custom".to_string(), + field: "expression".to_string(), + }); + } + } + CovariateEffectType::Categorical => { + if effect.covariate.is_none() { + return Err(JsonModelError::MissingCovariateEffectField { + effect_type: "categorical".to_string(), + field: "covariate".to_string(), + }); + } + if effect.levels.is_none() { + return Err(JsonModelError::MissingCovariateEffectField { + effect_type: "categorical".to_string(), + field: "levels".to_string(), + }); + } + } + CovariateEffectType::Proportional => { + if effect.covariate.is_none() { + return Err(JsonModelError::MissingCovariateEffectField { + effect_type: "proportional".to_string(), + field: "covariate".to_string(), + }); + } + } + } + + // Check that referenced covariate exists + if let Some(cov_name) = &effect.covariate { + if !covariates.is_empty() && !covariates.contains(cov_name) { + return Err(JsonModelError::UndefinedCovariate { + name: cov_name.clone(), + }); + } + } + } + } + Ok(()) + } + + /// Validate analytical function parameters + fn validate_analytical_params( + &self, + model: &JsonModel, + func: &AnalyticalFunction, + ) -> Result<(), JsonModelError> { + let expected = func.expected_parameters(); + let actual = model.get_parameters(); + + // Check if expected parameters are present at the start (in order) + // Extra parameters (like V, tlag) are allowed after + if self.strict && actual.len() >= expected.len() { + let actual_prefix: Vec<_> = actual.iter().take(expected.len()).cloned().collect(); + let expected_vec: Vec<_> = expected.iter().map(|s| s.to_string()).collect(); + + if actual_prefix != expected_vec { + return Err(JsonModelError::ParameterOrderWarning { + function: func.rust_name().to_string(), + expected: expected_vec, + actual: actual_prefix, + }); + } + } + + Ok(()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_validate_missing_analytical() { + let json = r#"{ + "schema": "1.0", + "id": "test", + "type": "analytical", + "parameters": ["ke"], + "output": "x[0]" + }"#; + + let model = JsonModel::from_str(json).unwrap(); + let result = Validator::new().validate(&model); + assert!(matches!( + result, + Err(JsonModelError::MissingField { field, .. }) if field == "analytical" + )); + } + + #[test] + fn test_validate_missing_diffeq() { + let json = r#"{ + "schema": "1.0", + "id": "test", + "type": "ode", + "parameters": ["ke"], + "output": "x[0]" + }"#; + + let model = JsonModel::from_str(json).unwrap(); + let result = Validator::new().validate(&model); + assert!(matches!( + result, + Err(JsonModelError::MissingField { field, .. }) if field == "diffeq" + )); + } + + #[test] + fn test_validate_invalid_field_for_type() { + let json = r#"{ + "schema": "1.0", + "id": "test", + "type": "analytical", + "analytical": "one_compartment", + "diffeq": "dx[0] = -ke * x[0];", + "parameters": ["ke"], + "output": "x[0]" + }"#; + + let model = JsonModel::from_str(json).unwrap(); + let result = Validator::new().validate(&model); + assert!(matches!( + result, + Err(JsonModelError::InvalidFieldForType { field, .. }) if field == "diffeq" + )); + } + + #[test] + fn test_validate_duplicate_parameter() { + let json = r#"{ + "schema": "1.0", + "id": "test", + "type": "ode", + "parameters": ["ke", "V", "ke"], + "diffeq": "dx[0] = -ke * x[0];", + "output": "x[0]" + }"#; + + let model = JsonModel::from_str(json).unwrap(); + let result = Validator::new().validate(&model); + assert!(matches!( + result, + Err(JsonModelError::DuplicateParameter { name }) if name == "ke" + )); + } + + #[test] + fn test_validate_valid_model() { + let json = r#"{ + "schema": "1.0", + "id": "pk_1cmt_oral", + "type": "analytical", + "analytical": "one_compartment_with_absorption", + "parameters": ["ka", "ke", "V"], + "output": "x[1] / V" + }"#; + + let model = JsonModel::from_str(json).unwrap(); + let result = Validator::new().validate(&model); + assert!(result.is_ok()); + } +} diff --git a/src/lib.rs b/src/lib.rs index 36c5e6d..3e2dfc8 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -2,6 +2,7 @@ pub mod data; pub mod error; #[cfg(feature = "exa")] pub mod exa; +pub mod json; pub mod nca; pub mod optimize; pub mod simulator; diff --git a/tests/test_json.rs b/tests/test_json.rs new file mode 100644 index 0000000..91f7106 --- /dev/null +++ b/tests/test_json.rs @@ -0,0 +1,788 @@ +//! Integration tests for the JSON model system +//! +//! These tests validate the complete pipeline from JSON parsing to code generation. + +use pharmsol::json::{ + generate_code, parse_json, validate_json, CodeGenerator, JsonModel, ModelLibrary, ModelType, + Validator, +}; + +// ═══════════════════════════════════════════════════════════════════════════════ +// Parsing Tests +// ═══════════════════════════════════════════════════════════════════════════════ + +mod parsing { + use super::*; + + #[test] + fn test_parse_complete_analytical_model() { + let json = r#"{ + "schema": "1.0", + "id": "pk_2cmt_oral", + "type": "analytical", + "version": "1.0.0", + "analytical": "two_compartments_with_absorption", + "parameters": ["ke", "ka", "kcp", "kpc", "V"], + "output": "x[1] / V", + "neqs": [3, 1], + "display": { + "name": "Two-Compartment Oral", + "category": "pk", + "tags": ["2-compartment", "oral"] + }, + "documentation": { + "summary": "Standard two-compartment oral PK model" + } + }"#; + + let model = parse_json(json).expect("Should parse successfully"); + assert_eq!(model.id, "pk_2cmt_oral"); + assert_eq!(model.model_type, ModelType::Analytical); + assert_eq!(model.parameters.as_ref().unwrap().len(), 5); + } + + #[test] + fn test_parse_complete_ode_model() { + let json = r#"{ + "schema": "1.0", + "id": "pk_mm_1cmt", + "type": "ode", + "parameters": ["Vmax", "Km", "V"], + "compartments": ["central"], + "diffeq": { + "central": "-Vmax * (central/V) / (Km + central/V)" + }, + "output": "central / V", + "neqs": [1, 1] + }"#; + + let model = parse_json(json).expect("Should parse successfully"); + assert_eq!(model.model_type, ModelType::Ode); + assert!(model.diffeq.is_some()); + } + + #[test] + fn test_parse_with_covariates() { + let json = r#"{ + "schema": "1.0", + "id": "pk_1cmt_wt", + "type": "analytical", + "analytical": "one_compartment", + "parameters": ["ke", "V"], + "output": "x[0] / V", + "covariates": [ + { "id": "WT", "reference": 70.0, "units": "kg" } + ], + "covariateEffects": [ + { + "covariate": "WT", + "on": "V", + "type": "allometric", + "exponent": 0.75, + "reference": 70.0 + } + ] + }"#; + + let model = parse_json(json).expect("Should parse successfully"); + assert!(model.covariates.is_some()); + assert!(model.covariate_effects.is_some()); + assert_eq!(model.covariate_effects.as_ref().unwrap().len(), 1); + } + + #[test] + fn test_parse_with_lag_and_fa() { + let json = r#"{ + "schema": "1.0", + "id": "pk_1cmt_lag", + "type": "ode", + "parameters": ["ka", "CL", "V", "APTS", "FFA"], + "compartments": ["depot", "central"], + "diffeq": { + "depot": "-ka * depot", + "central": "ka * depot - CL/V * central" + }, + "output": "central / V", + "lag": { + "depot": "APTS" + }, + "fa": { + "depot": "FFA" + } + }"#; + + let model = parse_json(json).expect("Should parse successfully"); + assert!(model.lag.is_some()); + assert!(model.fa.is_some()); + } + + #[test] + fn test_reject_unknown_fields() { + let json = r#"{ + "schema": "1.0", + "id": "bad_model", + "type": "ode", + "unknownField": "should fail" + }"#; + + let result = parse_json(json); + assert!(result.is_err()); + } + + #[test] + fn test_reject_unsupported_schema() { + let json = r#"{ + "schema": "99.0", + "id": "future_model", + "type": "ode" + }"#; + + let result = parse_json(json); + assert!(result.is_err()); + } +} + +// ═══════════════════════════════════════════════════════════════════════════════ +// Validation Tests +// ═══════════════════════════════════════════════════════════════════════════════ + +mod validation { + use super::*; + + #[test] + fn test_validate_complete_model() { + let json = r#"{ + "schema": "1.0", + "id": "pk_1cmt", + "type": "analytical", + "analytical": "one_compartment", + "parameters": ["ke", "V"], + "output": "x[0] / V" + }"#; + + let validated = validate_json(json).expect("Should validate successfully"); + assert_eq!(validated.inner().id, "pk_1cmt"); + } + + #[test] + fn test_validate_rejects_missing_analytical() { + let json = r#"{ + "schema": "1.0", + "id": "bad_analytical", + "type": "analytical", + "parameters": ["ke", "V"], + "output": "x[0] / V" + }"#; + + let result = validate_json(json); + assert!(result.is_err()); + } + + #[test] + fn test_validate_rejects_missing_diffeq() { + let json = r#"{ + "schema": "1.0", + "id": "bad_ode", + "type": "ode", + "parameters": ["ke", "V"], + "output": "x[0] / V" + }"#; + + let result = validate_json(json); + assert!(result.is_err()); + } + + #[test] + fn test_validate_rejects_duplicate_parameters() { + let json = r#"{ + "schema": "1.0", + "id": "dup_params", + "type": "analytical", + "analytical": "one_compartment", + "parameters": ["ke", "V", "ke"], + "output": "x[0] / V" + }"#; + + let result = validate_json(json); + assert!(result.is_err()); + } + + #[test] + fn test_validate_ode_with_compartments() { + let json = r#"{ + "schema": "1.0", + "id": "ode_with_cmt", + "type": "ode", + "parameters": ["ka", "CL", "V"], + "compartments": ["depot", "central"], + "diffeq": { + "depot": "-ka * depot", + "central": "ka * depot - CL/V * central" + }, + "output": "central / V" + }"#; + + let validated = validate_json(json).expect("Should validate successfully"); + assert_eq!(validated.inner().compartments.as_ref().unwrap().len(), 2); + } +} + +// ═══════════════════════════════════════════════════════════════════════════════ +// Code Generation Tests +// ═══════════════════════════════════════════════════════════════════════════════ + +mod codegen { + use super::*; + + #[test] + fn test_generate_analytical_code() { + let json = r#"{ + "schema": "1.0", + "id": "pk_1cmt", + "type": "analytical", + "analytical": "one_compartment_with_absorption", + "parameters": ["ka", "ke", "V"], + "output": "x[1] / V" + }"#; + + let code = generate_code(json).expect("Should generate code"); + + // Check generated code contains expected elements + assert!(code.equation_code.contains("Analytical::new")); + assert!(code + .equation_code + .contains("one_compartment_with_absorption")); + assert!(code.equation_code.contains("fetch_params!")); + assert!(code.equation_code.contains("y[0] = x[1] / V")); + + assert_eq!(code.parameters, vec!["ka", "ke", "V"]); + } + + #[test] + fn test_generate_ode_code() { + let json = r#"{ + "schema": "1.0", + "id": "pk_1cmt_ode", + "type": "ode", + "parameters": ["CL", "V"], + "compartments": ["central"], + "diffeq": { + "central": "-CL/V * central" + }, + "output": "central / V" + }"#; + + let code = generate_code(json).expect("Should generate code"); + + assert!(code.equation_code.contains("ODE::new")); + assert!(code.equation_code.contains("fetch_params!")); + // ODE uses dx[idx] = expression format + assert!(code.equation_code.contains("dx[0]")); + } + + #[test] + fn test_generate_with_lag() { + let json = r#"{ + "schema": "1.0", + "id": "pk_with_lag", + "type": "ode", + "parameters": ["ka", "CL", "V", "APTS"], + "compartments": ["depot", "central"], + "diffeq": { + "depot": "-ka * depot", + "central": "ka * depot - CL/V * central" + }, + "output": "central / V", + "lag": { + "depot": "APTS" + } + }"#; + + let code = generate_code(json).expect("Should generate code"); + + assert!(code.equation_code.contains("lag!")); + // depot is compartment 0, so should be "0 => APTS" + assert!(code.equation_code.contains("=> APTS")); + } + + #[test] + fn test_generate_with_init() { + let json = r#"{ + "schema": "1.0", + "id": "pk_with_init", + "type": "ode", + "parameters": ["CL", "V", "A0"], + "compartments": ["central"], + "diffeq": { + "central": "-CL/V * central" + }, + "init": { + "central": "A0" + }, + "output": "central / V" + }"#; + + let code = generate_code(json).expect("Should generate code"); + + assert!(code.equation_code.contains("x[0] = A0")); + } + + #[test] + fn test_generate_with_covariates() { + let json = r#"{ + "schema": "1.0", + "id": "pk_cov", + "type": "analytical", + "analytical": "one_compartment", + "parameters": ["ke", "V"], + "output": "x[0] / V", + "covariates": [ + { "id": "WT", "reference": 70.0 } + ], + "covariateEffects": [ + { + "covariate": "WT", + "on": "V", + "type": "allometric", + "exponent": 0.75, + "reference": 70.0 + } + ] + }"#; + + let code = generate_code(json).expect("Should generate code"); + + // Should include covariate access and effect + assert!(code.equation_code.contains("cov.get_covariate")); + // Allometric: V * (WT / ref)^exp + assert!(code.equation_code.contains("powf")); + } +} + +// ═══════════════════════════════════════════════════════════════════════════════ +// Library Tests +// ═══════════════════════════════════════════════════════════════════════════════ + +mod library { + use super::*; + + #[test] + fn test_builtin_library_contains_standard_models() { + let library = ModelLibrary::builtin(); + + // Should have all expected models + assert!(library.contains("pk/1cmt-iv")); + assert!(library.contains("pk/1cmt-oral")); + assert!(library.contains("pk/2cmt-iv")); + assert!(library.contains("pk/2cmt-oral")); + assert!(library.contains("pk/1cmt-iv-ode")); + assert!(library.contains("pk/1cmt-oral-ode")); + } + + #[test] + fn test_library_search() { + let library = ModelLibrary::builtin(); + + // Search by ID substring + let oral_models = library.search("oral"); + assert!(!oral_models.is_empty()); + assert!(oral_models.iter().all(|m| m.id.contains("oral"))); + } + + #[test] + fn test_library_filter_by_type() { + let library = ModelLibrary::builtin(); + + let analytical = library.filter_by_type(ModelType::Analytical); + let ode = library.filter_by_type(ModelType::Ode); + + assert!(!analytical.is_empty()); + assert!(!ode.is_empty()); + + // All filtered models should have correct type + assert!(analytical + .iter() + .all(|m| m.model_type == ModelType::Analytical)); + assert!(ode.iter().all(|m| m.model_type == ModelType::Ode)); + } + + #[test] + fn test_library_filter_by_tag() { + let library = ModelLibrary::builtin(); + + let oral_models = library.filter_by_tag("oral"); + assert!(!oral_models.is_empty()); + } + + #[test] + fn test_library_inheritance() { + let mut library = ModelLibrary::new(); + + // Add base model + let base = JsonModel::from_str( + r#"{ + "schema": "1.0", + "id": "base/pk-1cmt", + "type": "analytical", + "analytical": "one_compartment", + "parameters": ["ke", "V"], + "output": "x[0] / V", + "display": { + "name": "Base One-Compartment", + "category": "pk" + } + }"#, + ) + .unwrap(); + library.add(base); + + // Create derived model with weight covariate + let derived = JsonModel::from_str( + r#"{ + "schema": "1.0", + "id": "derived/pk-1cmt-wt", + "extends": "base/pk-1cmt", + "type": "analytical", + "analytical": "one_compartment", + "parameters": ["ke", "V"], + "covariates": [ + { "id": "WT", "reference": 70.0 } + ], + "covariateEffects": [ + { + "covariate": "WT", + "on": "V", + "type": "allometric", + "exponent": 0.75, + "reference": 70.0 + } + ] + }"#, + ) + .unwrap(); + + let resolved = library.resolve(&derived).unwrap(); + + // Should inherit output from base + assert!(resolved.output.is_some()); + assert_eq!(resolved.output.as_ref().unwrap(), "x[0] / V"); + + // Should have covariates from derived + assert!(resolved.covariates.is_some()); + assert!(resolved.covariate_effects.is_some()); + } + + #[test] + fn test_library_generates_code_from_model() { + let library = ModelLibrary::builtin(); + + let model = library.get("pk/1cmt-oral").unwrap(); + let generator = CodeGenerator::new(model); + let code = generator.generate().expect("Should generate code"); + + assert!(code + .equation_code + .contains("one_compartment_with_absorption")); + assert_eq!(code.parameters, vec!["ka", "ke", "V"]); + } +} + +// ═══════════════════════════════════════════════════════════════════════════════ +// End-to-End Tests +// ═══════════════════════════════════════════════════════════════════════════════ + +mod end_to_end { + use super::*; + + #[test] + fn test_full_pipeline_analytical() { + // 1. Define model in JSON + let json = r#"{ + "schema": "1.0", + "id": "e2e_1cmt", + "type": "analytical", + "analytical": "one_compartment_with_absorption", + "parameters": ["ka", "ke", "V"], + "output": "x[1] / V", + "display": { + "name": "E2E Test Model", + "category": "pk" + } + }"#; + + // 2. Parse + let model = parse_json(json).unwrap(); + assert_eq!(model.id, "e2e_1cmt"); + + // 3. Validate + let validator = Validator::new(); + let validated = validator.validate(&model).unwrap(); + + // 4. Generate code + let generator = CodeGenerator::new(validated.inner()); + let code = generator.generate().unwrap(); + + // 5. Verify code is valid Rust syntax (basic check) + assert!(code.equation_code.contains("Analytical::new")); + assert!(!code.equation_code.is_empty()); + assert_eq!(code.parameters.len(), 3); + } + + #[test] + fn test_full_pipeline_ode() { + let json = r#"{ + "schema": "1.0", + "id": "e2e_mm", + "type": "ode", + "parameters": ["Vmax", "Km", "V"], + "compartments": ["central"], + "diffeq": { + "central": "-Vmax * (central/V) / (Km + central/V)" + }, + "output": "central / V" + }"#; + + // Full pipeline + let code = generate_code(json).unwrap(); + + assert!(code.equation_code.contains("ODE::new")); + assert!(code.equation_code.contains("Vmax")); + assert!(code.equation_code.contains("Km")); + } + + #[test] + fn test_library_to_code_pipeline() { + let library = ModelLibrary::builtin(); + + // Get all models and verify they all generate valid code + for id in library.list() { + let model = library.get(id).unwrap(); + let generator = CodeGenerator::new(model); + let result = generator.generate(); + + assert!(result.is_ok(), "Failed to generate code for model: {}", id); + } + } +} + +// ═══════════════════════════════════════════════════════════════════════════════ +// EXA Compilation Tests (requires `exa` feature) +// ═══════════════════════════════════════════════════════════════════════════════ + +#[cfg(feature = "exa")] +mod exa_integration { + use approx::assert_relative_eq; + use pharmsol::json::compile_json; + use pharmsol::{equation, exa, Equation, Subject, SubjectBuilderExt, ODE}; + use pharmsol::{fa, fetch_params, lag}; + use std::path::PathBuf; + use std::sync::atomic::{AtomicUsize, Ordering}; + + // Unique counter for test file names + static TEST_COUNTER: AtomicUsize = AtomicUsize::new(0); + + fn unique_model_path(prefix: &str) -> PathBuf { + let count = TEST_COUNTER.fetch_add(1, Ordering::SeqCst); + let pid = std::process::id(); + std::env::current_dir() + .expect("Failed to get current directory") + .join(format!( + "{}_{}_{}_{}.pkm", + prefix, + pid, + count, + std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .unwrap() + .as_nanos() + )) + } + + /// Create a unique temp path for each test to avoid race conditions + fn unique_temp_path() -> PathBuf { + let count = TEST_COUNTER.fetch_add(1, Ordering::SeqCst); + let pid = std::process::id(); + std::env::temp_dir().join(format!("exa_test_{}_{}", pid, count)) + } + + #[test] + fn test_compile_json_ode_model() { + // Define a simple ODE model in JSON + let json = r#"{ + "schema": "1.0", + "id": "test_compiled_ode", + "type": "ode", + "parameters": ["ke", "V"], + "compartments": ["central"], + "diffeq": { + "central": "-ke * central + rateiv[0]" + }, + "output": "central / V" + }"#; + + let model_output_path = unique_model_path("test_json_compiled"); + let template_path = unique_temp_path(); + + // Compile using compile_json + let model_path = compile_json::( + json, + Some(model_output_path.clone()), + template_path.clone(), + |_, _| {}, // Empty callback for tests + ) + .expect("compile_json should succeed"); + + // Load the compiled model + let model_path = PathBuf::from(&model_path); + let (_lib, (dyn_ode, _meta)) = unsafe { exa::load::load::(model_path.clone()) }; + + // Create a test subject + let subject = Subject::builder("1") + .infusion(0.0, 500.0, 0, 0.5) + .observation(0.5, 1.5, 0) + .observation(1.0, 1.2, 0) + .observation(2.0, 0.5, 0) + .build(); + + // Test that the model produces predictions + let params = vec![1.0, 100.0]; // ke=1.0, V=100 + let predictions = dyn_ode.estimate_predictions(&subject, ¶ms); + assert!(predictions.is_ok(), "Should produce predictions"); + + let preds = predictions.unwrap().flat_predictions(); + assert_eq!(preds.len(), 3, "Should have 3 predictions"); + + // Predictions should be positive (concentrations) + for p in &preds { + assert!(*p > 0.0, "Concentration should be positive"); + } + + // Clean up + std::fs::remove_file(model_path).ok(); + std::fs::remove_dir_all(template_path).ok(); + } + + #[test] + fn test_compile_json_matches_handwritten_ode() { + // Define model in JSON + let json = r#"{ + "schema": "1.0", + "id": "compare_ode", + "type": "ode", + "parameters": ["ke", "V"], + "compartments": ["central"], + "diffeq": { + "central": "-ke * central + rateiv[0]" + }, + "output": "central / V" + }"#; + + // Compile JSON model + let model_output_path = unique_model_path("test_json_vs_handwritten"); + let template_path = unique_temp_path(); + + let model_path = compile_json::( + json, + Some(model_output_path.clone()), + template_path.clone(), + |_, _| {}, + ) + .expect("compile_json should succeed"); + + let model_path = PathBuf::from(&model_path); + let (_lib, (dyn_ode, _meta)) = unsafe { exa::load::load::(model_path.clone()) }; + + // Create equivalent handwritten ODE + let handwritten_ode = equation::ODE::new( + |x, p, _t, dx, _b, rateiv, _cov| { + fetch_params!(p, ke, _V); + dx[0] = -ke * x[0] + rateiv[0]; + }, + |_p, _t, _cov| lag! {}, + |_p, _t, _cov| fa! {}, + |_p, _t, _cov, _x| {}, + |x, p, _t, _cov, y| { + fetch_params!(p, _ke, V); + y[0] = x[0] / V; + }, + (1, 1), + ); + + // Test subject + let subject = Subject::builder("1") + .infusion(0.0, 500.0, 0, 0.5) + .observation(0.5, 1.645776, 0) + .observation(1.0, 1.216442, 0) + .observation(2.0, 0.4622729, 0) + .build(); + + let params = vec![1.02282724609375, 194.51904296875]; + + // Compare predictions + let json_preds = dyn_ode.estimate_predictions(&subject, ¶ms).unwrap(); + let hand_preds = handwritten_ode + .estimate_predictions(&subject, ¶ms) + .unwrap(); + + let json_flat = json_preds.flat_predictions(); + let hand_flat = hand_preds.flat_predictions(); + + assert_eq!(json_flat.len(), hand_flat.len()); + + for (json_val, hand_val) in json_flat.iter().zip(hand_flat.iter()) { + assert_relative_eq!(json_val, hand_val, max_relative = 1e-10, epsilon = 1e-10); + } + + // Clean up + std::fs::remove_file(model_path).ok(); + std::fs::remove_dir_all(template_path).ok(); + } + + #[test] + fn test_compile_json_library_model() { + use pharmsol::json::ModelLibrary; + + let library = ModelLibrary::builtin(); + + // Get an ODE model from the library + let model = library + .get("pk/1cmt-iv-ode") + .expect("Should have pk/1cmt-iv-ode"); + + // Convert back to JSON and compile + let json = serde_json::to_string(model).expect("Should serialize"); + + let model_output_path = unique_model_path("test_library_compiled"); + let template_path = unique_temp_path(); + + let model_path = compile_json::( + &json, + Some(model_output_path.clone()), + template_path.clone(), + |_, _| {}, + ) + .expect("compile_json should succeed for library model"); + + let model_path = PathBuf::from(&model_path); + + // Verify it loads + let (_lib, (dyn_ode, meta)) = unsafe { exa::load::load::(model_path.clone()) }; + + // Verify metadata + assert_eq!(meta.get_params(), &vec!["CL".to_string(), "V".to_string()]); + + // Test it produces valid predictions + let subject = Subject::builder("1") + .bolus(0.0, 100.0, 0) + .observation(1.0, 50.0, 0) + .build(); + + let params = vec![5.0, 10.0]; // CL=5, V=10 (ke = CL/V = 0.5) + let predictions = dyn_ode.estimate_predictions(&subject, ¶ms); + assert!(predictions.is_ok()); + + // Clean up + std::fs::remove_file(model_path).ok(); + std::fs::remove_dir_all(template_path).ok(); + } +}