Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

package org.apache.spark.ml.atk.plugins

import org.apache.spark.ml.regression.CoxModel
import org.apache.spark.ml.regression.CoxPhModel
import org.apache.spark.mllib.atk.plugins.MLLibJsonProtocol
import org.apache.spark.mllib.atk.plugins.MLLibJsonProtocol.VectorFormat
import org.trustedanalytics.atk.domain.DomainJsonProtocol._
Expand Down Expand Up @@ -52,9 +52,9 @@ object MLJsonProtocol {
}
}

implicit object CoxModelFormat extends JsonFormat[CoxModel] {
implicit object CoxModelFormat extends JsonFormat[CoxPhModel] {

override def write(obj: CoxModel): JsValue = {
override def write(obj: CoxPhModel): JsValue = {
val beta = VectorFormat.write(obj.beta)
val mean = VectorFormat.write(obj.meanVector)
JsObject(
Expand All @@ -64,7 +64,7 @@ object MLJsonProtocol {
)
}

override def read(json: JsValue): org.apache.spark.ml.regression.CoxModel = {
override def read(json: JsValue): org.apache.spark.ml.regression.CoxPhModel = {
val fields = json.asJsObject.fields
val uid = getOrInvalid(fields, "uid").convertTo[String]
val beta = fields.get("beta").map(v => {
Expand All @@ -75,7 +75,7 @@ object MLJsonProtocol {
VectorFormat.read(v)
}).get

new CoxModel(uid, beta, mean)
new CoxPhModel(uid, beta, mean)
}
}

Expand Down
Loading