diff --git a/engine-plugins/model-plugins/src/main/scala/org/apache/spark/h2o/H2oModelData.scala b/engine-plugins/model-plugins/src/main/scala/org/apache/spark/h2o/H2oModelData.scala index e71f56d42f..0806bc26df 100644 --- a/engine-plugins/model-plugins/src/main/scala/org/apache/spark/h2o/H2oModelData.scala +++ b/engine-plugins/model-plugins/src/main/scala/org/apache/spark/h2o/H2oModelData.scala @@ -51,7 +51,7 @@ case class H2oModelData(modelName: String, pojo: String, labelColumn: String, ob * @return Generated model */ def toGenModel: GenModel = { - H2oModelCache.getGenmodel(this) + H2oModelCache.getCachedGenmodel(this) } /** @@ -59,15 +59,15 @@ case class H2oModelData(modelName: String, pojo: String, labelColumn: String, ob * @return Class files */ def getModelClassFiles: List[File] = { - H2oModelCache.getClassFiles(this) + H2oModelCache.getCachedClassFiles(this) } } object H2oModelCache { - + import H2oModelClassLoader._ //Object with cached class private case class GeneratedClass(genModel: GenModel, genModelDir: File) - private val maxCacheSize = 10 + private val maxCacheSize = 100 /** * A cache of generated classes. @@ -80,6 +80,7 @@ object H2oModelCache { */ private val cache = CacheBuilder.newBuilder() .maximumSize(maxCacheSize) + .weakValues() .removalListener(new RemovalListener[H2oModelData, GeneratedClass] { override def onRemoval(rm: RemovalNotification[H2oModelData, GeneratedClass]): Unit = { deleteDir(rm.getValue.genModelDir) @@ -110,15 +111,18 @@ object H2oModelCache { /** * Get generated H2O model from cache */ - def getGenmodel(data: H2oModelData): GenModel = cache.get(data).genModel + def getCachedGenmodel(data: H2oModelData): GenModel = cache.get(data).genModel + + def getCachedClassFiles(data: H2oModelData): List[File] = getClassFiles(cache.get(data).genModelDir) +} +object H2oModelClassLoader { /** * Get class files for generated H2O model from cache */ - def getClassFiles(data: H2oModelData): List[File] = { + def getClassFiles(pojoDir: File): List[File] = { var files = Array.empty[File] try { - val pojoDir = cache.get(data).genModelDir files = pojoDir.listFiles(new FilenameFilter() { @Override def accept(dir: File, name: String): Boolean = { @@ -150,7 +154,7 @@ object H2oModelCache { * Compile POJO * @return Directory with compiled classes */ - private def compilePojo(data: H2oModelData): File = { + def compilePojo(data: H2oModelData): File = { val tmpDir = Files.createTempDirectory(data.modelName) // write pojo to temporary file diff --git a/engine-plugins/model-plugins/src/main/scala/org/trustedanalytics/atk/engine/model/plugins/regression/H2oRandomForestRegressorTrainPlugin.scala b/engine-plugins/model-plugins/src/main/scala/org/trustedanalytics/atk/engine/model/plugins/regression/H2oRandomForestRegressorTrainPlugin.scala index 1133bae1f5..46f345b843 100644 --- a/engine-plugins/model-plugins/src/main/scala/org/trustedanalytics/atk/engine/model/plugins/regression/H2oRandomForestRegressorTrainPlugin.scala +++ b/engine-plugins/model-plugins/src/main/scala/org/trustedanalytics/atk/engine/model/plugins/regression/H2oRandomForestRegressorTrainPlugin.scala @@ -274,7 +274,7 @@ class H2oRandomForestRegressorLocalTrainPlugin extends CommandPlugin[H2oRandomFo val sparkLocalDir = tmpDir + "/" + appName + "/spark-local" val h2oLogDir = tmpDir + "/" + appName + "/h2o-logs" val conf = new SparkConf() - .setMaster("local") + .setMaster("local[*]") .setAppName(this.getClass.getSimpleName + " " + new Date()) conf.set("spark.driver.allowMultipleContexts", "true") conf.set("spark.serializer", "org.apache.spark.serializer.KryoSerializer")