diff --git a/.gitignore b/.gitignore index a8de19a..7882e9a 100644 --- a/.gitignore +++ b/.gitignore @@ -39,3 +39,10 @@ build/* docs/html/* examples/out/* examples/out_json_ref/* + +# python eggs +python_src/*.egg-info +**/__pycache__ +**/*.pyc + + diff --git a/.gitmodules b/.gitmodules index 0bc74f1..05dfe02 100644 --- a/.gitmodules +++ b/.gitmodules @@ -1,9 +1,12 @@ -[submodule "Catch2"] - path = submodules/Catch2 - url = https://github.com/catchorg/Catch2 -[submodule "msgpack-c"] +[submodule "submodules/pybind11"] + path = submodules/pybind11 + url = https://github.com/pybind/pybind11.git +[submodule "submodules/msgpack-c"] path = submodules/msgpack-c - url = https://github.com/msgpack/msgpack-c -[submodule "mmtf_spec"] + url = https://github.com/msgpack/msgpack-c.git +[submodule "submodules/Catch2"] + path = submodules/Catch2 + url = https://github.com/catchorg/Catch2.git +[submodule "submodules/mmtf_spec"] path = submodules/mmtf_spec - url = https://github.com/rcsb/mmtf + url = https://github.com/rcsb/mmtf.git diff --git a/.travis.yml b/.travis.yml index 144daea..70fd8ea 100644 --- a/.travis.yml +++ b/.travis.yml @@ -1,3 +1,4 @@ +--- language: cpp sudo: false dist: trusty @@ -27,8 +28,25 @@ linux64_cpp17addons: sources: - ubuntu-toolchain-r-test +linux64_cpp17addons: + addons: &linux64cpp17 + apt: + sources: + - ubuntu-toolchain-r-test + + +linux64_cpp17addons_py: + addons: &linux64cpp17py + apt: + sources: + - ubuntu-toolchain-r-test + - gcc-7 + - g++-7 + +python_test_command_sub: &python_test_command TEST_COMMAND=$TRAVIS_BUILD_DIR/ci/build_and_run_python_tests.sh CC=gcc # Set empty values for allow_failures to work env: TEST_COMMAND=$TRAVIS_BUILD_DIR/ci/build_and_run_tests.sh +python: matrix: fast_finish: true @@ -54,7 +72,27 @@ matrix: compiler: gcc addons: *linux64cpp17 dist: bionic - + - os: linux + compiler: gcc + addons: *linux64cpp17py + dist: bionic + env: *python_test_command + python: 3.8 + language: python + - os: linux + compiler: gcc + addons: *linux64cpp17py + dist: bionic + env: *python_test_command + python: 3.11 + language: python + - os: linux + compiler: gcc + addons: *linux64cpp17py + dist: bionic + env: *python_test_command + python: 3.12 + language: python before_install: # Setting environement diff --git a/CMakeLists.txt b/CMakeLists.txt index 7acffed..696715e 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -19,9 +19,7 @@ if (mmtf_build_local) add_library(msgpackc INTERFACE) target_include_directories(msgpackc INTERFACE ${MSGPACKC_INCLUDE_DIR}) if (BUILD_TESTS) - set(CATCH_INCLUDE_DIR ${CMAKE_CURRENT_SOURCE_DIR}/submodules/Catch2/single_include) - add_library(Catch INTERFACE) - target_include_directories(Catch INTERFACE ${CATCH_INCLUDE_DIR}) + add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/submodules/Catch2) endif() endif() @@ -40,6 +38,34 @@ if (mmtf_build_examples) add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/examples) endif() + +if (build_py) + + if(NOT CMAKE_BUILD_TYPE) + set(CMAKE_BUILD_TYPE Release) + endif() + set(CMAKE_CXX_FLAGS "-Wall -Wextra") + set(CMAKE_CXX_FLAGS_DEBUG "-g") + set(CMAKE_CXX_FLAGS_RELEASE "-O3") + add_library(mmtf_bindings SHARED + ${CMAKE_CURRENT_SOURCE_DIR}/python_src/bindings.cpp + ) + + set(MSGPACKC_INCLUDE_DIR ${CMAKE_CURRENT_SOURCE_DIR}/submodules/msgpack-c/include) + set(PYBIND_INCLUDE_DIR ${CMAKE_CURRENT_SOURCE_DIR}/submodules/pybind11/include) + + target_include_directories(mmtf_bindings PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}/include) + target_include_directories(mmtf_bindings PUBLIC ${MSGPACKC_INCLUDE_DIR}) + target_include_directories(mmtf_bindings PUBLIC ${PYBIND_INCLUDE_DIR}) + target_include_directories(mmtf_bindings PUBLIC ${python_include_A}) + target_include_directories(mmtf_bindings PUBLIC ${python_include_B}) + + set_target_properties(mmtf_bindings PROPERTIES POSITION_INDEPENDENT_CODE ON) + set_target_properties(mmtf_bindings PROPERTIES PREFIX "") + set_target_properties(mmtf_bindings PROPERTIES SUFFIX ".so") + set_target_properties(mmtf_bindings PROPERTIES CXX_STANDARD 17) +endif() + install( DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/include/ DESTINATION "include" diff --git a/README.md b/README.md index 5679e52..03a4a3b 100644 --- a/README.md +++ b/README.md @@ -45,6 +45,53 @@ Here, `` and `` are the paths to the For your more complicated projects, a `CMakeLists.txt` is included for you. + +### Python bindings + +The C++ MMTF library now can build python bindings using pybind11. To use them +you must have A) a c++11 compatible compiler and B) python >= 3.6 + +to install, it is as simple as `pip install .` + +(in the future possible `pip install mmtf-cpp`) + +```python +from mmtf_cppy import StructureData +import numpy as np +import math + + +def rotation_matrix(axis, theta): + """ + Return the rotation matrix associated with counterclockwise rotation about + the given axis by theta radians. + from https://stackoverflow.com/a/6802723 + """ + axis = np.asarray(axis) + axis = axis / math.sqrt(np.dot(axis, axis)) + a = math.cos(theta / 2.0) + b, c, d = -axis * math.sin(theta / 2.0) + aa, bb, cc, dd = a * a, b * b, c * c, d * d + bc, ad, ac, ab, bd, cd = b * c, a * d, a * c, a * b, b * d, c * d + return np.array([[aa + bb - cc - dd, 2 * (bc + ad), 2 * (bd - ac)], + [2 * (bc - ad), aa + cc - bb - dd, 2 * (cd + ab)], + [2 * (bd + ac), 2 * (cd - ab), aa + dd - bb - cc]]) + + +theta = 1.2 +axis = [0, 0, 1] + +sd = StructureData("my_favorite_structure.mmtf") +sd.atomProperties["pymol_colorList"] = [1 if x % 2 == 0 else 5 for x in sd.xCoordList] +xyz = np.column_stack((sd.xCoordList, sd.yCoordList, sd.zCoordList)) +xyz_rot = rotation_matrix(axis, theta).dot(xyz.T).T +sd.xCoordList, sd.yCoordList, sd.zCoordList = np.hsplit(xyz_rot, 3) +sd.write_to_file("my_favorite_structure_rot.mmtf") + +``` + + + ## Installation You can also perform a system wide installation with `cmake` and `ninja` (or `make`). To do so: diff --git a/ci/build_and_run_python_tests.sh b/ci/build_and_run_python_tests.sh new file mode 100755 index 0000000..e2704c9 --- /dev/null +++ b/ci/build_and_run_python_tests.sh @@ -0,0 +1,10 @@ + +python3 --version +pip3 --version +pip3 install -r requirements.txt +pip3 install -r requirements-dev.txt +cd $TRAVIS_BUILD_DIR +pip3 install . +pytest python_src/tests/tests.py -s -vv + + diff --git a/ci/build_and_run_tests.sh b/ci/build_and_run_tests.sh index de8e67c..c56538b 100755 --- a/ci/build_and_run_tests.sh +++ b/ci/build_and_run_tests.sh @@ -1,5 +1,7 @@ +#!/usr/bin/env bash + set -e -cd $TRAVIS_BUILD_DIR +cd "$TRAVIS_BUILD_DIR" mkdir build && cd build $CMAKE_CONFIGURE cmake $CMAKE_ARGS $CMAKE_EXTRA .. make -j2 diff --git a/include/mmtf/structure_data.hpp b/include/mmtf/structure_data.hpp index 0d37e32..c16cfe6 100644 --- a/include/mmtf/structure_data.hpp +++ b/include/mmtf/structure_data.hpp @@ -163,7 +163,7 @@ struct StructureData { std::string title; std::string depositionDate; std::string releaseDate; - std::vector > ncsOperatorList; + std::vector> ncsOperatorList; std::vector bioAssemblyList; std::vector entityList; std::vector experimentalMethods; diff --git a/python_src/bindings.cpp b/python_src/bindings.cpp new file mode 100644 index 0000000..55914d2 --- /dev/null +++ b/python_src/bindings.cpp @@ -0,0 +1,515 @@ + +#include + +#include +#include +#include +#include +#include +#include + +#include + +namespace py = pybind11; + +/// CPP -> PY FUNCTIONS + +/* Notes + * We destory original data because it is much faster to apply move + * than it is to copy the data. + */ + +// This destroys the original data +template< typename T > +py::array +array1d_from_vector(std::vector & m) { + if (m.empty()) return py::array_t(); + std::vector* ptr = new std::vector(std::move(m)); + auto capsule = py::capsule(ptr, [](void* p) { + delete reinterpret_cast*>(p); + }); + return py::array_t( + ptr->size(), // shape of array + ptr->data(), // c-style contiguous strides for Sequence + capsule // numpy array references this parent + ); +} + + +template<> +py::array +array1d_from_vector(std::vector & m) { + //if (m.empty()) return py::array_t(); + std::vector* ptr = new std::vector(std::move(m)); + auto capsule = py::capsule(ptr, [](void* p) { + delete reinterpret_cast*>(p); + }); + return py::array( + py::dtype("size()}, // shape of array + {}, + ptr->data(), // c-style contiguous strides for Sequence + capsule // numpy array references this parent + ); +} + +template< > +py::array +array1d_from_vector(std::vector & m) { + return py::array(py::cast(std::move(m))); +} + +template +std::vector +flatten2D(std::vector> const & v) { + std::size_t total_size = 0; + for (auto const & x : v) + total_size += x.size(); + std::vector result; + result.reserve(total_size); + for (auto const & subv : v) + result.insert(result.end(), subv.begin(), subv.end()); + return result; +} + + +// would be nice if this was faster +template< typename T > +py::array +array2D_from_vector(std::vector> const & m) { + if (m.empty()) return py::array_t(); + std::vector* ptr = new std::vector(flatten2D(m)); + auto capsule = py::capsule(ptr, [](void* p) { + delete reinterpret_cast*>(p); + }); + return py::array_t( + {m.size(), m.at(0).size()}, // shape of array + {m.at(0).size()*sizeof(T), sizeof(T)}, // c-style contiguous strides + ptr->data(), + capsule); +} + +// This destroys the original data +py::list +dump_bio_assembly_list(mmtf::StructureData & sd) { + py::object py_ba_class = py::module::import("mmtf_cppy").attr("BioAssembly"); + py::object py_t_class = py::module::import("mmtf_cppy").attr("Transform"); + py::list bal; + for (mmtf::BioAssembly & cba : sd.bioAssemblyList) { + py::list transform_list; + for (mmtf::Transform & trans : cba.transformList) { + std::vector matrix(std::begin(trans.matrix), std::end(trans.matrix)); + transform_list.append( + py_t_class( + array1d_from_vector(trans.chainIndexList), + array1d_from_vector(matrix) + ) + ); + } + bal.append( + py_ba_class( + transform_list, + py::str(cba.name) + ) + ); + } + return bal; +} + +// This destroys the original data +py::list +dump_entity_list(std::vector & cpp_el) { + py::object entity = py::module::import("mmtf_cppy").attr("Entity"); + py::list el; + for (mmtf::Entity & e : cpp_el) { + el.append( + entity( + array1d_from_vector(e.chainIndexList), + e.description, + e.type, + e.sequence) + ); + } + return el; +} + +py::bytes +raw_properties(mmtf::StructureData const & sd) { + std::stringstream bytes; + std::map< std::string, std::map< std::string, msgpack::object > > objs({ + {"bondProperties", sd.bondProperties}, + {"atomProperties", sd.atomProperties}, + {"groupProperties", sd.groupProperties}, + {"chainProperties", sd.chainProperties}, + {"modelProperties", sd.modelProperties}, + {"extraProperties", sd.extraProperties}}); + msgpack::pack(bytes, objs); + return py::bytes(bytes.str()); +} + + +std::vector +make_transformList(py::list const & l) { + std::vector tl; + for (auto const & trans : l) { + mmtf::Transform t; + t.chainIndexList = trans.attr("chainIndexList").cast>(); + py::list pymatrix(trans.attr("matrix")); + std::size_t count(0); + for (auto const & x : pymatrix) { + t.matrix[count] = x.cast(); + ++count; + } + tl.push_back(t); + } + return tl; +} + + +void +set_bioAssemblyList(py::list const & obj, mmtf::StructureData & sd) { + std::vector bioAs; + for (auto const & py_bioAssembly : obj ) { + mmtf::BioAssembly bioA; + bioA.name = py::str(py_bioAssembly.attr("name")); + py::list py_transform_list(py_bioAssembly.attr("transformList")); + std::vector transform_list = make_transformList(py_transform_list); + bioA.transformList = transform_list; + bioAs.push_back(bioA); + } + sd.bioAssemblyList = bioAs; +} + + +void +set_entityList(py::list const & obj, mmtf::StructureData & sd) { + std::vector entities; + for (auto const & py_entity : obj ) { + mmtf::Entity entity; + entity.chainIndexList = py_entity.attr("chainIndexList").cast>(); + entity.description = py_entity.attr("description").cast(); + entity.type = py_entity.attr("type").cast(); + entity.sequence = py_entity.attr("sequence").cast(); + entities.push_back(entity); + } + sd.entityList = entities; +} + + +void +set_groupList(py::list const & obj, mmtf::StructureData & sd) { + std::vector groups; + for (auto const & py_group : obj ) { + mmtf::GroupType group; + group.formalChargeList = py_group.attr("formalChargeList").cast>(); + group.atomNameList = py_group.attr("atomNameList").cast>(); + group.elementList = py_group.attr("elementList").cast>(); + group.bondAtomList = py_group.attr("bondAtomList").cast>(); + group.bondOrderList = py_group.attr("bondOrderList").cast>(); + group.bondResonanceList = py_group.attr("bondResonanceList").cast>(); + group.groupName = py_group.attr("groupName").cast(); + group.singleLetterCode = py_group.attr("singleLetterCode").cast(); + group.chemCompType = py_group.attr("chemCompType").cast(); + groups.push_back(group); + } + sd.groupList = groups; +} + + +// This destroys the original data +py::list +dump_group_list(std::vector & gtl) { + py::object py_gt_class = py::module::import("mmtf_cppy").attr("GroupType"); + py::list gl; + for (mmtf::GroupType & gt : gtl) { + gl.append( + py_gt_class( + array1d_from_vector(gt.formalChargeList), + gt.atomNameList, + gt.elementList, + array1d_from_vector(gt.bondAtomList), + array1d_from_vector(gt.bondOrderList), + array1d_from_vector(gt.bondResonanceList), + gt.groupName, + std::string(1, gt.singleLetterCode), + gt.chemCompType + ) + ); + } + return gl; +} + +template< typename T> +std::vector +py_array_to_vector(py::array_t const & array_in) { + std::vector vec_array(array_in.size()); + std::memcpy(vec_array.data(), array_in.data(), array_in.size()*sizeof(T)); + return vec_array; +} + +template<> +std::vector +py_array_to_vector(py::array_t const & array_in) { + std::string tmpstr(array_in.data(), array_in.size()); + std::vector vec_array(tmpstr.begin(), tmpstr.end()); + return vec_array; +} + +/* This isn't really necessary, but lets make the interface anyway + */ +py::bytes +py_encodeInt8ToByte(py::array_t const & array_in) { + std::vector cpp_vec(mmtf::encodeInt8ToByte(py_array_to_vector(array_in))); + return py::bytes(std::string(cpp_vec.begin(), cpp_vec.end())); +} + +py::bytes +py_encodeFourByteInt(py::array_t const & array_in) { + std::vector const cpp_vec(py_array_to_vector(array_in)); + std::vector encoded(mmtf::encodeFourByteInt(cpp_vec)); + return py::bytes(encoded.data(), encoded.size()); +} + +py::bytes +py_encodeRunLengthChar(py::array_t const & array_in) { + std::vector const cpp_vec(py_array_to_vector(array_in)); + std::vector encoded(mmtf::encodeRunLengthChar(cpp_vec)); + return py::bytes(encoded.data(), encoded.size()); +} + +py::bytes +py_encodeRunLengthDeltaInt(py::array_t const & array_in) { + std::vector const cpp_vec(py_array_to_vector(array_in)); + std::vector encoded(mmtf::encodeRunLengthDeltaInt(cpp_vec)); + return py::bytes(encoded.data(), encoded.size()); +} + +py::bytes +py_encodeDeltaRecursiveFloat(py::array_t const & array_in, int32_t const multiplier = 1000) { + std::vector const cpp_vec(py_array_to_vector(array_in)); + std::vector encoded(mmtf::encodeDeltaRecursiveFloat(cpp_vec, multiplier)); + return py::bytes(encoded.data(), encoded.size()); +} + +py::bytes +py_encodeRunLengthFloat(py::array_t const & array_in, int32_t const multiplier = 1000) { + std::vector const cpp_vec(py_array_to_vector(array_in)); + std::vector encoded(mmtf::encodeRunLengthFloat(cpp_vec, multiplier)); + return py::bytes(encoded.data(), encoded.size()); +} + + + +py::bytes +py_encodeRunLengthInt8(py::array_t const & array_in) { + std::vector const cpp_vec(py_array_to_vector(array_in)); + std::vector encoded(mmtf::encodeRunLengthInt8(cpp_vec)); + return py::bytes(encoded.data(), encoded.size()); +} + +// TODO pyarray POD types to numpy array? seems hard +//py::bytes +//py_encodeStringVector4(py::array_t> const & array_in) { +// using np_str_t = std::array; +// pybind11::array_t cstring_array(vector.size()); +// const char * data = +// np_str_t* array_of_cstr_ptr = reinterpret_cast(cstring_array.request().ptr); +// +// +///* std::vector tobuild; */ +///* std::vector const cpp_vec(py_array_to_vector(array_in)); */ +///* std::vector encoded(mmtf::encodeStringVector(cpp_vec, max_string_size)); */ +///* return py::bytes(encoded.data(), encoded.size()); */ +//} + + +std::vector +char_vector_to_string_vector(std::vector const & cvec) { + std::vector ret(cvec.size()); + for (std::size_t i=0; i > tmp_target; + tmp_object.convert(tmp_target); + sd.bondProperties = tmp_target["bondProperties"]; + sd.atomProperties = tmp_target["atomProperties"]; + sd.groupProperties = tmp_target["groupProperties"]; + sd.chainProperties = tmp_target["chainProperties"]; + sd.modelProperties = tmp_target["modelProperties"]; + sd.extraProperties = tmp_target["extraProperties"]; +} + + +py::array +binary_decode_int32(py::bytes const & bytes) { + using namespace pybind11::literals; + std::string const tmpstr(bytes); + std::vector tmp; + mmtf::BinaryDecoder bd(tmpstr); + bd.decode(tmp); + return array1d_from_vector(tmp); +} + +py::array +binary_decode_int16(py::bytes const & bytes) { + using namespace pybind11::literals; + std::string const tmpstr(bytes); + std::vector tmp; + mmtf::BinaryDecoder bd(tmpstr); + bd.decode(tmp); + return array1d_from_vector(tmp); +} + +py::array +binary_decode_int8(py::bytes const & bytes) { + using namespace pybind11::literals; + std::string const tmpstr(bytes); + std::vector tmp; + mmtf::BinaryDecoder bd(tmpstr); + bd.decode(tmp); + return array1d_from_vector(tmp); +} + + +py::array +binary_decode_char(py::bytes const & bytes) { + using namespace pybind11::literals; + std::string const tmpstr(bytes); + std::vector tmp; + mmtf::BinaryDecoder bd(tmpstr); + bd.decode(tmp); + return array1d_from_vector(tmp); +} + +py::array +binary_decode_float(py::bytes const & bytes) { + using namespace pybind11::literals; + std::string const tmpstr(bytes); + std::vector tmp; + mmtf::BinaryDecoder bd(tmpstr); + bd.decode(tmp); + return array1d_from_vector(tmp); +} + + +PYBIND11_MODULE(mmtf_bindings, m) { + m.def("decode_int32", &binary_decode_int32, "decode array[int32_t]"); + m.def("decode_int16", &binary_decode_int16, "decode array[int16_t]"); + m.def("decode_int8", &binary_decode_int8, "decode array[int8_t]"); + m.def("decode_char", &binary_decode_char, "decode array[char]"); + m.def("decode_float", &binary_decode_float, "decode array[float]"); + // new stuff here + py::class_(m, "CPPStructureData") + .def( pybind11::init( [](){ return new mmtf::StructureData(); } ) ) + .def( pybind11::init( [](mmtf::StructureData const &o){ return new mmtf::StructureData(o); } ) ) + .def_readwrite("mmtfVersion", &mmtf::StructureData::mmtfVersion) + .def_readwrite("mmtfProducer", &mmtf::StructureData::mmtfProducer) + .def("unitCell", [](mmtf::StructureData &m){return array1d_from_vector(m.unitCell);}) + .def_readwrite("unitCell_io", &mmtf::StructureData::unitCell) + .def_readwrite("spaceGroup", &mmtf::StructureData::spaceGroup) + .def_readwrite("structureId", &mmtf::StructureData::structureId) + .def_readwrite("title", &mmtf::StructureData::title) + .def_readwrite("depositionDate", &mmtf::StructureData::depositionDate) + .def_readwrite("releaseDate", &mmtf::StructureData::releaseDate) + .def("ncsOperatorList", [](mmtf::StructureData &m){return array2D_from_vector(m.ncsOperatorList);}) + .def_readwrite("ncsOperatorList_io", &mmtf::StructureData::ncsOperatorList, py::return_value_policy::move) + .def("bioAssemblyList", [](mmtf::StructureData &m){return dump_bio_assembly_list(m);}) + .def_readwrite("bioAssemblyList_io", &mmtf::StructureData::bioAssemblyList) + .def("entityList", [](mmtf::StructureData &m){return dump_entity_list(m.entityList);}) + .def_readwrite("entityList_io", &mmtf::StructureData::entityList) + .def_readwrite("experimentalMethods", &mmtf::StructureData::experimentalMethods) + .def_readwrite("resolution", &mmtf::StructureData::resolution) + .def_readwrite("rFree", &mmtf::StructureData::rFree) + .def_readwrite("rWork", &mmtf::StructureData::rWork) + .def_readwrite("numBonds", &mmtf::StructureData::numBonds) + .def_readwrite("numAtoms", &mmtf::StructureData::numAtoms) + .def_readwrite("numGroups", &mmtf::StructureData::numGroups) + .def_readwrite("numChains", &mmtf::StructureData::numChains) + .def_readwrite("numModels", &mmtf::StructureData::numModels) + .def("groupList", [](mmtf::StructureData &m){return dump_group_list(m.groupList);}) + .def_readwrite("groupList_io", &mmtf::StructureData::groupList) + .def("unitCell", [](mmtf::StructureData &m){return array1d_from_vector(m.unitCell);}) + .def_readwrite("unitCell_io", &mmtf::StructureData::unitCell) + .def("bondAtomList", [](mmtf::StructureData &m){return array1d_from_vector(m.bondAtomList);}) + .def_readwrite("bondAtomList_io", &mmtf::StructureData::bondAtomList) + .def("bondOrderList", [](mmtf::StructureData &m){return array1d_from_vector(m.bondOrderList);}) + .def_readwrite("bondOrderList_io", &mmtf::StructureData::bondOrderList) + .def("bondResonanceList", [](mmtf::StructureData &m){return array1d_from_vector(m.bondResonanceList);}) + .def_readwrite("bondResonanceList_io", &mmtf::StructureData::bondResonanceList) + .def("xCoordList", [](mmtf::StructureData &m){return array1d_from_vector(m.xCoordList);}) + .def_readwrite("xCoordList_io", &mmtf::StructureData::xCoordList) + .def("yCoordList", [](mmtf::StructureData &m){return array1d_from_vector(m.yCoordList);}) + .def_readwrite("yCoordList_io", &mmtf::StructureData::yCoordList) + .def("zCoordList", [](mmtf::StructureData &m){return array1d_from_vector(m.zCoordList);}) + .def_readwrite("zCoordList_io", &mmtf::StructureData::zCoordList) + .def("bFactorList", [](mmtf::StructureData &m){return array1d_from_vector(m.bFactorList);}) + .def_readwrite("bFactorList_io", &mmtf::StructureData::bFactorList) + .def("atomIdList", [](mmtf::StructureData &m){return array1d_from_vector(m.atomIdList);}) + .def_readwrite("atomIdList_io", &mmtf::StructureData::atomIdList) + .def("altLocList", [](mmtf::StructureData &m) { + return array1d_from_vector(m.altLocList); + /* std::vector tmp(char_vector_to_string_vector(m.altLocList)); */ + /* return array1d_from_vector(tmp); */ + }) + .def("set_altLocList", [](mmtf::StructureData &m, py::array_t const & st) { + m.altLocList = std::vector(st.data(), st.data()+st.size()); + }) + .def("occupancyList", [](mmtf::StructureData &m){return array1d_from_vector(m.occupancyList);}) + .def_readwrite("occupancyList_io", &mmtf::StructureData::occupancyList) + .def("groupIdList", [](mmtf::StructureData &m){return array1d_from_vector(m.groupIdList);}) + .def_readwrite("groupIdList_io", &mmtf::StructureData::groupIdList) + .def("groupTypeList", [](mmtf::StructureData &m){return array1d_from_vector(m.groupTypeList);}) + .def_readwrite("groupTypeList_io", &mmtf::StructureData::groupTypeList) + .def("secStructList", [](mmtf::StructureData &m){return array1d_from_vector(m.secStructList);}) + .def_readwrite("secStructList_io", &mmtf::StructureData::secStructList) + .def("insCodeList", [](mmtf::StructureData &m) { + std::vector tmp(char_vector_to_string_vector(m.insCodeList)); + return array1d_from_vector(tmp); + }) + .def("set_insCodeList", [](mmtf::StructureData &m, py::array_t const & st) { + m.insCodeList = std::vector(st.data(), st.data()+st.size()); + }) + .def("sequenceIndexList", [](mmtf::StructureData &m){return array1d_from_vector(m.sequenceIndexList);}) + .def_readwrite("sequenceIndexList_io", &mmtf::StructureData::sequenceIndexList) + .def("chainIdList", [](mmtf::StructureData &m){return array1d_from_vector(m.chainIdList);}) + .def_readwrite("chainIdList_io", &mmtf::StructureData::chainIdList) + .def("chainNameList", [](mmtf::StructureData &m){return array1d_from_vector(m.chainNameList);}) + .def_readwrite("chainNameList_io", &mmtf::StructureData::chainNameList) + .def("groupsPerChain", [](mmtf::StructureData &m){return array1d_from_vector(m.groupsPerChain);}) + .def_readwrite("groupsPerChain_io", &mmtf::StructureData::groupsPerChain) + .def("chainsPerModel", [](mmtf::StructureData &m){return array1d_from_vector(m.chainsPerModel);}) + .def_readwrite("chainsPerModel_io", &mmtf::StructureData::chainsPerModel) + .def("set_properties", [](mmtf::StructureData & sd, py::bytes const & bytes_in){set_properties(sd, bytes_in);}) + .def("raw_properties", [](mmtf::StructureData const &m){return raw_properties(m);}); + + // I think it would be ideal to not pass in the sd, but it is still very + // fast this way. + m.def("set_bioAssemblyList", [](py::list const & i, mmtf::StructureData & sd){return set_bioAssemblyList(i, sd);}); + m.def("set_entityList", [](py::list const & i, mmtf::StructureData & sd){return set_entityList(i, sd);}); + m.def("set_groupList", [](py::list const & i, mmtf::StructureData & sd){return set_groupList(i, sd);}); + m.def("decodeFromFile", &mmtf::decodeFromFile, "decode a mmtf::StructureData from a file"); + m.def("decodeFromBuffer", &mmtf::decodeFromBuffer, "decode a mmtf::StructureData from bytes"); + m.def("encodeToFile", [](mmtf::StructureData const &m, std::string const & fn){mmtf::encodeToFile(m, fn);}); + m.def("encodeToStream", [](mmtf::StructureData const &m){ + std::stringstream ss; + mmtf::encodeToStream(m, ss); + return py::bytes(ss.str()); + }); + // encoders + m.def("encodeInt8ToByte", &py_encodeInt8ToByte); + m.def("encodeFourByteInt", &py_encodeFourByteInt); + m.def("encodeRunLengthChar", &py_encodeRunLengthChar); + m.def("encodeRunLengthDeltaInt", &py_encodeRunLengthDeltaInt); + m.def("encodeDeltaRecursiveFloat", &py_encodeDeltaRecursiveFloat); + m.def("encodeRunLengthFloat", &py_encodeRunLengthFloat); + m.def("encodeRunLengthInt8", &py_encodeRunLengthInt8); + //m.def("encodeStringVector", &py_encodeStringVector); +} diff --git a/python_src/mmtf_cppy/__init__.py b/python_src/mmtf_cppy/__init__.py new file mode 100644 index 0000000..92be8c5 --- /dev/null +++ b/python_src/mmtf_cppy/__init__.py @@ -0,0 +1,31 @@ +from .structure_data import ( + Entity, + GroupType, + Transform, + BioAssembly, + StructureData, +) + + +from .mmtf_bindings import ( + CPPStructureData, + decode_int8, + encodeRunLengthInt8, + decodeFromBuffer, + encodeDeltaRecursiveFloat, + encodeToFile, + decodeFromFile, + encodeFourByteInt, + encodeToStream, + decode_char, + encodeInt8ToByte, + set_bioAssemblyList, + decode_float, + encodeRunLengthChar, + set_entityList, + decode_int16, + encodeRunLengthDeltaInt, + set_groupList, + decode_int32, + encodeRunLengthFloat, +) diff --git a/python_src/mmtf_cppy/structure_data.py b/python_src/mmtf_cppy/structure_data.py new file mode 100644 index 0000000..f8ccb52 --- /dev/null +++ b/python_src/mmtf_cppy/structure_data.py @@ -0,0 +1,440 @@ +import numpy as np +import msgpack +from typing import List, Dict + +from . import mmtf_bindings +from .mmtf_bindings import CPPStructureData as CPPSD, decodeFromFile, decodeFromBuffer, encodeToFile, encodeToStream + + +class Entity: + def __init__(self, chainIndexList, description, type_, sequence): + self.chainIndexList = chainIndexList + self.description = description + self.type = type_ + self.sequence = sequence + + def __repr__(self): + return ( + f"chainIndexList: {self.chainIndexList}" + f"description: {self.description}" + f"type: {self.type}" + f"sequence: {self.sequence}" + ) + + def __eq__(self, other: "Entity"): + return ( + (self.chainIndexList == other.chainIndexList).all() + and self.description == other.description + and self.type == other.type + and self.sequence == other.sequence + ) + + +class GroupType: + def __init__( + self, + formalChargeList: np.ndarray, + atomNameList, + elementList, + bondAtomList, + bondOrderList, + bondResonanceList, + groupName, + singleLetterCode, + chemCompType, + ): + self.formalChargeList = formalChargeList + self.atomNameList = atomNameList + self.elementList = elementList + self.bondAtomList = bondAtomList + self.bondOrderList = bondOrderList + self.bondResonanceList = bondResonanceList + self.groupName = groupName + self.singleLetterCode = singleLetterCode + self.chemCompType = chemCompType + + def __repr__(self): + return ( + f"formalChargeList: {self.formalChargeList}" + f" atomNameList: {self.atomNameList}" + f" elementList: {self.elementList}" + f" bondAtomList: {self.bondAtomList}" + f" bondOrderList: {self.bondOrderList}" + f" bondResonanceList: {self.bondResonanceList}" + f" groupName: {self.groupName}" + f" singleLetterCode: {self.singleLetterCode}" + f" chemCompType: {self.chemCompType}" + ) + + def __eq__(self, other: "GroupType"): + return ( + (self.formalChargeList == other.formalChargeList).all() + and self.atomNameList == other.atomNameList + and self.elementList == other.elementList + and (self.bondAtomList == other.bondAtomList).all() + and (self.bondOrderList == other.bondOrderList).all() + and (self.bondResonanceList == other.bondResonanceList).all() + and self.groupName == other.groupName + and self.singleLetterCode == other.singleLetterCode + and self.chemCompType == other.chemCompType + ) + + +class Transform: + def __init__(self, chainIndexList: List[int], matrix: np.ndarray): + self.chainIndexList = chainIndexList + self.matrix = matrix + + def __eq__(self, other: "Transform"): + return (self.chainIndexList == other.chainIndexList).all() and (self.matrix == other.matrix).all() + + +class BioAssembly: + def __init__(self, transformList: List[Transform], name: str): + self.transformList = transformList + self.name = name + + def __eq__(self, other: "BioAssembly"): + return self.transformList == other.transformList and self.name == other.name + + +def cppSD_from_SD(sd: "StructureData"): + cppsd = CPPSD() + cppsd.mmtfVersion = sd.mmtfVersion + cppsd.mmtfProducer = sd.mmtfProducer + cppsd.unitCell_io = sd.unitCell + cppsd.spaceGroup = sd.spaceGroup + cppsd.structureId = sd.structureId + cppsd.title = sd.title + cppsd.depositionDate = sd.depositionDate + cppsd.releaseDate = sd.releaseDate + cppsd.ncsOperatorList_io = sd.ncsOperatorList + mmtf_bindings.set_bioAssemblyList(sd.bioAssemblyList, cppsd) + mmtf_bindings.set_entityList(sd.entityList, cppsd) + cppsd.experimentalMethods = sd.experimentalMethods + cppsd.resolution = sd.resolution + cppsd.rFree = sd.rFree + cppsd.rWork = sd.rWork + cppsd.numBonds = sd.numBonds + cppsd.numAtoms = sd.numAtoms + cppsd.numGroups = sd.numGroups + cppsd.numChains = sd.numChains + cppsd.numModels = sd.numModels + mmtf_bindings.set_groupList(sd.groupList, cppsd) + cppsd.bondAtomList_io = sd.bondAtomList + cppsd.bondOrderList_io = sd.bondOrderList + cppsd.bondResonanceList_io = sd.bondResonanceList + cppsd.xCoordList_io = sd.xCoordList + cppsd.yCoordList_io = sd.yCoordList + cppsd.zCoordList_io = sd.zCoordList + cppsd.bFactorList_io = sd.bFactorList + cppsd.atomIdList_io = sd.atomIdList + tmp_altLocList = np.array([ord(x) if x else 0 for x in sd.altLocList], dtype=np.int8) + cppsd.set_altLocList(tmp_altLocList) + del tmp_altLocList + cppsd.occupancyList_io = sd.occupancyList + cppsd.groupIdList_io = sd.groupIdList + cppsd.groupTypeList_io = sd.groupTypeList + cppsd.secStructList_io = sd.secStructList + tmp_insCodeList = np.array([ord(x) if x else 0 for x in sd.insCodeList], dtype=np.int8) + cppsd.set_insCodeList(np.int8(tmp_insCodeList)) + del tmp_insCodeList + cppsd.sequenceIndexList_io = sd.sequenceIndexList + cppsd.chainIdList_io = sd.chainIdList + cppsd.chainNameList_io = sd.chainNameList + cppsd.groupsPerChain_io = sd.groupsPerChain + cppsd.chainsPerModel_io = sd.chainsPerModel + packed_data = msgpack.packb( + { + "bondProperties": sd.bondProperties, + "atomProperties": sd.atomProperties, + "groupProperties": sd.groupProperties, + "chainProperties": sd.chainProperties, + "modelProperties": sd.modelProperties, + "extraProperties": sd.extraProperties, + }, + use_bin_type=True, + ) + cppsd.set_properties(packed_data) + return cppsd + + +class StructureData: + def __init__(self, file_name=None, file_bytes=None, debuf_filemames=None): + """ + + Note: + file and bytes are separated because it will be faster + if you just let c++ handle the file (rather than have + python read the bytes itself and pass them to c++) + """ + if file_name: + self.init_from_file_name(file_name) + elif file_bytes: + self.init_from_bytes(file_bytes) + elif debuf_filemames: + self.init_from_file_names_debug(debuf_filemames) + else: + self.raw_init() + + def init_from_bytes(self, file_bytes: bytes): + cppsd = CPPSD() + decodeFromBuffer(cppsd, file_bytes, len(file_bytes)) + self.init_from_cppsd(cppsd) + + def init_from_file_name(self, file_name: str): + cppsd = CPPSD() + decodeFromFile(cppsd, file_name) + self.init_from_cppsd(cppsd) + + def init_from_cppsd(self, cppsd: "CPPStructureData"): + self.mmtfVersion = cppsd.mmtfVersion + self.mmtfProducer = cppsd.mmtfProducer + self.unitCell = cppsd.unitCell() + self.spaceGroup = cppsd.spaceGroup + self.structureId = cppsd.structureId + self.title = cppsd.title + self.depositionDate = cppsd.depositionDate + self.releaseDate = cppsd.releaseDate + self.ncsOperatorList = cppsd.ncsOperatorList() + self.bioAssemblyList = cppsd.bioAssemblyList() + self.entityList = cppsd.entityList() + self.experimentalMethods = cppsd.experimentalMethods + self.resolution = cppsd.resolution + self.rFree = cppsd.rFree + self.rWork = cppsd.rWork + self.numBonds = cppsd.numBonds + self.numAtoms = cppsd.numAtoms + self.numGroups = cppsd.numGroups + self.numChains = cppsd.numChains + self.numModels = cppsd.numModels + self.groupList = cppsd.groupList() + self.bondAtomList = cppsd.bondAtomList() + self.bondOrderList = cppsd.bondOrderList() + self.bondResonanceList = cppsd.bondResonanceList() + self.xCoordList = cppsd.xCoordList() + self.yCoordList = cppsd.yCoordList() + self.zCoordList = cppsd.zCoordList() + self.bFactorList = cppsd.bFactorList() + self.atomIdList = cppsd.atomIdList() + self.altLocList = cppsd.altLocList() # slow + self.occupancyList = cppsd.occupancyList() + self.groupIdList = cppsd.groupIdList() + self.groupTypeList = cppsd.groupTypeList() # slow + self.secStructList = cppsd.secStructList() + self.insCodeList = cppsd.insCodeList() # slow + self.sequenceIndexList = cppsd.sequenceIndexList() + self.chainIdList = cppsd.chainIdList() + self.chainNameList = cppsd.chainNameList() + self.groupsPerChain = cppsd.groupsPerChain() + self.chainsPerModel = cppsd.chainsPerModel() + + raw_properties = cppsd.raw_properties() + raw_properties = msgpack.unpackb(raw_properties, raw=False) + self.bondProperties = raw_properties["bondProperties"] + self.atomProperties = raw_properties["atomProperties"] + self.groupProperties = raw_properties["groupProperties"] + self.chainProperties = raw_properties["chainProperties"] + self.modelProperties = raw_properties["modelProperties"] + self.extraProperties = raw_properties["extraProperties"] + + def raw_init(self): + self.mmtfVersion = None + self.mmtfProducer = None + self.unitCell = None + self.spaceGroup = None + self.structureId = None + self.title = None + self.depositionDate = None + self.releaseDate = None + self.ncsOperatorList = None + self.bioAssemblyList = None + self.entityList = None + self.experimentalMethods = None + self.resolution = None + self.rFree = None + self.rWork = None + self.numBonds = None + self.numAtoms = None + self.numGroups = None + self.numChains = None + self.numModels = None + self.groupList = None + self.bondAtomList = None + self.bondOrderList = None + self.bondResonanceList = None + self.xCoordList = None + self.yCoordList = None + self.zCoordList = None + self.bFactorList = None + self.atomIdList = None + self.altLocList = None + self.occupancyList = None + self.groupIdList = None + self.groupTypeList = None + self.secStructList = None + self.insCodeList = None + self.sequenceIndexList = None + self.chainIdList = None + self.chainNameList = None + self.groupsPerChain = None + self.chainsPerModel = None + self.bondProperties = None + self.atomProperties = None + self.groupProperties = None + self.chainProperties = None + self.modelProperties = None + self.extraProperties = None + + def write_to_file(self, filename: str): + cppsd = cppSD_from_SD(self) + encodeToFile(cppsd, filename) + + def write_to_bytes(self): + cppsd = cppSD_from_SD(self) + buff = encodeToStream(cppsd) + return buff + + def check_equals(self, other: "StructureData"): + if not (self.mmtfVersion == other.mmtfVersion): + print("NOT self.mmtfVersion == other.mmtfVersion") + if not (self.mmtfProducer == other.mmtfProducer): + print("NOT self.mmtfProducer == other.mmtfProducer") + if not ((self.unitCell == other.unitCell).all()): + print("NOT (self.unitCell == other.unitCell).all()") + if not (self.spaceGroup == other.spaceGroup): + print("NOT self.spaceGroup == other.spaceGroup") + if not (self.structureId == other.structureId): + print("NOT self.structureId == other.structureId") + if not (self.title == other.title): + print("NOT self.title == other.title") + if not (self.depositionDate == other.depositionDate): + print("NOT self.depositionDate == other.depositionDate") + if not (self.releaseDate == other.releaseDate): + print("NOT self.releaseDate == other.releaseDate") + if not ((self.ncsOperatorList == other.ncsOperatorList).all()): + print("NOT (self.ncsOperatorList == other.ncsOperatorList).all()") + if not (self.bioAssemblyList == other.bioAssemblyList): + print("NOT self.bioAssemblyList == other.bioAssemblyList") + if not (self.entityList == other.entityList): + print("NOT self.entityList == other.entityList") + if not (self.experimentalMethods == other.experimentalMethods): + print("NOT self.experimentalMethods == other.experimentalMethods") + if not (self.resolution == other.resolution): + print("NOT self.resolution == other.resolution") + if not (self.rFree == other.rFree): + print("NOT self.rFree == other.rFree") + if not (self.rWork == other.rWork): + print("NOT self.rWork == other.rWork") + if not (self.numBonds == other.numBonds): + print("NOT self.numBonds == other.numBonds") + if not (self.numAtoms == other.numAtoms): + print("NOT self.numAtoms == other.numAtoms") + if not (self.numGroups == other.numGroups): + print("NOT self.numGroups == other.numGroups") + if not (self.numChains == other.numChains): + print("NOT self.numChains == other.numChains") + if not (self.numModels == other.numModels): + print("NOT self.numModels == other.numModels") + if not (self.groupList == other.groupList): + print("NOT self.groupList == other.groupList") + if not ((self.bondAtomList == other.bondAtomList).all()): + print("NOT (self.bondAtomList == other.bondAtomList).all()") + if not ((self.bondOrderList == other.bondOrderList).all()): + print("NOT (self.bondOrderList == other.bondOrderList).all()") + if not ((self.bondResonanceList == other.bondResonanceList).all()): + print("NOT (self.bondResonanceList == other.bondResonanceList).all()") + if not ((self.xCoordList == other.xCoordList).all()): + print("NOT (self.xCoordList == other.xCoordList).all()") + if not ((self.yCoordList == other.yCoordList).all()): + print("NOT (self.yCoordList == other.yCoordList).all()") + if not ((self.zCoordList == other.zCoordList).all()): + print("NOT (self.zCoordList == other.zCoordList).all()") + if not ((self.bFactorList == other.bFactorList).all()): + print("NOT (self.bFactorList == other.bFactorList).all()") + if not ((self.atomIdList == other.atomIdList).all()): + print("NOT (self.atomIdList == other.atomIdList).all()") + if not ((self.altLocList == other.altLocList).all()): + print("NOT (self.altLocList == other.altLocList).all()") + if not ((self.occupancyList == other.occupancyList).all()): + print("NOT (self.occupancyList == other.occupancyList).all()") + if not ((self.groupIdList == other.groupIdList).all()): + print("NOT (self.groupIdList == other.groupIdList).all()") + if not ((self.groupTypeList == other.groupTypeList).all()): + print("NOT (self.groupTypeList == other.groupTypeList).all()") + if not ((self.secStructList == other.secStructList).all()): + print("NOT (self.secStructList == other.secStructList).all()") + if not ((self.insCodeList == other.insCodeList).all()): + print("NOT (self.insCodeList == other.insCodeList).all()") + if not ((self.sequenceIndexList == other.sequenceIndexList).all()): + print("NOT (self.sequenceIndexList == other.sequenceIndexList).all()") + if not ((self.chainIdList == other.chainIdList).all()): + print("NOT (self.chainIdList == other.chainIdList).all()") + if not ((self.chainNameList == other.chainNameList).all()): + print("NOT (self.chainNameList == other.chainNameList).all()") + if not ((self.groupsPerChain == other.groupsPerChain).all()): + print("NOT (self.groupsPerChain == other.groupsPerChain).all()") + if not ((self.chainsPerModel == other.chainsPerModel).all()): + print("NOT (self.chainsPerModel == other.chainsPerModel).all()") + if not (self.bondProperties == other.bondProperties): + print("NOT self.bondProperties == other.bondProperties") + if not (self.atomProperties == other.atomProperties): + print("NOT self.atomProperties == other.atomProperties") + if not (self.groupProperties == other.groupProperties): + print("NOT self.groupProperties == other.groupProperties") + if not (self.chainProperties == other.chainProperties): + print("NOT self.chainProperties == other.chainProperties") + if not (self.modelProperties == other.modelProperties): + print("NOT self.modelProperties == other.modelProperties") + if not (self.extraProperties == other.extraProperties): + print("NOT self.extraProperties == other.extraProperties") + + def __eq__(self, other: "StructureData"): + return ( + self.mmtfVersion == other.mmtfVersion + and self.mmtfProducer == other.mmtfProducer + and (self.unitCell == other.unitCell).all() + and self.spaceGroup == other.spaceGroup + and self.structureId == other.structureId + and self.title == other.title + and self.depositionDate == other.depositionDate + and self.releaseDate == other.releaseDate + and (self.ncsOperatorList == other.ncsOperatorList).all() + and self.bioAssemblyList == other.bioAssemblyList + and self.entityList == other.entityList + and self.experimentalMethods == other.experimentalMethods + and self.resolution == other.resolution + and self.rFree == other.rFree + and self.rWork == other.rWork + and self.numBonds == other.numBonds + and self.numAtoms == other.numAtoms + and self.numGroups == other.numGroups + and self.numChains == other.numChains + and self.numModels == other.numModels + and self.groupList == other.groupList + and (self.bondAtomList == other.bondAtomList).all() + and (self.bondOrderList == other.bondOrderList).all() + and (self.bondResonanceList == other.bondResonanceList).all() + and (self.xCoordList == other.xCoordList).all() + and (self.yCoordList == other.yCoordList).all() + and (self.zCoordList == other.zCoordList).all() + and (self.bFactorList == other.bFactorList).all() + and (self.atomIdList == other.atomIdList).all() + and (self.altLocList == other.altLocList).all() + and (self.occupancyList == other.occupancyList).all() + and (self.groupIdList == other.groupIdList).all() + and (self.groupTypeList == other.groupTypeList).all() + and (self.secStructList == other.secStructList).all() + and (self.insCodeList == other.insCodeList).all() + and (self.sequenceIndexList == other.sequenceIndexList).all() + and (self.chainIdList == other.chainIdList).all() + and (self.chainNameList == other.chainNameList).all() + and (self.groupsPerChain == other.groupsPerChain).all() + and (self.chainsPerModel == other.chainsPerModel).all() + and self.bondProperties == other.bondProperties + and self.atomProperties == other.atomProperties + and self.groupProperties == other.groupProperties + and self.chainProperties == other.chainProperties + and self.modelProperties == other.modelProperties + and self.extraProperties == other.extraProperties + ) diff --git a/python_src/tests/conftest.py b/python_src/tests/conftest.py new file mode 100644 index 0000000..bfe6e8f --- /dev/null +++ b/python_src/tests/conftest.py @@ -0,0 +1,14 @@ +import os +from distutils import dir_util +import pytest + + +@pytest.fixture +def test_data_dir(tmpdir, request): + filename = request.module.__file__ + root_dir = os.path.dirname(filename) + input_dir = os.path.join(root_dir, "../../submodules", "mmtf_spec") + dir_util.copy_tree(input_dir, os.path.join(tmpdir.strpath, "mmtf_spec")) + input_dir = os.path.join(root_dir, "../../", "temporary_test_data") + dir_util.copy_tree(input_dir, os.path.join(tmpdir.strpath, "temporary_test_data")) + return tmpdir diff --git a/python_src/tests/tests.py b/python_src/tests/tests.py new file mode 100644 index 0000000..60ec577 --- /dev/null +++ b/python_src/tests/tests.py @@ -0,0 +1,192 @@ +import os +import mmtf_cppy +from mmtf_cppy import CPPStructureData, decodeFromFile, StructureData +import time +import pytest +import numpy as np + + +def test_eq_operator(test_data_dir): + s1 = StructureData(os.path.join(test_data_dir, "mmtf_spec/test-suite/mmtf/173D.mmtf")) + s2 = StructureData(os.path.join(test_data_dir, "mmtf_spec/test-suite/mmtf/173D.mmtf")) + s3 = StructureData(os.path.join(test_data_dir, "mmtf_spec/test-suite/mmtf/1AUY.mmtf")) + assert s1 == s2 + assert s1 != s3 + + +def test_roundtrip(test_data_dir): + files = [ + "mmtf_spec/test-suite/mmtf/173D.mmtf", + "mmtf_spec/test-suite/mmtf/1AA6.mmtf", + "mmtf_spec/test-suite/mmtf/1AUY.mmtf", + "mmtf_spec/test-suite/mmtf/1BNA.mmtf", + "mmtf_spec/test-suite/mmtf/1CAG.mmtf", + "mmtf_spec/test-suite/mmtf/1HTQ.mmtf", + "mmtf_spec/test-suite/mmtf/1IGT.mmtf", + "mmtf_spec/test-suite/mmtf/1L2Q.mmtf", + "mmtf_spec/test-suite/mmtf/1LPV.mmtf", + "mmtf_spec/test-suite/mmtf/1MSH.mmtf", + "mmtf_spec/test-suite/mmtf/1O2F.mmtf", + "mmtf_spec/test-suite/mmtf/1R9V.mmtf", + "mmtf_spec/test-suite/mmtf/1SKM.mmtf", + "mmtf_spec/test-suite/mmtf/3NJW.mmtf", + "mmtf_spec/test-suite/mmtf/3ZYB.mmtf", + "mmtf_spec/test-suite/mmtf/4CK4.mmtf", + "mmtf_spec/test-suite/mmtf/4CUP.mmtf", + "mmtf_spec/test-suite/mmtf/4OPJ.mmtf", + "mmtf_spec/test-suite/mmtf/4P3R.mmtf", + "mmtf_spec/test-suite/mmtf/4V5A.mmtf", + "mmtf_spec/test-suite/mmtf/4Y60.mmtf", + "mmtf_spec/test-suite/mmtf/5EMG.mmtf", + "mmtf_spec/test-suite/mmtf/5ESW.mmtf", + "mmtf_spec/test-suite/mmtf/empty-all0.mmtf", + "mmtf_spec/test-suite/mmtf/empty-numChains1.mmtf", + "mmtf_spec/test-suite/mmtf/empty-numModels1.mmtf", + "temporary_test_data/all_canoncial.mmtf", + "temporary_test_data/1PEF_with_resonance.mmtf", + ] + test_tmp_mmtf_filename = "test_mmtf.mmtf" + for filename in files: + s1 = StructureData(os.path.join(test_data_dir, filename)) + s1.write_to_file(test_tmp_mmtf_filename) + s2 = StructureData(test_tmp_mmtf_filename) + s1.check_equals(s2) + assert s1 == s2 + + +def test_bad_mmtf(test_data_dir): + doesnt_work = ["mmtf_spec/test-suite/mmtf/empty-mmtfVersion99999999.mmtf"] + for filename in doesnt_work: + with pytest.raises(RuntimeError) as einfo: + s1 = StructureData(os.path.join(test_data_dir, filename)) + + +def test_various_throws(test_data_dir): + working_mmtf_fn = os.path.join(test_data_dir, "mmtf_spec/test-suite/mmtf/173D.mmtf") + + sd = StructureData(working_mmtf_fn) + sd.xCoordList = np.append(sd.xCoordList, 0.334) + with pytest.raises(RuntimeError) as einfo: + sd.write_to_file("wrk.mmtf") + + sd = StructureData(working_mmtf_fn) + sd.yCoordList = np.append(sd.yCoordList, 0.334) + with pytest.raises(RuntimeError) as einfo: + sd.write_to_file("wrk.mmtf") + + sd = StructureData(working_mmtf_fn) + sd.zCoordList = np.append(sd.zCoordList, 0.334) + with pytest.raises(RuntimeError) as einfo: + sd.write_to_file("wrk.mmtf") + + sd = StructureData(working_mmtf_fn) + sd.bFactorList = np.append(sd.bFactorList, 0.334) + with pytest.raises(RuntimeError) as einfo: + sd.write_to_file("wrk.mmtf") + + sd = StructureData(working_mmtf_fn) + sd.numAtoms = 20 + with pytest.raises(RuntimeError) as einfo: + sd.write_to_file("wrk.mmtf") + + sd = StructureData(working_mmtf_fn) + sd.chainIdList = np.append(sd.chainIdList, "xsz") + with pytest.raises(RuntimeError) as einfo: + sd.write_to_file("wrk.mmtf") + + sd = StructureData(working_mmtf_fn) + sd.chainIdList = sd.chainIdList.astype(" c++ vector string +# def test_encodeStringVector(): +# encoded_data = b'\x00\x00\x00\x05\x00\x00\x00\x06\x00\x00\x00\x04B\x00\x00\x00A\x00\x00\x00C\x00\x00\x00A\x00\x00\x00A\x00\x00\x00A\x00\x00\x00' +# decoded_data = np.array(("B", "A", "C", "A", "A", "A")) +# ret = mmtf_cppy.encodeStringVector(decoded_data, 4) +# assert ret == encoded_data + + +def test_atomProperties(test_data_dir): + working_mmtf_fn = os.path.join(test_data_dir, "mmtf_spec/test-suite/mmtf/173D.mmtf") + sd = StructureData(working_mmtf_fn) + random_data = list(range(256)) + encoded_random_data = mmtf_cppy.encodeRunLengthDeltaInt(list((range(256)))) + sd.atomProperties["256_atomColorList"] = random_data + sd.atomProperties["256_atomColorList_encoded"] = encoded_random_data + sd.write_to_file("atomProperties_test.mmtf") + sd2 = StructureData("atomProperties_test.mmtf") + assert sd2.atomProperties["256_atomColorList"] == random_data + assert sd2.atomProperties["256_atomColorList_encoded"] == encoded_random_data + assert (mmtf_cppy.decode_int32(sd2.atomProperties["256_atomColorList_encoded"]) == np.array(random_data)).all() diff --git a/requirements-dev.txt b/requirements-dev.txt new file mode 100644 index 0000000..9955dec --- /dev/null +++ b/requirements-dev.txt @@ -0,0 +1,2 @@ +pytest +pytest-cov diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..1cf67a2 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,3 @@ +numpy +pybind11 +msgpack diff --git a/setup.py b/setup.py new file mode 100644 index 0000000..f95993c --- /dev/null +++ b/setup.py @@ -0,0 +1,89 @@ +#!/usr/bin/env python3 + +import os +import re +import sys +import sysconfig +import platform +import subprocess +import pybind11 + +from distutils.version import LooseVersion +from setuptools import setup, Extension, find_packages +from setuptools.command.build_ext import build_ext +from setuptools.command.test import test as TestCommand +from shutil import copyfile, copymode + + +class CMakeExtension(Extension): + def __init__(self, name, sourcedir=""): + Extension.__init__(self, name, sources=[]) + self.sourcedir = os.path.abspath(sourcedir) + + +class CMakeBuild(build_ext): + def run(self): + try: + out = subprocess.check_output(["cmake", "--version"]) + except OSError: + raise RuntimeError( + "CMake must be installed to build the following extensions: " + + ", ".join(e.name for e in self.extensions) + ) + + if platform.system() == "Windows": + cmake_version = LooseVersion(re.search(r"version\s*([\d.]+)", out.decode()).group(1)) + if cmake_version < "3.5.0": + raise RuntimeError("CMake >= 3.5.0 is required") + + for ext in self.extensions: + self.build_extension(ext) + + def build_extension(self, ext): + extdir = os.path.abspath(os.path.dirname(self.get_ext_fullpath(ext.name))) + cmake_args = [ + "-DCMAKE_LIBRARY_OUTPUT_DIRECTORY=" + extdir, + "-Dbuild_py=ON", + f"-Dpython_include_A={pybind11.get_include()}", + f"-Dpython_include_B={sysconfig.get_path('include')}", + ] + + cfg = "Debug" if self.debug else "Release" + build_args = ["--config", cfg] + + if platform.system() == "Windows": + cmake_args += ["-DCMAKE_LIBRARY_OUTPUT_DIRECTORY_{}={}".format(cfg.upper(), extdir)] + if sys.maxsize > 2 ** 32: + cmake_args += ["-A", "x64"] + build_args += ["--", "/m"] + else: + cmake_args += ["-DCMAKE_BUILD_TYPE=" + cfg] + build_args += ["--", "-j2"] + + env = os.environ.copy() + env["CXXFLAGS"] = '{} -DVERSION_INFO=\\"{}\\"'.format(env.get("CXXFLAGS", ""), self.distribution.get_version()) + if not os.path.exists(self.build_temp): + os.makedirs(self.build_temp) + cmake_set_args = ["cmake", ext.sourcedir] + cmake_args + cmake_build_args = ["cmake", "--build", "."] + build_args + subprocess.check_call(cmake_set_args, cwd=self.build_temp, env=env) + subprocess.check_call(cmake_build_args, cwd=self.build_temp) + + +setup( + name="mmtf_bindings", + version="0.1.0", + author="Daniel P. Farrell", + author_email="danpf@uw.edu", + url="https://github.com/rcsb/mmtf-cpp", + description="A hybrid Python/C++ test project", + long_description=open("README.md").read(), + packages=find_packages("python_src", exclude=["tests", "python_src/tests"]), + package_dir={"": "python_src"}, + ext_modules=[CMakeExtension("mmtf_cppy/")], + cmdclass=dict(build_ext=CMakeBuild), + test_suite="tests", + zip_safe=False, + tests_require=["pytest", "pytest-cov"], + install_requires=["numpy", "msgpack", "pybind11",], +) diff --git a/submodules/Catch2 b/submodules/Catch2 index cf4b7ee..6e79e68 160000 --- a/submodules/Catch2 +++ b/submodules/Catch2 @@ -1 +1 @@ -Subproject commit cf4b7eead92773932f32c7efd2612e9d27b07557 +Subproject commit 6e79e682b726f524310d55dec8ddac4e9c52fb5f diff --git a/submodules/mmtf_spec b/submodules/mmtf_spec index 8c88834..e4aaae5 160000 --- a/submodules/mmtf_spec +++ b/submodules/mmtf_spec @@ -1 +1 @@ -Subproject commit 8c8883457e54fb460908a57d801212c56a603aec +Subproject commit e4aaae5d2f273d073e5482db61f74ad93f6e8ab4 diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index 77fa418..519865e 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -7,10 +7,18 @@ endif() add_executable(mmtf_tests mmtf_tests.cpp) target_compile_features(mmtf_tests PRIVATE cxx_auto_type) -if(WIN32) - target_link_libraries(mmtf_tests Catch msgpackc MMTFcpp ws2_32) +if(EMSCRIPTEN) + if(WIN32) + target_link_libraries(mmtf_tests Catch2::Catch2 msgpackc MMTFcpp ws2_32) + else() + target_link_libraries(mmtf_tests Catch2::Catch2 msgpackc MMTFcpp) + endif() else() - target_link_libraries(mmtf_tests Catch msgpackc MMTFcpp) + if(WIN32) + target_link_libraries(mmtf_tests Catch2::Catch2WithMain msgpackc MMTFcpp ws2_32) + else() + target_link_libraries(mmtf_tests Catch2::Catch2WithMain msgpackc MMTFcpp) + endif() endif() # test for multi-linking diff --git a/tests/mmtf_tests.cpp b/tests/mmtf_tests.cpp index d84a392..993262b 100644 --- a/tests/mmtf_tests.cpp +++ b/tests/mmtf_tests.cpp @@ -1,11 +1,9 @@ #ifdef __EMSCRIPTEN__ #define CATCH_INTERNAL_CONFIG_NO_POSIX_SIGNALS -#define CATCH_CONFIG_RUNNER -#else -#define CATCH_CONFIG_MAIN #endif -#include "catch.hpp" +#include +#include #include #include @@ -17,7 +15,7 @@ template bool approx_equal_vector(const T& a, const T& b, float eps = 0.00001) { if (a.size() != b.size()) return false; for (std::size_t i=0; i < a.size(); ++i) { - if (a[i] != Approx(b[i]).margin(eps)) return false; + if (a[i] != Catch::Approx(b[i]).margin(eps)) return false; } return true; }