Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -130,21 +130,19 @@ static void fit(final Dataset dataset,
final int numIterations = parseInt(params.get(NUM_ITERATIONS_PARAMETER_NAME));
logger.debug("LightGBM model trainParams: {}", trainParams);

final SWIGTrainData swigTrainData = new SWIGTrainData(
numFeatures,
instancesPerChunk);
final SWIGTrainBooster swigTrainBooster = new SWIGTrainBooster();

/// Create LightGBM dataset
createTrainDataset(dataset, numFeatures, trainParams, swigTrainData);
try (final SWIGTrainBooster swigTrainBooster = new SWIGTrainBooster();
final SWIGTrainData swigTrainData = new SWIGTrainData(numFeatures, instancesPerChunk)){
/// Create LightGBM dataset
createTrainDataset(dataset, numFeatures, trainParams, swigTrainData);

/// Create Booster from dataset
createBoosterStructure(swigTrainBooster, swigTrainData, trainParams);
trainBooster(swigTrainBooster.swigBoosterHandle, numIterations);
/// Create Booster from dataset
createBoosterStructure(swigTrainBooster, swigTrainData, trainParams);
trainBooster(swigTrainBooster.swigBoosterHandle, numIterations);

/// Save model
saveModelFileToDisk(swigTrainBooster.swigBoosterHandle, outputModelFilePath);
swigTrainBooster.close(); // Explicitly release C++ resources right away. They're no longer needed.
/// Save model
saveModelFileToDisk(swigTrainBooster.swigBoosterHandle, outputModelFilePath);
}
}

