Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .github/workflows/ci-tests-drafts.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,9 @@ jobs:
pip install -e .
pip install -e .[test]
pip install pytest pytest-cov
- name: Register Jupyter Kernel
run: |
python -m ipykernel install --user --name python3
Comment on lines +30 to +32
Copy link
Collaborator

@f-allian f-allian Dec 12, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jmafoster1 What do we need this for? Is it for the Jupyter Notebook test?

Copy link
Collaborator Author

@jmafoster1 jmafoster1 Dec 12, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Otherwise it says "no kernel found" when you try to run the notebooks

- name: Test with pytest
run: |
pytest --cov=causal_testing --cov-report=xml
Expand Down
3 changes: 3 additions & 0 deletions .github/workflows/ci-tests.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,9 @@ jobs:
pip install -e .
pip install -e .[test]
pip install pytest pytest-cov
- name: Register Jupyter Kernel
run: |
python -m ipykernel install --user --name python3
- name: Test with pytest
run: |
pytest --cov=causal_testing --cov-report=xml
Expand Down
8 changes: 7 additions & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -37,4 +37,10 @@ repos:
language: system
types: [python]
args: ['--rcfile=.pylintrc', '--max-line-length=120', '--max-positional-arguments=12', '--disable=W1401']
files: ^causal_testing/
files: ^causal_testing/

- repo: https://github.com/jsh9/pydoclint
rev: 0.8.3
hooks:
- id: pydoclint
args: [--style=google, --check-return-types=False]
31 changes: 14 additions & 17 deletions causal_testing/estimation/ipcw_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,20 @@ class IPCWEstimator(Estimator):
"""
Class to perform Inverse Probability of Censoring Weighting (IPCW) estimation
for sequences of treatments over time-varying data.

:param: df: Input DataFrame containing time-varying data.
:param: timesteps_per_observation: Number of timesteps per observation.
:param: control_strategy: The control strategy, with entries of the form (timestep, variable, value).
:param: treatment_strategy: The treatment strategy, with entries of the form (timestep, variable, value).
:param: outcome: Name of the outcome column in the DataFrame.
:param: status_column: Name of the status column in the DataFrame, which should be True for operating normally,
False for a fault.
:param: fit_bl_switch_formula: Formula for fitting the baseline switch model.
:param: fit_bltd_switch_formula: Formula for fitting the baseline time-dependent switch model.
:param: eligibility: Function to determine eligibility for treatment. Defaults to None for "always eligible".
:param: alpha: Significance level for hypothesis testing. Defaults to 0.05.
:param: total_time: Total time for the analysis. Defaults to one plus the length of of the strategy (control or
treatment) with the most elements multiplied by `timesteps_per_observation`.
"""

