Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -51,23 +51,23 @@ case class H2oModelData(modelName: String, pojo: String, labelColumn: String, ob
* @return Generated model
*/
def toGenModel: GenModel = {
H2oModelCache.getGenmodel(this)
H2oModelCache.getCachedGenmodel(this)
}

/**
* Get list of class files for generated model
* @return Class files
*/
def getModelClassFiles: List[File] = {
H2oModelCache.getClassFiles(this)
H2oModelCache.getCachedClassFiles(this)
}
}

object H2oModelCache {

import H2oModelClassLoader._
//Object with cached class
private case class GeneratedClass(genModel: GenModel, genModelDir: File)
private val maxCacheSize = 10
private val maxCacheSize = 100

/**
* A cache of generated classes.
Expand All @@ -80,6 +80,7 @@ object H2oModelCache {
*/
private val cache = CacheBuilder.newBuilder()
.maximumSize(maxCacheSize)
.weakValues()
.removalListener(new RemovalListener[H2oModelData, GeneratedClass] {
override def onRemoval(rm: RemovalNotification[H2oModelData, GeneratedClass]): Unit = {
deleteDir(rm.getValue.genModelDir)
Expand Down Expand Up @@ -110,15 +111,18 @@ object H2oModelCache {
/**
* Get generated H2O model from cache
*/
def getGenmodel(data: H2oModelData): GenModel = cache.get(data).genModel
def getCachedGenmodel(data: H2oModelData): GenModel = cache.get(data).genModel

def getCachedClassFiles(data: H2oModelData): List[File] = getClassFiles(cache.get(data).genModelDir)
}

object H2oModelClassLoader {
/**
* Get class files for generated H2O model from cache
*/
def getClassFiles(data: H2oModelData): List[File] = {
def getClassFiles(pojoDir: File): List[File] = {
var files = Array.empty[File]
try {
val pojoDir = cache.get(data).genModelDir
files = pojoDir.listFiles(new FilenameFilter() {
@Override
def accept(dir: File, name: String): Boolean = {
Expand Down Expand Up @@ -150,7 +154,7 @@ object H2oModelCache {
* Compile POJO
* @return Directory with compiled classes
*/
private def compilePojo(data: H2oModelData): File = {
def compilePojo(data: H2oModelData): File = {
val tmpDir = Files.createTempDirectory(data.modelName)

// write pojo to temporary file
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -274,7 +274,7 @@ class H2oRandomForestRegressorLocalTrainPlugin extends CommandPlugin[H2oRandomFo
val sparkLocalDir = tmpDir + "/" + appName + "/spark-local"
val h2oLogDir = tmpDir + "/" + appName + "/h2o-logs"
val conf = new SparkConf()
.setMaster("local")
.setMaster("local[*]")
.setAppName(this.getClass.getSimpleName + " " + new Date())
conf.set("spark.driver.allowMultipleContexts", "true")
conf.set("spark.serializer", "org.apache.spark.serializer.KryoSerializer")
Expand Down