Skip to content

Commit 072f954

Browse files
authored
Modifies cells to make function handles pickleable (#82)
1 parent dfef174 commit 072f954

File tree

3 files changed

+74
-55
lines changed

3 files changed

+74
-55
lines changed

RATapi/inputs.py

Lines changed: 47 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -279,6 +279,52 @@ def check_indices(problem: ProblemDefinition) -> None:
279279
)
280280

281281

282+
class FileHandles:
283+
"""Class to defer creation of custom file handles.
284+
285+
Parameters
286+
----------
287+
files : ClassList[CustomFile]
288+
A list of custom file models.
289+
"""
290+
291+
def __init__(self, files):
292+
self.index = 0
293+
self.files = [*files]
294+
295+
def __iter__(self):
296+
self.index = 0
297+
return self
298+
299+
def get_handle(self, index):
300+
"""Returns file handle for a given custom file.
301+
302+
Parameters
303+
----------
304+
index : int
305+
The index of the custom file.
306+
307+
"""
308+
custom_file = self.files[index]
309+
full_path = os.path.join(custom_file.path, custom_file.filename)
310+
if custom_file.language == Languages.Python:
311+
file_handle = get_python_handle(custom_file.filename, custom_file.function_name, custom_file.path)
312+
elif custom_file.language == Languages.Matlab:
313+
file_handle = RATapi.wrappers.MatlabWrapper(full_path).getHandle()
314+
elif custom_file.language == Languages.Cpp:
315+
file_handle = RATapi.wrappers.DylibWrapper(full_path, custom_file.function_name).getHandle()
316+
317+
return file_handle
318+
319+
def __next__(self):
320+
if self.index < len(self.files):
321+
custom_file = self.get_handle(self.index)
322+
self.index += 1
323+
return custom_file
324+
else:
325+
raise StopIteration
326+
327+
282328
def make_cells(project: RATapi.Project) -> Cells:
283329
"""Constructs the cells input required for the compiled RAT code.
284330
@@ -344,16 +390,6 @@ def make_cells(project: RATapi.Project) -> Cells:
344390
else:
345391
simulation_limits.append([0.0, 0.0])
346392

347-
file_handles = []
348-
for custom_file in project.custom_files:
349-
full_path = os.path.join(custom_file.path, custom_file.filename)
350-
if custom_file.language == Languages.Python:
351-
file_handles.append(get_python_handle(custom_file.filename, custom_file.function_name, custom_file.path))
352-
elif custom_file.language == Languages.Matlab:
353-
file_handles.append(RATapi.wrappers.MatlabWrapper(full_path).getHandle())
354-
elif custom_file.language == Languages.Cpp:
355-
file_handles.append(RATapi.wrappers.DylibWrapper(full_path, custom_file.function_name).getHandle())
356-
357393
# Populate the set of cells
358394
cells = Cells()
359395
cells.f1 = [[0, 1]] * len(project.contrasts) # This is marked as "to do" in RAT
@@ -369,7 +405,7 @@ def make_cells(project: RATapi.Project) -> Cells:
369405
cells.f11 = [param.name for param in project.bulk_in]
370406
cells.f12 = [param.name for param in project.bulk_out]
371407
cells.f13 = [param.name for param in project.resolution_parameters]
372-
cells.f14 = file_handles
408+
cells.f14 = FileHandles(project.custom_files)
373409
cells.f15 = [param.type for param in project.backgrounds]
374410
cells.f16 = [param.type for param in project.resolutions]
375411

cpp/rat.cpp

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -446,7 +446,7 @@ struct Cells {
446446
py::list f11;
447447
py::list f12;
448448
py::list f13;
449-
py::list f14;
449+
py::object f14;
450450
py::list f15;
451451
py::list f16;
452452
py::list f17;
@@ -844,12 +844,13 @@ coder::array<RAT::cell_wrap_6, 2U> pyListToRatCellWrap6(py::list values)
844844
return result;
845845
}
846846

847-
coder::array<RAT::cell_wrap_6, 2U> py_function_array_to_rat_cell_wrap_6(py::list values)
847+
coder::array<RAT::cell_wrap_6, 2U> py_function_array_to_rat_cell_wrap_6(py::object values)
848848
{
849+
auto handles = py::cast<py::list>(values);
849850
coder::array<RAT::cell_wrap_6, 2U> result;
850-
result.set_size(1, values.size());
851+
result.set_size(1, handles.size());
851852
int32_T idx {0};
852-
for (py::handle array: values)
853+
for (py::handle array: handles)
853854
{
854855
auto func = py::cast<py::function>(array);
855856
std::string func_ptr = convertPtr2String<CallbackInterface>(new Library(func));
@@ -1585,7 +1586,7 @@ PYBIND11_MODULE(rat_core, m) {
15851586
cell.f11 = t[10].cast<py::list>();
15861587
cell.f12 = t[11].cast<py::list>();
15871588
cell.f13 = t[12].cast<py::list>();
1588-
cell.f14 = t[13].cast<py::list>();
1589+
cell.f14 = t[13].cast<py::object>();
15891590
cell.f15 = t[14].cast<py::list>();
15901591
cell.f16 = t[15].cast<py::list>();
15911592
cell.f17 = t[16].cast<py::list>();

tests/test_inputs.py

Lines changed: 21 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -624,25 +624,7 @@ def test_make_input(test_project, test_problem, test_cells, test_limits, test_pr
624624
"domainRatio",
625625
]
626626

627-
mocked_matlab_future = mock.MagicMock()
628-
mocked_engine = mock.MagicMock()
629-
mocked_matlab_future.result.return_value = mocked_engine
630-
631-
with mock.patch.object(
632-
RATapi.wrappers.MatlabWrapper,
633-
"loader",
634-
mocked_matlab_future,
635-
), mock.patch.object(RATapi.rat_core, "DylibEngine", mock.MagicMock()), mock.patch.object(
636-
RATapi.inputs,
637-
"get_python_handle",
638-
mock.MagicMock(return_value=dummy_function),
639-
), mock.patch.object(
640-
RATapi.wrappers.MatlabWrapper,
641-
"getHandle",
642-
mock.MagicMock(return_value=dummy_function),
643-
), mock.patch.object(RATapi.wrappers.DylibWrapper, "getHandle", mock.MagicMock(return_value=dummy_function)):
644-
problem, cells, limits, priors, controls = make_input(test_project, RATapi.Controls())
645-
627+
problem, cells, limits, priors, controls = make_input(test_project, RATapi.Controls())
646628
problem = pickle.loads(pickle.dumps(problem))
647629
check_problem_equal(problem, test_problem)
648630
cells = pickle.loads(pickle.dumps(cells))
@@ -768,25 +750,7 @@ def test_make_cells(test_project, test_cells, request) -> None:
768750
"""The cells object should be populated according to the input project object."""
769751
test_project = request.getfixturevalue(test_project)
770752
test_cells = request.getfixturevalue(test_cells)
771-
772-
mocked_matlab_future = mock.MagicMock()
773-
mocked_engine = mock.MagicMock()
774-
mocked_matlab_future.result.return_value = mocked_engine
775-
with mock.patch.object(
776-
RATapi.wrappers.MatlabWrapper,
777-
"loader",
778-
mocked_matlab_future,
779-
), mock.patch.object(RATapi.rat_core, "DylibEngine", mock.MagicMock()), mock.patch.object(
780-
RATapi.inputs,
781-
"get_python_handle",
782-
mock.MagicMock(return_value=dummy_function),
783-
), mock.patch.object(
784-
RATapi.wrappers.MatlabWrapper,
785-
"getHandle",
786-
mock.MagicMock(return_value=dummy_function),
787-
), mock.patch.object(RATapi.wrappers.DylibWrapper, "getHandle", mock.MagicMock(return_value=dummy_function)):
788-
cells = make_cells(test_project)
789-
753+
cells = make_cells(test_project)
790754
check_cells_equal(cells, test_cells)
791755

792756

@@ -865,7 +829,25 @@ def check_cells_equal(actual_cells, expected_cells) -> None:
865829
"NaN" if np.isnan(el) else el for entry in actual_cells.f6 for el in entry
866830
] == ["NaN" if np.isnan(el) else el for entry in expected_cells.f6 for el in entry]
867831

868-
for index in chain(range(3, 6), range(7, 21)):
832+
mocked_matlab_future = mock.MagicMock()
833+
mocked_engine = mock.MagicMock()
834+
mocked_matlab_future.result.return_value = mocked_engine
835+
with mock.patch.object(
836+
RATapi.wrappers.MatlabWrapper,
837+
"loader",
838+
mocked_matlab_future,
839+
), mock.patch.object(RATapi.rat_core, "DylibEngine", mock.MagicMock()), mock.patch.object(
840+
RATapi.inputs,
841+
"get_python_handle",
842+
mock.MagicMock(return_value=dummy_function),
843+
), mock.patch.object(
844+
RATapi.wrappers.MatlabWrapper,
845+
"getHandle",
846+
mock.MagicMock(return_value=dummy_function),
847+
), mock.patch.object(RATapi.wrappers.DylibWrapper, "getHandle", mock.MagicMock(return_value=dummy_function)):
848+
assert list(actual_cells.f14) == expected_cells.f14
849+
850+
for index in chain(range(3, 6), range(7, 14), range(15, 21)):
869851
field = f"f{index}"
870852
assert getattr(actual_cells, field) == getattr(expected_cells, field)
871853

0 commit comments

Comments
 (0)