diff --git a/openml-lightgbm/lightgbm-provider/src/main/java/com/feedzai/openml/provider/lightgbm/LightGBMBinaryClassificationModelTrainer.java b/openml-lightgbm/lightgbm-provider/src/main/java/com/feedzai/openml/provider/lightgbm/LightGBMBinaryClassificationModelTrainer.java index 3a9366d7..1f01e573 100644 --- a/openml-lightgbm/lightgbm-provider/src/main/java/com/feedzai/openml/provider/lightgbm/LightGBMBinaryClassificationModelTrainer.java +++ b/openml-lightgbm/lightgbm-provider/src/main/java/com/feedzai/openml/provider/lightgbm/LightGBMBinaryClassificationModelTrainer.java @@ -130,21 +130,19 @@ static void fit(final Dataset dataset, final int numIterations = parseInt(params.get(NUM_ITERATIONS_PARAMETER_NAME)); logger.debug("LightGBM model trainParams: {}", trainParams); - final SWIGTrainData swigTrainData = new SWIGTrainData( - numFeatures, - instancesPerChunk); - final SWIGTrainBooster swigTrainBooster = new SWIGTrainBooster(); - /// Create LightGBM dataset - createTrainDataset(dataset, numFeatures, trainParams, swigTrainData); + try (final SWIGTrainBooster swigTrainBooster = new SWIGTrainBooster(); + final SWIGTrainData swigTrainData = new SWIGTrainData(numFeatures, instancesPerChunk)){ + /// Create LightGBM dataset + createTrainDataset(dataset, numFeatures, trainParams, swigTrainData); - /// Create Booster from dataset - createBoosterStructure(swigTrainBooster, swigTrainData, trainParams); - trainBooster(swigTrainBooster.swigBoosterHandle, numIterations); + /// Create Booster from dataset + createBoosterStructure(swigTrainBooster, swigTrainData, trainParams); + trainBooster(swigTrainBooster.swigBoosterHandle, numIterations); - /// Save model - saveModelFileToDisk(swigTrainBooster.swigBoosterHandle, outputModelFilePath); - swigTrainBooster.close(); // Explicitly release C++ resources right away. They're no longer needed. + /// Save model + saveModelFileToDisk(swigTrainBooster.swigBoosterHandle, outputModelFilePath); + } } /** @@ -231,28 +229,30 @@ private static void initializeLightGBMTrainDatasetFeatures(final SWIGTrainData s /// First generate the array that has the chunk sizes for `LGBM_DatasetCreateFromMats`. final SWIGTYPE_p_int swigChunkSizesArray = genSWIGFeatureChunkSizesArray(swigTrainData, numFeatures); - - /// Now create the LightGBM Dataset itself from the chunks: - logger.debug("Creating LGBM_Dataset from chunked data..."); - final int returnCodeLGBM = lightgbmlib.LGBM_DatasetCreateFromMats( - (int) swigTrainData.swigFeaturesChunkedArray.get_chunks_count(), // numChunks - swigTrainData.swigFeaturesChunkedArray.data_as_void(), // input data (void**) - lightgbmlibConstants.C_API_DTYPE_FLOAT64, - swigChunkSizesArray, - numFeatures, - 1, // rowMajor. - trainParams, // parameters. - null, // No alighment with other datasets. - swigTrainData.swigOutDatasetHandlePtr // Output LGBM Dataset - ); - if (returnCodeLGBM == -1) { - logger.error("Could not create LightGBM dataset."); - throw new LightGBMException(); + try { + /// Now create the LightGBM Dataset itself from the chunks: + logger.debug("Creating LGBM_Dataset from chunked data..."); + final int returnCodeLGBM = lightgbmlib.LGBM_DatasetCreateFromMats( + (int) swigTrainData.swigFeaturesChunkedArray.get_chunks_count(), // numChunks + swigTrainData.swigFeaturesChunkedArray.data_as_void(), // input data (void**) + lightgbmlibConstants.C_API_DTYPE_FLOAT64, + swigChunkSizesArray, + numFeatures, + 1, // rowMajor. + trainParams, // parameters. + null, // No alighment with other datasets. + swigTrainData.swigOutDatasetHandlePtr // Output LGBM Dataset + ); + if (returnCodeLGBM == -1) { + logger.error("Could not create LightGBM dataset."); + throw new LightGBMException(); + } + // FIXME is this init necessary? + swigTrainData.initSwigDatasetHandle(); + } finally { + lightgbmlib.delete_intArray(swigChunkSizesArray); } - swigTrainData.initSwigDatasetHandle(); - swigTrainData.destroySwigTrainFeaturesChunkedDataArray(); - lightgbmlib.delete_intArray(swigChunkSizesArray); } /** @@ -313,7 +313,6 @@ private static void setLightGBMDatasetLabelData(final SWIGTrainData swigTrainDat throw new LightGBMException(); } - swigTrainData.destroySwigTrainLabelDataArray(); } /** diff --git a/openml-lightgbm/lightgbm-provider/src/main/java/com/feedzai/openml/provider/lightgbm/SWIGResources.java b/openml-lightgbm/lightgbm-provider/src/main/java/com/feedzai/openml/provider/lightgbm/SWIGResources.java index caa96def..9bab7e7f 100644 --- a/openml-lightgbm/lightgbm-provider/src/main/java/com/feedzai/openml/provider/lightgbm/SWIGResources.java +++ b/openml-lightgbm/lightgbm-provider/src/main/java/com/feedzai/openml/provider/lightgbm/SWIGResources.java @@ -105,6 +105,11 @@ class SWIGResources implements AutoCloseable { */ private Integer boosterNumFeatures; + /** + * Number of classes in the trained LGBM. + */ + private Integer boosterNumClasses = null; + /** * Names of features in the trained LightGBM boosting model. * Whilst not a swig resource, it is automatically retrieved during model loading, @@ -252,19 +257,21 @@ private void initBoosterFastContributionsHandle(final String LightGBMParameters) * Assumes the model was already loaded from file. * Initializes the remaining SWIG resources needed to use the model. * + * The size of {@link #swigOutContributionsPtr} is computed accoring to + * https://lightgbm.readthedocs.io/en/latest/C-API.html#c.LGBM_BoosterPredictForMatSingleRow + * * @throws LightGBMException in case there's an error in the C++ core library. */ private void initAuxiliaryModelResources() throws LightGBMException { - - this.boosterNumFeatures = computeBoosterNumFeaturesFromModel(); + computeBoosterNumFeaturesFromModel(); logger.debug("Loaded LightGBM Model has {} features.", this.boosterNumFeatures); this.boosterFeatureNames = computeBoosterFeatureNamesFromModel(); this.swigOutLengthInt64Ptr = lightgbmlibJNI.new_int64_tp(); this.swigInstancePtr = lightgbmlibJNI.new_doubleArray(getBoosterNumFeatures()); - this.swigOutScoresPtr = lightgbmlibJNI.new_doubleArray(BINARY_LGBM_NUM_CLASSES); - this.swigOutContributionsPtr = lightgbmlibJNI.new_doubleArray(this.boosterNumFeatures); + this.swigOutScoresPtr = lightgbmlibJNI.new_doubleArray(this.boosterNumClasses); + this.swigOutContributionsPtr = lightgbmlibJNI.new_doubleArray((long) this.boosterNumClasses * (this.boosterNumFeatures + 1)); } /** @@ -302,6 +309,14 @@ private void releaseInitializedSWIGResources() throws LightGBMException { lightgbmlibJNI.delete_intp(this.swigOutIntPtr); this.swigOutIntPtr = null; } + if (this.boosterNumFeatures != null) { + lightgbmlibJNI.delete_intp(this.boosterNumFeatures); + this.boosterNumFeatures = null; + } + if (this.boosterNumClasses != null) { + lightgbmlibJNI.delete_intp(this.boosterNumClasses); + this.boosterNumClasses = null; + } if (this.swigOutContributionsPtr != null) { lightgbmlibJNI.delete_doubleArray(this.swigOutContributionsPtr); this.swigOutContributionsPtr = null; @@ -373,17 +388,31 @@ public String[] getBoosterFeatureNames() { * Computes the number of features in the model and returns it. * * @throws LightGBMException when there is a LightGBM C++ error. - * @returns int with the number of Booster features. */ - private Integer computeBoosterNumFeaturesFromModel() throws LightGBMException { - - final int returnCodeLGBM = lightgbmlibJNI.LGBM_BoosterGetNumFeature( + private void computeBoosterNumFeaturesFromModel() throws LightGBMException { + final int returnCodeNumFeatsLGBM = lightgbmlibJNI.LGBM_BoosterGetNumFeature( this.swigBoosterHandle, this.swigOutIntPtr); - if (returnCodeLGBM == -1) + if (returnCodeNumFeatsLGBM == -1) throw new LightGBMException(); - return lightgbmlibJNI.intp_value(this.swigOutIntPtr); + + if (this.boosterNumFeatures != null) { + lightgbmlibJNI.delete_intp(this.boosterNumFeatures); + this.boosterNumFeatures = null; + } + this.boosterNumFeatures = lightgbmlibJNI.intp_value(this.swigOutIntPtr); + + final int returnCodeNumClassesLGBM = lightgbmlibJNI.LGBM_BoosterGetNumClasses( + this.swigBoosterHandle, + this.swigOutIntPtr); + if (returnCodeNumClassesLGBM == -1) + throw new LightGBMException(); + if (this.boosterNumClasses != null) { + lightgbmlibJNI.delete_intp(this.boosterNumClasses); + this.boosterNumClasses = null; + } + this.boosterNumClasses = lightgbmlibJNI.intp_value(this.swigOutIntPtr); } /** diff --git a/openml-lightgbm/lightgbm-provider/src/test/java/com/feedzai/openml/provider/lightgbm/SWIGChunkedArrayAPITest.java b/openml-lightgbm/lightgbm-provider/src/test/java/com/feedzai/openml/provider/lightgbm/SWIGChunkedArrayAPITest.java index 4b94b6a2..04a4c68c 100644 --- a/openml-lightgbm/lightgbm-provider/src/test/java/com/feedzai/openml/provider/lightgbm/SWIGChunkedArrayAPITest.java +++ b/openml-lightgbm/lightgbm-provider/src/test/java/com/feedzai/openml/provider/lightgbm/SWIGChunkedArrayAPITest.java @@ -64,24 +64,27 @@ public void doubleChunkedArrayValuesOK() { final int max_i = 10; final long chunk_size = 3; final doubleChunkedArray chunkedArray = new doubleChunkedArray(chunk_size); + try { + for (int i = 1; i <= max_i; ++i) { + chunkedArray.add(i * 1.1); + } - for (int i = 1; i <= max_i; ++i) { - chunkedArray.add(i * 1.1); - } - - int chunk = 0; - int pos = 0; - for (int i = 0; i < max_i; ++i) { - final double ref_value = (i+1) * 1.1; - assertThat(chunkedArray.getitem(chunk, pos, -1)) - .as("Value at chunk %d, position %d", chunk, pos) - .isCloseTo(ref_value, Offset.offset(1e-3)); - - ++pos; - if (pos == chunk_size) { - ++chunk; - pos = 0; + int chunk = 0; + int pos = 0; + for (int i = 0; i < max_i; ++i) { + final double ref_value = (i + 1) * 1.1; + assertThat(chunkedArray.getitem(chunk, pos, -1)) + .as("Value at chunk %d, position %d", chunk, pos) + .isCloseTo(ref_value, Offset.offset(1e-3)); + + ++pos; + if (pos == chunk_size) { + ++chunk; + pos = 0; + } } + } finally { + chunkedArray.delete(); } } @@ -94,14 +97,18 @@ public void doubleChunkedArrayOutOfBoundsError() { final double on_fail_sentinel_value = -1; final doubleChunkedArray chunkedArray = new doubleChunkedArray(3); - // Test out of bounds chunk (only 1 exists, not 11): - assertThat(chunkedArray.getitem(10, 0, on_fail_sentinel_value)) - .as("out-of-bounds return sentinel value") - .isCloseTo(on_fail_sentinel_value, Offset.offset(1e-3)); - // Test out of bounds on first chunk: - assertThat(chunkedArray.getitem(0, 10, on_fail_sentinel_value)) - .as("out-of-bounds return sentinel value") - .isCloseTo(on_fail_sentinel_value, Offset.offset(1e-3)); + try { + // Test out of bounds chunk (only 1 exists, not 11): + assertThat(chunkedArray.getitem(10, 0, on_fail_sentinel_value)) + .as("out-of-bounds return sentinel value") + .isCloseTo(on_fail_sentinel_value, Offset.offset(1e-3)); + // Test out of bounds on first chunk: + assertThat(chunkedArray.getitem(0, 10, on_fail_sentinel_value)) + .as("out-of-bounds return sentinel value") + .isCloseTo(on_fail_sentinel_value, Offset.offset(1e-3)); + } finally { + chunkedArray.delete(); + } } /** @@ -112,19 +119,27 @@ public void doubleChunkedArrayOutOfBoundsError() { @Test public void ChunkedArrayCoalesceTo() { final int numFeatures = 3; - final int chunkSize = 2*numFeatures; // Must be multiple + final int chunkSize = 2 * numFeatures; // Must be multiple final doubleChunkedArray chunkedArray = new doubleChunkedArray(chunkSize); - // Fill 1 chunk + some part of other - for (int i = 0; i < chunkSize + 1; ++i) { - chunkedArray.add(i); - } - final SWIGTYPE_p_double swigArr = lightgbmlib.new_doubleArray(chunkedArray.get_add_count()); + try { + // Fill 1 chunk + some part of other + for (int i = 0; i < chunkSize + 1; ++i) { + chunkedArray.add(i); + } + final SWIGTYPE_p_double swigArr = lightgbmlib.new_doubleArray(chunkedArray.get_add_count()); + try { - chunkedArray.coalesce_to(swigArr); + chunkedArray.coalesce_to(swigArr); - for (int i = 0; i < chunkedArray.get_add_count(); ++i) { - double v = lightgbmlib.doubleArray_getitem(swigArr, i); - assertThat(v).as("coalescedArray[%d]", i).isCloseTo(i, Offset.offset(1e-3)); + for (int i = 0; i < chunkedArray.get_add_count(); ++i) { + double v = lightgbmlib.doubleArray_getitem(swigArr, i); + assertThat(v).as("coalescedArray[%d]", i).isCloseTo(i, Offset.offset(1e-3)); + } + } finally { + lightgbmlib.delete_doubleArray(swigArr); + } + } finally { + chunkedArray.delete(); } } @@ -135,35 +150,46 @@ public void ChunkedArrayCoalesceTo() { @Test public void LGBM_DatasetCreateFromMatsFromChunkedArray() { final int numFeatures = 3; - final int chunkSize = 2*numFeatures; // Must be multiple + final int chunkSize = 2 * numFeatures; // Must be multiple final doubleChunkedArray chunkedArray = new doubleChunkedArray(chunkSize); - // Fill 1 chunk + some part of other - for (int i = 0; i < chunkSize + 1; ++i) { - chunkedArray.add(i); - } + try { + // Fill 1 chunk + some part of other + for (int i = 0; i < chunkSize + 1; ++i) { + chunkedArray.add(i); + } - final long numChunks = chunkedArray.get_chunks_count(); - SWIGTYPE_p_int chunkSizes = lightgbmlib.new_intArray(numChunks); - for (int i = 0; i < numChunks - 1; ++i) { - lightgbmlib.intArray_setitem(chunkSizes, i, chunkSize); + final long numChunks = chunkedArray.get_chunks_count(); + final SWIGTYPE_p_int chunkSizes = lightgbmlib.new_intArray(numChunks); + try { + for (int i = 0; i < numChunks - 1; ++i) { + lightgbmlib.intArray_setitem(chunkSizes, i, chunkSize); + } + lightgbmlib.intArray_setitem(chunkSizes, numChunks - 1, (int) chunkedArray.get_current_chunk_added_count()); + + final SWIGTYPE_p_p_void swigOutDatasetHandlePtr = lightgbmlib.voidpp_handle(); + try { + final int returnCodeLGBM = lightgbmlib.LGBM_DatasetCreateFromMats( + (int) chunkedArray.get_chunks_count(), + chunkedArray.data_as_void(), + lightgbmlibConstants.C_API_DTYPE_FLOAT64, + chunkSizes, + numFeatures, + 1, + "", // parameters + null, + swigOutDatasetHandlePtr + ); + + assertThat(returnCodeLGBM).as("LightGBM return code").isEqualTo(0); + } finally { + lightgbmlib.delete_voidpp(swigOutDatasetHandlePtr); + } + } finally { + lightgbmlib.delete_intArray(chunkSizes); + } + } finally { + chunkedArray.delete(); } - lightgbmlib.intArray_setitem(chunkSizes, numChunks-1, (int)chunkedArray.get_current_chunk_added_count()); - - final SWIGTYPE_p_p_void swigOutDatasetHandlePtr = lightgbmlib.voidpp_handle();; - - final int returnCodeLGBM = lightgbmlib.LGBM_DatasetCreateFromMats( - (int)chunkedArray.get_chunks_count(), - chunkedArray.data_as_void(), - lightgbmlibConstants.C_API_DTYPE_FLOAT64, - chunkSizes, - numFeatures, - 1, - "", // parameters - null, - swigOutDatasetHandlePtr - ); - - assertThat(returnCodeLGBM).as("LightGBM return code").isEqualTo(0); } } diff --git a/openml-lightgbm/lightgbm-provider/src/test/java/com/feedzai/openml/provider/lightgbm/SWIGResourcesTest.java b/openml-lightgbm/lightgbm-provider/src/test/java/com/feedzai/openml/provider/lightgbm/SWIGResourcesTest.java index 08fa143e..cc41686b 100644 --- a/openml-lightgbm/lightgbm-provider/src/test/java/com/feedzai/openml/provider/lightgbm/SWIGResourcesTest.java +++ b/openml-lightgbm/lightgbm-provider/src/test/java/com/feedzai/openml/provider/lightgbm/SWIGResourcesTest.java @@ -18,6 +18,7 @@ package com.feedzai.openml.provider.lightgbm; import com.feedzai.openml.provider.exception.ModelLoadingException; +import org.junit.AfterClass; import org.junit.BeforeClass; import org.junit.Test; @@ -60,6 +61,15 @@ public static void setupFixture() throws ModelLoadingException, URISyntaxExcepti defaultSwig = new SWIGResources(modelPath.toString(), ""); } + /** + * Release/close resources + */ + @AfterClass + public static void afterClass() { + defaultSwig.close(); + } + + /** * Test SWIGResources() - all public members should be initialized. */ @@ -91,7 +101,7 @@ public void constructorThrowsModelLoadingExceptionOnInvalidModelPath() { public void closeResetsAllPublicMembers() throws ModelLoadingException { // Generate a new SWIGResources instance as it will be modified: - SWIGResources swig = new SWIGResources(modelPath.toString(), ""); + final SWIGResources swig = new SWIGResources(modelPath.toString(), ""); swig.close(); assertThat(swig.swigBoosterHandle).as("swigBoosterHandle").isNull(); diff --git a/openml-lightgbm/lightgbm-provider/src/test/java/com/feedzai/openml/provider/lightgbm/SWIGTrainDataTest.java b/openml-lightgbm/lightgbm-provider/src/test/java/com/feedzai/openml/provider/lightgbm/SWIGTrainDataTest.java index a0ad9f16..59322623 100644 --- a/openml-lightgbm/lightgbm-provider/src/test/java/com/feedzai/openml/provider/lightgbm/SWIGTrainDataTest.java +++ b/openml-lightgbm/lightgbm-provider/src/test/java/com/feedzai/openml/provider/lightgbm/SWIGTrainDataTest.java @@ -1,5 +1,6 @@ package com.feedzai.openml.provider.lightgbm; +import org.junit.After; import org.junit.Before; import org.junit.BeforeClass; import org.junit.Test; @@ -42,6 +43,14 @@ public void setupTest() { swigTrainData = new SWIGTrainData((int) NUM_FEATURES, NUM_INSTANCES_PER_CHUNK); } + /** + * Release/close resources. + */ + @After + public void tearDown() { + swigTrainData.close(); + } + /** * Assert the features ChunkedArray has the proper size. */