# pylint: disable=too-many-arguments
Expand All @@ -40,23 +54,6 @@ def __init__(
alpha: float = 0.05,
total_time: float = None,
):
"""
Initialise IPCWEstimator.

:param: df: Input DataFrame containing time-varying data.
:param: timesteps_per_observation: Number of timesteps per observation.
:param: control_strategy: The control strategy, with entries of the form (timestep, variable, value).
:param: treatment_strategy: The treatment strategy, with entries of the form (timestep, variable, value).
:param: outcome: Name of the outcome column in the DataFrame.
:param: status_column: Name of the status column in the DataFrame, which should be True for operating normally,
False for a fault.
:param: fit_bl_switch_formula: Formula for fitting the baseline switch model.
:param: fit_bltd_switch_formula: Formula for fitting the baseline time-dependent switch model.
:param: eligibility: Function to determine eligibility for treatment. Defaults to None for "always eligible".
:param: alpha: Significance level for hypothesis testing. Defaults to 0.05.
:param: total_time: Total time for the analysis. Defaults to one plus the length of of the strategy (control or
treatment) with the most elements multiplied by `timesteps_per_observation`.
"""
super().__init__(
base_test_case=BaseTestCase(None, outcome),
treatment_value=[val for _, _, val in treatment_strategy],
Expand Down
32 changes: 12 additions & 20 deletions causal_testing/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,13 @@
from causal_testing.estimation.linear_regression_estimator import LinearRegressionEstimator
from causal_testing.estimation.logistic_regression_estimator import LogisticRegressionEstimator
from causal_testing.specification.causal_dag import CausalDAG
from causal_testing.specification.causal_specification import CausalSpecification
from causal_testing.specification.scenario import Scenario
from causal_testing.specification.variable import Input, Output
from causal_testing.testing.base_test_case import BaseTestCase
from causal_testing.testing.causal_effect import Negative, NoEffect, Positive, SomeEffect
from causal_testing.testing.causal_test_adequacy import DataAdequacy
from causal_testing.testing.causal_test_case import CausalTestCase
from causal_testing.testing.causal_test_result import CausalTestResult
from causal_testing.testing.causal_test_adequacy import DataAdequacy

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -106,7 +105,6 @@ def __init__(self, paths: CausalTestingPaths, ignore_cycles: bool = False, query
self.data: Optional[pd.DataFrame] = None
self.variables: Dict[str, Any] = {"inputs": {}, "outputs": {}, "metas": {}}
self.scenario: Optional[Scenario] = None
self.causal_specification: Optional[CausalSpecification] = None
self.test_cases: Optional[List[CausalTestCase]] = None

def setup(self) -> None:
Expand All @@ -130,8 +128,11 @@ def setup(self) -> None:
# Create variables from DAG
self.create_variables()

# Create scenario and specification
self.create_scenario_and_specification()
# Create scenario
self.scenario = Scenario(
list(self.variables["inputs"].values()) + list(self.variables["outputs"].values()),
{self.query} if self.query else None,
)

logger.info("Setup completed successfully")

Expand Down Expand Up @@ -187,18 +188,6 @@ def create_variables(self) -> None:
if self.dag.in_degree(node_name) > 0:
self.variables["outputs"][node_name] = Output(name=node_name, datatype=dtype)

def create_scenario_and_specification(self) -> None:
"""Create scenario and causal specification objects from loaded data."""
# Create scenario
all_variables = list(self.variables["inputs"].values()) + list(self.variables["outputs"].values())
self.scenario = Scenario(variables=all_variables)

# Set up treatment variables
self.scenario.setup_treatment_variables()

# Create causal specification
self.causal_specification = CausalSpecification(scenario=self.scenario, causal_dag=self.dag)

def load_tests(self) -> None:
"""
Load and prepare test configurations from file.
Expand Down Expand Up @@ -316,7 +305,10 @@ def create_causal_test(self, test: dict, base_test: BaseTestCase) -> CausalTestC
base_test_case=base_test,
treatment_value=test.get("treatment_value"),
control_value=test.get("control_value"),
adjustment_set=test.get("adjustment_set", self.causal_specification.causal_dag.identification(base_test)),
adjustment_set=test.get(
"adjustment_set",
self.dag.identification(base_test, self.scenario.hidden_variables()),
),
df=filtered_df,
effect_modifiers=None,
formula=test.get("formula"),
Expand Down Expand Up @@ -346,7 +338,7 @@ def run_tests_in_batches(
:param silent: Whether to suppress errors
:param adequacy: Whether to calculate causal test adequacy (defaults to False)
:param bootstrap_size: The number of bootstrap samples to use when calculating causal test adequacy
(defaults to 100)
(defaults to 100)
:return: List of all test results
:raises: ValueError if no tests are loaded
"""
Expand Down Expand Up @@ -403,7 +395,7 @@ def run_tests(
:param silent: Whether to suppress errors
:param adequacy: Whether to calculate causal test adequacy (defaults to False)
:param bootstrap_size: The number of bootstrap samples to use when calculating causal test adequacy
(defaults to 100)
(defaults to 100)

:return: List of CausalTestResult objects
:raises: ValueError if no tests are loaded
Expand Down
39 changes: 7 additions & 32 deletions causal_testing/specification/causal_dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,7 @@

from causal_testing.testing.base_test_case import BaseTestCase

from .scenario import Scenario
from .variable import Output
from .variable import Variable

Node = Union[str, int] # Node type hint: A node is a string or an int

Expand Down Expand Up @@ -489,38 +488,12 @@ def get_backdoor_graph(self, treatments: list[str]) -> CausalDAG:
backdoor_graph.add_edges_from(filter(lambda x: x not in outgoing_edges, self.edges))
return backdoor_graph

def depends_on_outputs(self, node: Node, scenario: Scenario) -> bool:
"""Check whether a given node in a given scenario is or depends on a
model output in the given scenario. That is, whether or not the model
needs to be run to determine its value.

NOTE: The graph must be acyclic for this to terminate.

:param node: The node in the DAG representing the variable of interest.
:param scenario: The modelling scenario.

:return: Whether the given variable is or depends on an output.
"""
if isinstance(scenario.variables[node], Output):
return True
return any((self.depends_on_outputs(n, scenario) for n in self.predecessors(node)))

@staticmethod
def remove_hidden_adjustment_sets(minimal_adjustment_sets: list[str], scenario: Scenario):
"""Remove variables labelled as hidden from adjustment set(s)

:param minimal_adjustment_sets: list of minimal adjustment set(s) to have hidden variables removed from
:param scenario: The modelling scenario which informs the variables that are hidden
"""
return [adj for adj in minimal_adjustment_sets if all(not scenario.variables.get(x).hidden for x in adj)]

def identification(self, base_test_case: BaseTestCase, scenario: Scenario = None):
def identification(self, base_test_case: BaseTestCase, avoid_variables: set[Variable] = None):
"""Identify and return the minimum adjustment set

:param base_test_case: A base test case instance containing the outcome_variable and the
treatment_variable required for identification.
:param scenario: The modelling scenario relating to the tests

:param avoid_variables: Variables not to be adjusted for (e.g. hidden variables).
:return: The smallest set of variables which can be adjusted for to obtain a causal
estimate as opposed to a purely associational estimate.
"""
Expand All @@ -539,8 +512,10 @@ def identification(self, base_test_case: BaseTestCase, scenario: Scenario = None
else:
raise ValueError("Causal effect should be 'total' or 'direct'")

if scenario is not None:
minimal_adjustment_sets = self.remove_hidden_adjustment_sets(minimal_adjustment_sets, scenario)
if avoid_variables is not None:
minimal_adjustment_sets = [
adj for adj in minimal_adjustment_sets if not {x.name for x in avoid_variables}.intersection(adj)
]

minimal_adjustment_set = min(minimal_adjustment_sets, key=len, default=set())
return set(minimal_adjustment_set)
Expand Down
22 changes: 0 additions & 22 deletions causal_testing/specification/causal_specification.py

This file was deleted.

Loading