Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -62,14 +62,6 @@ class NaiveBayesPredictPlugin extends SparkCommandPlugin[NaiveBayesPredictArgs,
override def name: String = "model:naive_bayes/predict"

override def apiMaturityTag = Some(ApiMaturityTag.Alpha)

/**
* Number of Spark jobs that get created by running this command
* (this configuration is used to prevent multiple progress bars in Python client)
*/

override def numberOfJobs(arguments: NaiveBayesPredictArgs)(implicit invocation: Invocation) = 9

/**
* Get the predictions for observations in a test frame
*
Expand All @@ -84,6 +76,7 @@ class NaiveBayesPredictPlugin extends SparkCommandPlugin[NaiveBayesPredictArgs,
val model: Model = arguments.model

// Loading model
require(!frame.rdd.isEmpty(), "Predict Frame is empty. Please predict on a non-empty Frame.")
val naiveBayesJsObject = model.dataOption.getOrElse(
throw new RuntimeException("This model has not been trained yet. Please train before trying to predict")
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,20 +51,6 @@ class NaiveBayesPublishPlugin extends CommandPlugin[ModelPublishArgs, ExportMeta
override def name: String = "model:naive_bayes/publish"

override def apiMaturityTag = Some(ApiMaturityTag.Alpha)

/**
* User documentation exposed in Python.
*
* [[http://docutils.sourceforge.net/rst.html ReStructuredText]]
*/

/**
* Number of Spark jobs that get created by running this command
* (this configuration is used to prevent multiple progress bars in Python client)
*/

override def numberOfJobs(arguments: ModelPublishArgs)(implicit invocation: Invocation) = 1

/**
* Get the predictions for observations in a test frame
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,13 +75,6 @@ class NaiveBayesTestPlugin extends SparkCommandPlugin[NaiveBayesTestArgs, Classi
override def name: String = "model:naive_bayes/test"

override def apiMaturityTag = Some(ApiMaturityTag.Alpha)

/**
* Number of Spark jobs that get created by running this command
* (this configuration is used to prevent multiple progress bars in Python client)
*/

override def numberOfJobs(arguments: NaiveBayesTestArgs)(implicit invocation: Invocation) = 9
/**
* Get the predictions for observations in a test frame
*
Expand All @@ -95,6 +88,7 @@ class NaiveBayesTestPlugin extends SparkCommandPlugin[NaiveBayesTestArgs, Classi
val model: Model = arguments.model
val frame: SparkFrame = arguments.frame

require(!frame.rdd.isEmpty(), "Test Frame is empty. Please test on a non-empty Frame.")
//Extracting the model and data to run on
val naiveBayesData = model.data.convertTo[NaiveBayesData]
val naiveBayesModel = naiveBayesData.naiveBayesModel
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,12 +64,6 @@ class NaiveBayesTrainPlugin extends SparkCommandPlugin[NaiveBayesTrainArgs, Unit

override def apiMaturityTag = Some(ApiMaturityTag.Alpha)

/**
* Number of Spark jobs that get created by running this command
* (this configuration is used to prevent multiple progress bars in Python client)
*/
override def numberOfJobs(arguments: NaiveBayesTrainArgs)(implicit invocation: Invocation) = 109

/**
* Run MLLib's NaiveBayes() on the training frame and create a Model for it.
*
Expand All @@ -84,6 +78,7 @@ class NaiveBayesTrainPlugin extends SparkCommandPlugin[NaiveBayesTrainArgs, Unit
val model: Model = arguments.model

//create RDD from the frame
require(!frame.rdd.isEmpty(), "Train Frame is empty. Please train on a non-empty Frame.")
val labeledTrainRdd: RDD[LabeledPoint] = frame.rdd.toLabeledPointRDD(arguments.labelColumn, arguments.observationColumns)

//Running MLLib
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,13 +61,6 @@ class RandomForestClassifierPredictPlugin extends SparkCommandPlugin[RandomFores

override def apiMaturityTag = Some(ApiMaturityTag.Alpha)

/**
* Number of Spark jobs that get created by running this command
* (this configuration is used to prevent multiple progress bars in Python client)
*/

override def numberOfJobs(arguments: RandomForestClassifierPredictArgs)(implicit invocation: Invocation) = 9

/**
* Get the predictions for observations in a test frame
*
Expand All @@ -82,6 +75,7 @@ class RandomForestClassifierPredictPlugin extends SparkCommandPlugin[RandomFores
val frame: SparkFrame = arguments.frame

//Running MLLib
require(!frame.rdd.isEmpty(), "Predict Frame is empty. Please predict on a non-empty Frame.")
val rfData = model.readFromStorage().convertTo[RandomForestClassifierData]
val rfModel = rfData.randomForestModel
if (arguments.observationColumns.isDefined) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,20 +51,6 @@ class RandomForestClassifierPublishPlugin extends CommandPlugin[ModelPublishArgs
override def name: String = "model:random_forest_classifier/publish"

override def apiMaturityTag = Some(ApiMaturityTag.Beta)

/**
* User documentation exposed in Python.
*
* [[http://docutils.sourceforge.net/rst.html ReStructuredText]]
*/

/**
* Number of Spark jobs that get created by running this command
* (this configuration is used to prevent multiple progress bars in Python client)
*/

override def numberOfJobs(arguments: ModelPublishArgs)(implicit invocation: Invocation) = 1

/**
* Get the predictions for observations in a test frame
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -77,12 +77,6 @@ class RandomForestClassifierTestPlugin extends SparkCommandPlugin[RandomForestCl

override def apiMaturityTag = Some(ApiMaturityTag.Alpha)

/**
* Number of Spark jobs that get created by running this command
* (this configuration is used to prevent multiple progress bars in Python client)
*/

override def numberOfJobs(arguments: RandomForestClassifierTestArgs)(implicit invocation: Invocation) = 9
/**
* Get the predictions for observations in a test frame
*
Expand All @@ -97,6 +91,7 @@ class RandomForestClassifierTestPlugin extends SparkCommandPlugin[RandomForestCl
val frame: SparkFrame = arguments.frame

//Extracting the model and data to run on
require(!frame.rdd.isEmpty(), "Test Frame is empty. Please test on a non-empty Frame.")
val rfData = model.readFromStorage().convertTo[RandomForestClassifierData]
val rfModel = rfData.randomForestModel
if (arguments.observationColumns.isDefined) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,12 +63,6 @@ class RandomForestClassifierTrainPlugin extends SparkCommandPlugin[RandomForestC

override def apiMaturityTag = Some(ApiMaturityTag.Alpha)

/**
* Number of Spark jobs that get created by running this command
* (this configuration is used to prevent multiple progress bars in Python client)
*/
override def numberOfJobs(arguments: RandomForestClassifierTrainArgs)(implicit invocation: Invocation) = 109

/**
* Run MLLib's RandomForest classifier on the training frame and create a Model for it.
*
Expand All @@ -83,6 +77,7 @@ class RandomForestClassifierTrainPlugin extends SparkCommandPlugin[RandomForestC
val model: Model = arguments.model

//create RDD from the frame
require(!frame.rdd.isEmpty(), "Train Frame is empty. Please train on a non-empty Frame.")
val labeledTrainRdd: RDD[LabeledPoint] = frame.rdd.toLabeledPointRDD(arguments.labelColumn, arguments.observationColumns)
val randomForestModel = RandomForest.trainClassifier(labeledTrainRdd, arguments.numClasses, arguments.getCategoricalFeaturesInfo, arguments.numTrees,
arguments.getFeatureSubsetCategory, arguments.impurity, arguments.maxDepth, arguments.maxBins, arguments.seed)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,13 +47,6 @@ class SVMWithSGDPredictPlugin extends SparkCommandPlugin[ClassificationWithSGDPr

override def apiMaturityTag = Some(ApiMaturityTag.Alpha)

/**
* Number of Spark jobs that get created by running this command
* (this configuration is used to prevent multiple progress bars in Python client)
*/

override def numberOfJobs(arguments: ClassificationWithSGDPredictArgs)(implicit invocation: Invocation) = 1

/**
* Get the predictions for observations in a test frame
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,20 +52,6 @@ class SVMWithSGDPublishPlugin extends CommandPlugin[ModelPublishArgs, ExportMeta
override def name: String = "model:svm/publish"

override def apiMaturityTag = Some(ApiMaturityTag.Beta)

/**
* User documentation exposed in Python.
*
* [[http://docutils.sourceforge.net/rst.html ReStructuredText]]
*/

/**
* Number of Spark jobs that get created by running this command
* (this configuration is used to prevent multiple progress bars in Python client)
*/

override def numberOfJobs(arguments: ModelPublishArgs)(implicit invocation: Invocation) = 1

/**
* Get the predictions for observations in a test frame
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,13 +58,6 @@ class SVMWithSGDTestPlugin extends SparkCommandPlugin[ClassificationWithSGDTestA
override def name: String = "model:svm/test"

override def apiMaturityTag = Some(ApiMaturityTag.Alpha)

/**
* Number of Spark jobs that get created by running this command
* (this configuration is used to prevent multiple progress bars in Python client)
*/

override def numberOfJobs(arguments: ClassificationWithSGDTestArgs)(implicit invocation: Invocation) = 1
/**
* Get the predictions for observations in a test frame
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,13 +48,6 @@ class SVMWithSGDTrainPlugin extends SparkCommandPlugin[ClassificationWithSGDTrai
override def name: String = "model:svm/train"

override def apiMaturityTag = Some(ApiMaturityTag.Alpha)

/**
* Number of Spark jobs that get created by running this command
* (this configuration is used to prevent multiple progress bars in Python client)
*/
override def numberOfJobs(arguments: ClassificationWithSGDTrainArgs)(implicit invocation: Invocation) = 1

/**
* Run MLLib's SVMWithSGD() on the training frame and create a Model for it.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,13 +46,6 @@ class LogisticRegressionPredictPlugin extends SparkCommandPlugin[ClassificationW

override def apiMaturityTag = Some(ApiMaturityTag.Alpha)

/**
* Number of Spark jobs that get created by running this command
* (this configuration is used to prevent multiple progress bars in Python client)
*/

override def numberOfJobs(arguments: ClassificationWithSGDPredictArgs)(implicit invocation: Invocation) = 9

/**
* Get the predictions for observations in a test frame
*
Expand All @@ -66,6 +59,7 @@ class LogisticRegressionPredictPlugin extends SparkCommandPlugin[ClassificationW
val frame: SparkFrame = arguments.frame
val model: Model = arguments.model

require(!frame.rdd.isEmpty(), "Predict Frame is empty. Please predict on a non-empty Frame.")
//Running MLLib
val logRegData = model.data.convertTo[LogisticRegressionData]
val logRegModel = logRegData.logRegModel
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,12 +58,6 @@ class LogisticRegressionTestPlugin extends SparkCommandPlugin[ClassificationWith
override def name: String = "model:logistic_regression/test"

override def apiMaturityTag = Some(ApiMaturityTag.Alpha)
/**
* Number of Spark jobs that get created by running this command
* (this configuration is used to prevent multiple progress bars in Python client)
*/

override def numberOfJobs(arguments: ClassificationWithSGDTestArgs)(implicit invocation: Invocation) = 9
/**
* Get the predictions for observations in a test frame
*
Expand All @@ -77,6 +71,7 @@ class LogisticRegressionTestPlugin extends SparkCommandPlugin[ClassificationWith
val frame: SparkFrame = arguments.frame
val model: Model = arguments.model

require(!frame.rdd.isEmpty(), "Test Frame is empty. Please test on a non-empty Frame.")
val logRegData = model.data.convertTo[LogisticRegressionData]
val logRegModel = logRegData.logRegModel
if (arguments.observationColumns.isDefined) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,12 +57,6 @@ class LogisticRegressionTrainPlugin extends SparkCommandPlugin[LogisticRegressio
override def name: String = "model:logistic_regression/train"

override def apiMaturityTag = Some(ApiMaturityTag.Alpha)

/**
* Number of Spark jobs that get created by running this command
* (this configuration is used to prevent multiple progress bars in Python client)
*/
override def numberOfJobs(arguments: LogisticRegressionTrainArgs)(implicit invocation: Invocation) = arguments.numIterations + 5
/**
* Run MLLib's LogisticRegressionWithSGD() on the training frame and create a Model for it.
*
Expand All @@ -77,6 +71,7 @@ class LogisticRegressionTrainPlugin extends SparkCommandPlugin[LogisticRegressio
val model: Model = arguments.model

//create RDD from the frame
require(!frame.rdd.isEmpty(), "Train Frame is empty. Please train on a non-empty Frame.")
val labeledTrainRdd = frame.rdd.toLabeledPointRDDWithFrequency(arguments.labelColumn,
arguments.observationColumns, arguments.frequencyColumn)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,6 @@ import MLLibJsonProtocol._
predicted_cluster : int
Integer containing the cluster assignment.""")
class GMMPredictPlugin extends SparkCommandPlugin[GMMPredictArgs, FrameReference] {

/**
* The name of the command.
*
Expand All @@ -55,19 +54,6 @@ class GMMPredictPlugin extends SparkCommandPlugin[GMMPredictArgs, FrameReference

override def apiMaturityTag = Some(ApiMaturityTag.Beta)

/**
* User documentation exposed in Python.
*
* [[http://docutils.sourceforge.net/rst.html ReStructuredText]]
*/

/**
* Number of Spark jobs that get created by running this command
* (this configuration is used to prevent multiple progress bars in Python client)
*/

override def numberOfJobs(arguments: GMMPredictArgs)(implicit invocation: Invocation) = 1

/**
* Get the predictions for observations in a test frame
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,20 +53,6 @@ class GMMTrainPlugin extends SparkCommandPlugin[GMMTrainArgs, GMMTrainReturn] {
override def name: String = "model:gmm/train"

override def apiMaturityTag = Some(ApiMaturityTag.Beta)

/**
* User documentation exposed in Python.
*
* [[http://docutils.sourceforge.net/rst.html ReStructuredText]]
*/

/**
* Number of Spark jobs that get created by running this command
*
* (this configuration is used to prevent multiple progress bars in Python client)
*/
override def numberOfJobs(arguments: GMMTrainArgs)(implicit invocation: Invocation) = 1

/**
* Run MLLib's GaussianMixtureModel() on the training frame and create a Model for it.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,6 @@ import scala.collection.mutable.ListBuffer
'k' columns : Each of the 'k' columns containing squared distance of that observation to the 'k'th cluster center
predicted_cluster column: The cluster assignment for the observation""")
class KMeansPredictPlugin extends SparkCommandPlugin[KMeansPredictArgs, FrameReference] {

/**
* The name of the command.
*
Expand All @@ -55,20 +54,6 @@ class KMeansPredictPlugin extends SparkCommandPlugin[KMeansPredictArgs, FrameRef
override def name: String = "model:k_means/predict"

override def apiMaturityTag = Some(ApiMaturityTag.Beta)

/**
* User documentation exposed in Python.
*
* [[http://docutils.sourceforge.net/rst.html ReStructuredText]]
*/

/**
* Number of Spark jobs that get created by running this command
* (this configuration is used to prevent multiple progress bars in Python client)
*/

override def numberOfJobs(arguments: KMeansPredictArgs)(implicit invocation: Invocation) = 1

/**
* Get the predictions for observations in a test frame
*
Expand All @@ -82,6 +67,7 @@ class KMeansPredictPlugin extends SparkCommandPlugin[KMeansPredictArgs, FrameRef
val frame: SparkFrame = arguments.frame
val model: Model = arguments.model

require(!frame.rdd.isEmpty(), "Predict Frame is empty. Please predict on a non-empty Frame.")
//Extracting the KMeansModel from the stored JsObject
val kmeansData = model.data.convertTo[KMeansData]
val kmeansModel = kmeansData.kMeansModel
Expand Down
Loading