/**
Expand Down Expand Up @@ -231,28 +229,30 @@ private static void initializeLightGBMTrainDatasetFeatures(final SWIGTrainData s

/// First generate the array that has the chunk sizes for `LGBM_DatasetCreateFromMats`.
final SWIGTYPE_p_int swigChunkSizesArray = genSWIGFeatureChunkSizesArray(swigTrainData, numFeatures);

/// Now create the LightGBM Dataset itself from the chunks:
logger.debug("Creating LGBM_Dataset from chunked data...");
final int returnCodeLGBM = lightgbmlib.LGBM_DatasetCreateFromMats(
(int) swigTrainData.swigFeaturesChunkedArray.get_chunks_count(), // numChunks
swigTrainData.swigFeaturesChunkedArray.data_as_void(), // input data (void**)
lightgbmlibConstants.C_API_DTYPE_FLOAT64,
swigChunkSizesArray,
numFeatures,
1, // rowMajor.
trainParams, // parameters.
null, // No alighment with other datasets.
swigTrainData.swigOutDatasetHandlePtr // Output LGBM Dataset
);
if (returnCodeLGBM == -1) {
logger.error("Could not create LightGBM dataset.");
throw new LightGBMException();
try {
/// Now create the LightGBM Dataset itself from the chunks:
logger.debug("Creating LGBM_Dataset from chunked data...");
final int returnCodeLGBM = lightgbmlib.LGBM_DatasetCreateFromMats(
(int) swigTrainData.swigFeaturesChunkedArray.get_chunks_count(), // numChunks
swigTrainData.swigFeaturesChunkedArray.data_as_void(), // input data (void**)
lightgbmlibConstants.C_API_DTYPE_FLOAT64,
swigChunkSizesArray,
numFeatures,
1, // rowMajor.
trainParams, // parameters.
null, // No alighment with other datasets.
swigTrainData.swigOutDatasetHandlePtr // Output LGBM Dataset
);
if (returnCodeLGBM == -1) {
logger.error("Could not create LightGBM dataset.");
throw new LightGBMException();
}
// FIXME is this init necessary?
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes this is necessary. It's in here we pass from a void ** to a LGBM DatasetHandle (https://lightgbm.readthedocs.io/en/latest/C-API.html#c.DatasetHandle) so we can pass it around with SWIG.

swigTrainData.initSwigDatasetHandle();
} finally {
lightgbmlib.delete_intArray(swigChunkSizesArray);
}

swigTrainData.initSwigDatasetHandle();
swigTrainData.destroySwigTrainFeaturesChunkedDataArray();
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@AlbertoEAF is there any reason for calling destroySwigTrainFeaturesChunkedDataArray() here? it looks to be an array that lives within the entire lifecycle of a SWIGTrainData instance and it is released in the close method.

If that is the case and to simplify and be more safe I would remove this destroySwigTrainFeaturesChunkedDataArray() because it may not be necessary and can cause confusion.

Copy link
Contributor

@AlbertoEAF AlbertoEAF Jun 2, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, it's wrong to delete this line. It'll unnecessarily duplicate RAM usage on the heaviest part of the job. Maybe add an inline comment with ("// Critical to save RAM. Do not remove!") along with a comment above explaining why it's here in the first place.

Transferring train data to a LightGBM Dataset has several stages:

  1. Transfer data from a Pulse (OpenML) Dataset iterator (which doesn't know how large the dataset is) instance by instance to our ChunkedArrays. A ChunkedArray works as dynamically growing array of buffers automatically managed on the C++ side.
  2. Create the LGBM Dataset data structure from such ChunkedArrays (buffers). // In this stage we achieve peak RAM usage (2x copies of the dataset => ~2x memory usage)
  3. Delete the buffers as soon as possible (destroySwigTrainFeaturesChunkedDataArray()). Now we only have the train data on a LightGBM Dataset. (we're back to only one copy of the train dataset 👍 )

Finally we can train the LGBM model from the LGBM Dataset, which can still use ample memory during train, but won't ever reach the 2x memory usage factor.

lightgbmlib.delete_intArray(swigChunkSizesArray);
}

/**
Expand Down Expand Up @@ -313,7 +313,6 @@ private static void setLightGBMDatasetLabelData(final SWIGTrainData swigTrainDat
throw new LightGBMException();
}

swigTrainData.destroySwigTrainLabelDataArray();
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we shouldn't remove this line, instead we should:

 swigTrainData.initSwigTrainLabelDataArray(); // Init and copy from chunked data.
 try {
    (...)
 } finally {
    swigTrainData.destroySwigTrainLabelDataArray();
 }

The reason is because initSwigTrainLabelDataArray() actually creates a new array and clones the chuncks, and if I understood well the code, if we call this method twice without calling destroySwigTrainLabelDataArray we will leak an array of data.

We can think of an improvement by having a closeable class for this case too.

Copy link
Contributor

@AlbertoEAF AlbertoEAF Jun 2, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes this is also the same situation as above @mlobofeedzai and should not be removed for the same reasons. I'd say, put a comment in the same manner as in the other case so there are no doubts as to why it's there.

I agree with @gandola , changing this to a try/finally is more idiomatic and better behaved in case of prior exceptions 👌.

}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,11 @@ class SWIGResources implements AutoCloseable {
*/
private Integer boosterNumFeatures;

/**
* Number of classes in the trained LGBM.
*/
private Integer boosterNumClasses = null;

/**
* Names of features in the trained LightGBM boosting model.
* Whilst not a swig resource, it is automatically retrieved during model loading,
Expand Down Expand Up @@ -252,19 +257,21 @@ private void initBoosterFastContributionsHandle(final String LightGBMParameters)
* Assumes the model was already loaded from file.
* Initializes the remaining SWIG resources needed to use the model.
*
* The size of {@link #swigOutContributionsPtr} is computed accoring to
* https://lightgbm.readthedocs.io/en/latest/C-API.html#c.LGBM_BoosterPredictForMatSingleRow
*
* @throws LightGBMException in case there's an error in the C++ core library.
*/
private void initAuxiliaryModelResources() throws LightGBMException {

this.boosterNumFeatures = computeBoosterNumFeaturesFromModel();
computeBoosterNumFeaturesFromModel();
logger.debug("Loaded LightGBM Model has {} features.", this.boosterNumFeatures);

this.boosterFeatureNames = computeBoosterFeatureNamesFromModel();

this.swigOutLengthInt64Ptr = lightgbmlibJNI.new_int64_tp();
this.swigInstancePtr = lightgbmlibJNI.new_doubleArray(getBoosterNumFeatures());
this.swigOutScoresPtr = lightgbmlibJNI.new_doubleArray(BINARY_LGBM_NUM_CLASSES);
this.swigOutContributionsPtr = lightgbmlibJNI.new_doubleArray(this.boosterNumFeatures);
this.swigOutScoresPtr = lightgbmlibJNI.new_doubleArray(this.boosterNumClasses);
this.swigOutContributionsPtr = lightgbmlibJNI.new_doubleArray((long) this.boosterNumClasses * (this.boosterNumFeatures + 1));
}

