diff --git a/RATapi/inputs.py b/RATapi/inputs.py index e2baebf6..757aa78a 100644 --- a/RATapi/inputs.py +++ b/RATapi/inputs.py @@ -151,11 +151,23 @@ def make_problem(project: RATapi.Project) -> ProblemDefinition: The problem input used in the compiled RAT code. """ - hydrate_id = {"bulk in": 1, "bulk out": 2} prior_id = {"uniform": 1, "gaussian": 2, "jeffreys": 3} - # Ensure backgrounds and resolutions have a source defined + # Ensure all contrast fields are properly defined for contrast in project.contrasts: + contrast_fields = ["data", "background", "bulk_in", "bulk_out", "scalefactor", "resolution"] + + if project.calculation == Calculations.Domains: + contrast_fields.append("domain_ratio") + + for field in contrast_fields: + if getattr(contrast, field) == "": + raise ValueError( + f'In the input project, the "{field}" field of contrast "{contrast.name}" does not have a ' + f"value defined. A value must be supplied before running the project." + ) + + # Ensure backgrounds and resolutions have a source defined background = project.backgrounds[contrast.background] resolution = project.resolutions[contrast.resolution] if background.source == "": @@ -191,22 +203,7 @@ def make_problem(project: RATapi.Project) -> ProblemDefinition: contrast_custom_files = [project.custom_files.index(contrast.model[0], True) for contrast in project.contrasts] # Get details of defined layers - layer_details = [] - for layer in project.layers: - if project.absorption: - layer_params = [ - project.parameters.index(getattr(layer, attribute), True) - for attribute in list(RATapi.models.AbsorptionLayer.model_fields.keys())[1:-2] - ] - else: - layer_params = [ - project.parameters.index(getattr(layer, attribute), True) - for attribute in list(RATapi.models.Layer.model_fields.keys())[1:-2] - ] - layer_params.append(project.parameters.index(layer.hydration, True) if layer.hydration else float("NaN")) - layer_params.append(hydrate_id[layer.hydrate_with]) - - layer_details.append(layer_params) + layer_details = get_layer_details(project) contrast_background_params = [] contrast_background_types = [] @@ -387,6 +384,35 @@ def make_problem(project: RATapi.Project) -> ProblemDefinition: return problem +def get_layer_details(project: RATapi.Project) -> list[int]: + """Get parameter indices for all layers defined in the project.""" + hydrate_id = {"bulk in": 1, "bulk out": 2} + layer_details = [] + + # Get the thickness, SLD, roughness fields from the appropriate model + if project.absorption: + layer_fields = list(RATapi.models.AbsorptionLayer.model_fields.keys())[1:-2] + else: + layer_fields = list(RATapi.models.Layer.model_fields.keys())[1:-2] + + for layer in project.layers: + for field in layer_fields: + if getattr(layer, field) == "": + raise ValueError( + f'In the input project, the "{field}" field of layer {layer.name} does not have a value ' + f"defined. A value must be supplied before running the project." + ) + + layer_params = [project.parameters.index(getattr(layer, attribute), True) for attribute in list(layer_fields)] + + layer_params.append(project.parameters.index(layer.hydration, True) if layer.hydration else float("NaN")) + layer_params.append(hydrate_id[layer.hydrate_with]) + + layer_details.append(layer_params) + + return layer_details + + def make_resample(project: RATapi.Project) -> list[int]: """Construct the "resample" field of the problem input required for the compiled RAT code. diff --git a/tests/test_inputs.py b/tests/test_inputs.py index e7efd7b4..d3ed276e 100644 --- a/tests/test_inputs.py +++ b/tests/test_inputs.py @@ -572,6 +572,69 @@ def test_background_params_value_indices(self, test_problem, bad_value, request) check_indices(test_problem) +@pytest.mark.parametrize("test_project", ["standard_layers_project", "custom_xy_project", "domains_project"]) +@pytest.mark.parametrize("field", ["data", "background", "bulk_in", "bulk_out", "scalefactor", "resolution"]) +def test_undefined_contrast_fields(test_project, field, request): + """If a field in a contrast is empty, we should raise an error.""" + test_project = request.getfixturevalue(test_project) + setattr(test_project.contrasts[0], field, "") + + with pytest.raises( + ValueError, + match=f'In the input project, the "{field}" field of contrast ' + f'"{test_project.contrasts[0].name}" does not have a value defined. ' + f"A value must be supplied before running the project.", + ): + make_problem(test_project) + + +@pytest.mark.parametrize("test_project", ["standard_layers_project", "custom_xy_project", "domains_project"]) +def test_undefined_background(test_project, request): + """If the source field of a background defined in a contrast is empty, we should raise an error.""" + test_project = request.getfixturevalue(test_project) + background = test_project.backgrounds[test_project.contrasts[0].background] + background.source = "" + + with pytest.raises( + ValueError, + match=f"All backgrounds must have a source defined. For a {background.type} type " + f"background, the source must be defined in " + f'"{RATapi.project.values_defined_in[f"backgrounds.{background.type}.source"]}"', + ): + make_problem(test_project) + + +@pytest.mark.parametrize("test_project", ["standard_layers_project", "custom_xy_project", "domains_project"]) +def test_undefined_resolution(test_project, request): + """If the source field of a resolution defined in a contrast is empty, we should raise an error.""" + test_project = request.getfixturevalue(test_project) + resolution = test_project.resolutions[test_project.contrasts[0].resolution] + resolution.source = "" + + with pytest.raises( + ValueError, + match=f"Constant resolutions must have a source defined. The source must be defined in " + f'"{RATapi.project.values_defined_in[f"resolutions.{resolution.type}.source"]}"', + ): + make_problem(test_project) + + +@pytest.mark.parametrize("test_project", ["standard_layers_project", "domains_project"]) +@pytest.mark.parametrize("field", ["thickness", "SLD", "roughness"]) +def test_undefined_layers(test_project, field, request): + """If the thickness, SLD, or roughness fields of a layer defined in the project are empty, we should raise an + error.""" + test_project = request.getfixturevalue(test_project) + setattr(test_project.layers[0], field, "") + + with pytest.raises( + ValueError, + match=f'In the input project, the "{field}" field of layer {test_project.layers[0].name} ' + f"does not have a value defined. A value must be supplied before running the project.", + ): + make_problem(test_project) + + def test_append_data_background(): """Test that background data is correctly added to contrast data.""" data = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]])