/**
Expand Down Expand Up @@ -302,6 +309,14 @@ private void releaseInitializedSWIGResources() throws LightGBMException {
lightgbmlibJNI.delete_intp(this.swigOutIntPtr);
this.swigOutIntPtr = null;
}
if (this.boosterNumFeatures != null) {
lightgbmlibJNI.delete_intp(this.boosterNumFeatures);
this.boosterNumFeatures = null;
}
if (this.boosterNumClasses != null) {
lightgbmlibJNI.delete_intp(this.boosterNumClasses);
this.boosterNumClasses = null;
}
if (this.swigOutContributionsPtr != null) {
lightgbmlibJNI.delete_doubleArray(this.swigOutContributionsPtr);
this.swigOutContributionsPtr = null;
Expand Down Expand Up @@ -373,17 +388,31 @@ public String[] getBoosterFeatureNames() {
* Computes the number of features in the model and returns it.
*
* @throws LightGBMException when there is a LightGBM C++ error.
* @returns int with the number of Booster features.
*/
private Integer computeBoosterNumFeaturesFromModel() throws LightGBMException {

final int returnCodeLGBM = lightgbmlibJNI.LGBM_BoosterGetNumFeature(
private void computeBoosterNumFeaturesFromModel() throws LightGBMException {
final int returnCodeNumFeatsLGBM = lightgbmlibJNI.LGBM_BoosterGetNumFeature(
this.swigBoosterHandle,
this.swigOutIntPtr);
if (returnCodeLGBM == -1)
if (returnCodeNumFeatsLGBM == -1)
throw new LightGBMException();

return lightgbmlibJNI.intp_value(this.swigOutIntPtr);

if (this.boosterNumFeatures != null) {
lightgbmlibJNI.delete_intp(this.boosterNumFeatures);
this.boosterNumFeatures = null;
}
this.boosterNumFeatures = lightgbmlibJNI.intp_value(this.swigOutIntPtr);

final int returnCodeNumClassesLGBM = lightgbmlibJNI.LGBM_BoosterGetNumClasses(
this.swigBoosterHandle,
this.swigOutIntPtr);
if (returnCodeNumClassesLGBM == -1)
throw new LightGBMException();
if (this.boosterNumClasses != null) {
lightgbmlibJNI.delete_intp(this.boosterNumClasses);
this.boosterNumClasses = null;
}
this.boosterNumClasses = lightgbmlibJNI.intp_value(this.swigOutIntPtr);
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,24 +64,27 @@ public void doubleChunkedArrayValuesOK() {
final int max_i = 10;
final long chunk_size = 3;
final doubleChunkedArray chunkedArray = new doubleChunkedArray(chunk_size);
try {
for (int i = 1; i <= max_i; ++i) {
chunkedArray.add(i * 1.1);
}

for (int i = 1; i <= max_i; ++i) {
chunkedArray.add(i * 1.1);
}

int chunk = 0;
int pos = 0;
for (int i = 0; i < max_i; ++i) {
final double ref_value = (i+1) * 1.1;
assertThat(chunkedArray.getitem(chunk, pos, -1))
.as("Value at chunk %d, position %d", chunk, pos)
.isCloseTo(ref_value, Offset.offset(1e-3));

++pos;
if (pos == chunk_size) {
++chunk;
pos = 0;
int chunk = 0;
int pos = 0;
for (int i = 0; i < max_i; ++i) {
final double ref_value = (i + 1) * 1.1;
assertThat(chunkedArray.getitem(chunk, pos, -1))
.as("Value at chunk %d, position %d", chunk, pos)
.isCloseTo(ref_value, Offset.offset(1e-3));

++pos;
if (pos == chunk_size) {
++chunk;
pos = 0;
}
}
} finally {
chunkedArray.delete();
}
}

Expand All @@ -94,14 +97,18 @@ public void doubleChunkedArrayOutOfBoundsError() {
final double on_fail_sentinel_value = -1;
final doubleChunkedArray chunkedArray = new doubleChunkedArray(3);

// Test out of bounds chunk (only 1 exists, not 11):
assertThat(chunkedArray.getitem(10, 0, on_fail_sentinel_value))
.as("out-of-bounds return sentinel value")
.isCloseTo(on_fail_sentinel_value, Offset.offset(1e-3));
// Test out of bounds on first chunk:
assertThat(chunkedArray.getitem(0, 10, on_fail_sentinel_value))
.as("out-of-bounds return sentinel value")
.isCloseTo(on_fail_sentinel_value, Offset.offset(1e-3));
try {
// Test out of bounds chunk (only 1 exists, not 11):
assertThat(chunkedArray.getitem(10, 0, on_fail_sentinel_value))
.as("out-of-bounds return sentinel value")
.isCloseTo(on_fail_sentinel_value, Offset.offset(1e-3));
// Test out of bounds on first chunk:
assertThat(chunkedArray.getitem(0, 10, on_fail_sentinel_value))
.as("out-of-bounds return sentinel value")
.isCloseTo(on_fail_sentinel_value, Offset.offset(1e-3));
} finally {
chunkedArray.delete();
}
}

/**
Expand All @@ -112,19 +119,27 @@ public void doubleChunkedArrayOutOfBoundsError() {
@Test
public void ChunkedArrayCoalesceTo() {
final int numFeatures = 3;
final int chunkSize = 2*numFeatures; // Must be multiple
final int chunkSize = 2 * numFeatures; // Must be multiple
final doubleChunkedArray chunkedArray = new doubleChunkedArray(chunkSize);
// Fill 1 chunk + some part of other
for (int i = 0; i < chunkSize + 1; ++i) {
chunkedArray.add(i);
}
final SWIGTYPE_p_double swigArr = lightgbmlib.new_doubleArray(chunkedArray.get_add_count());
try {
// Fill 1 chunk + some part of other
for (int i = 0; i < chunkSize + 1; ++i) {
chunkedArray.add(i);
}
final SWIGTYPE_p_double swigArr = lightgbmlib.new_doubleArray(chunkedArray.get_add_count());
try {

chunkedArray.coalesce_to(swigArr);
chunkedArray.coalesce_to(swigArr);

for (int i = 0; i < chunkedArray.get_add_count(); ++i) {
double v = lightgbmlib.doubleArray_getitem(swigArr, i);
assertThat(v).as("coalescedArray[%d]", i).isCloseTo(i, Offset.offset(1e-3));
for (int i = 0; i < chunkedArray.get_add_count(); ++i) {
double v = lightgbmlib.doubleArray_getitem(swigArr, i);
assertThat(v).as("coalescedArray[%d]", i).isCloseTo(i, Offset.offset(1e-3));
}
} finally {
lightgbmlib.delete_doubleArray(swigArr);
}
} finally {
chunkedArray.delete();
}
}

Expand All @@ -135,35 +150,46 @@ public void ChunkedArrayCoalesceTo() {
@Test
public void LGBM_DatasetCreateFromMatsFromChunkedArray() {
final int numFeatures = 3;
final int chunkSize = 2*numFeatures; // Must be multiple
final int chunkSize = 2 * numFeatures; // Must be multiple
final doubleChunkedArray chunkedArray = new doubleChunkedArray(chunkSize);
// Fill 1 chunk + some part of other
for (int i = 0; i < chunkSize + 1; ++i) {
chunkedArray.add(i);
}
try {
// Fill 1 chunk + some part of other
for (int i = 0; i < chunkSize + 1; ++i) {
chunkedArray.add(i);
}

final long numChunks = chunkedArray.get_chunks_count();
SWIGTYPE_p_int chunkSizes = lightgbmlib.new_intArray(numChunks);
for (int i = 0; i < numChunks - 1; ++i) {
lightgbmlib.intArray_setitem(chunkSizes, i, chunkSize);
final long numChunks = chunkedArray.get_chunks_count();
final SWIGTYPE_p_int chunkSizes = lightgbmlib.new_intArray(numChunks);
try {
for (int i = 0; i < numChunks - 1; ++i) {
lightgbmlib.intArray_setitem(chunkSizes, i, chunkSize);
}
lightgbmlib.intArray_setitem(chunkSizes, numChunks - 1, (int) chunkedArray.get_current_chunk_added_count());

final SWIGTYPE_p_p_void swigOutDatasetHandlePtr = lightgbmlib.voidpp_handle();
try {
final int returnCodeLGBM = lightgbmlib.LGBM_DatasetCreateFromMats(
(int) chunkedArray.get_chunks_count(),
chunkedArray.data_as_void(),
lightgbmlibConstants.C_API_DTYPE_FLOAT64,
chunkSizes,
numFeatures,
1,
"", // parameters
null,
swigOutDatasetHandlePtr
);

assertThat(returnCodeLGBM).as("LightGBM return code").isEqualTo(0);
} finally {
lightgbmlib.delete_voidpp(swigOutDatasetHandlePtr);
}
} finally {
lightgbmlib.delete_intArray(chunkSizes);
}
} finally {
chunkedArray.delete();
}
lightgbmlib.intArray_setitem(chunkSizes, numChunks-1, (int)chunkedArray.get_current_chunk_added_count());

final SWIGTYPE_p_p_void swigOutDatasetHandlePtr = lightgbmlib.voidpp_handle();;

final int returnCodeLGBM = lightgbmlib.LGBM_DatasetCreateFromMats(
(int)chunkedArray.get_chunks_count(),
chunkedArray.data_as_void(),
lightgbmlibConstants.C_API_DTYPE_FLOAT64,
chunkSizes,
numFeatures,
1,
"", // parameters
null,
swigOutDatasetHandlePtr
);

assertThat(returnCodeLGBM).as("LightGBM return code").isEqualTo(0);
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
package com.feedzai.openml.provider.lightgbm;

import com.feedzai.openml.provider.exception.ModelLoadingException;
import org.junit.AfterClass;
import org.junit.BeforeClass;
import org.junit.Test;

Expand Down Expand Up @@ -60,6 +61,15 @@ public static void setupFixture() throws ModelLoadingException, URISyntaxExcepti
defaultSwig = new SWIGResources(modelPath.toString(), "");
}

/**
* Release/close resources
*/
@AfterClass
public static void afterClass() {
defaultSwig.close();
}


/**
* Test SWIGResources() - all public members should be initialized.
*/
Expand Down Expand Up @@ -91,7 +101,7 @@ public void constructorThrowsModelLoadingExceptionOnInvalidModelPath() {
public void closeResetsAllPublicMembers() throws ModelLoadingException {

// Generate a new SWIGResources instance as it will be modified:
SWIGResources swig = new SWIGResources(modelPath.toString(), "");
final SWIGResources swig = new SWIGResources(modelPath.toString(), "");
swig.close();

assertThat(swig.swigBoosterHandle).as("swigBoosterHandle").isNull();
Expand Down
Loading