aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala
diff options
context:
space:
mode:
Diffstat (limited to 'mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala179
1 files changed, 162 insertions, 17 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala
index c41a611f1c..4de1b877b0 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala
@@ -21,22 +21,24 @@ import java.util.UUID
import scala.language.existentials
+import org.apache.hadoop.fs.Path
+import org.json4s.{DefaultFormats, JObject, _}
+import org.json4s.JsonDSL._
+import org.json4s.jackson.JsonMethods._
+
+import org.apache.spark.SparkContext
import org.apache.spark.annotation.{Experimental, Since}
import org.apache.spark.ml._
import org.apache.spark.ml.attribute._
-import org.apache.spark.ml.param.{Param, ParamMap}
-import org.apache.spark.ml.util.{Identifiable, MetadataUtils}
+import org.apache.spark.ml.param.{Param, ParamMap, ParamPair, Params}
+import org.apache.spark.ml.util._
import org.apache.spark.mllib.linalg.Vector
-import org.apache.spark.sql.{DataFrame, Row}
+import org.apache.spark.sql.{DataFrame, Dataset, Row}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types._
import org.apache.spark.storage.StorageLevel
-/**
- * Params for [[OneVsRest]].
- */
-private[ml] trait OneVsRestParams extends PredictorParams {
-
+private[ml] trait ClassifierTypeTrait {
// scalastyle:off structural.type
type ClassifierType = Classifier[F, E, M] forSome {
type F
@@ -44,6 +46,12 @@ private[ml] trait OneVsRestParams extends PredictorParams {
type E <: Classifier[F, E, M]
}
// scalastyle:on structural.type
+}
+
+/**
+ * Params for [[OneVsRest]].
+ */
+private[ml] trait OneVsRestParams extends PredictorParams with ClassifierTypeTrait {
/**
* param for the base binary classifier that we reduce multiclass classification into.
@@ -57,6 +65,55 @@ private[ml] trait OneVsRestParams extends PredictorParams {
def getClassifier: ClassifierType = $(classifier)
}
+private[ml] object OneVsRestParams extends ClassifierTypeTrait {
+
+ def validateParams(instance: OneVsRestParams): Unit = {
+ def checkElement(elem: Params, name: String): Unit = elem match {
+ case stage: MLWritable => // good
+ case other =>
+ throw new UnsupportedOperationException("OneVsRest write will fail " +
+ s" because it contains $name which does not implement MLWritable." +
+ s" Non-Writable $name: ${other.uid} of type ${other.getClass}")
+ }
+
+ instance match {
+ case ovrModel: OneVsRestModel => ovrModel.models.foreach(checkElement(_, "model"))
+ case _ => // no need to check OneVsRest here
+ }
+
+ checkElement(instance.getClassifier, "classifier")
+ }
+
+ def saveImpl(
+ path: String,
+ instance: OneVsRestParams,
+ sc: SparkContext,
+ extraMetadata: Option[JObject] = None): Unit = {
+
+ val params = instance.extractParamMap().toSeq
+ val jsonParams = render(params
+ .filter { case ParamPair(p, v) => p.name != "classifier" }
+ .map { case ParamPair(p, v) => p.name -> parse(p.jsonEncode(v)) }
+ .toList)
+
+ DefaultParamsWriter.saveMetadata(instance, path, sc, extraMetadata, Some(jsonParams))
+
+ val classifierPath = new Path(path, "classifier").toString
+ instance.getClassifier.asInstanceOf[MLWritable].save(classifierPath)
+ }
+
+ def loadImpl(
+ path: String,
+ sc: SparkContext,
+ expectedClassName: String): (DefaultParamsReader.Metadata, ClassifierType) = {
+
+ val metadata = DefaultParamsReader.loadMetadata(path, sc, expectedClassName)
+ val classifierPath = new Path(path, "classifier").toString
+ val estimator = DefaultParamsReader.loadParamsInstance[ClassifierType](classifierPath, sc)
+ (metadata, estimator)
+ }
+}
+
/**
* :: Experimental ::
* Model produced by [[OneVsRest]].
@@ -73,18 +130,18 @@ private[ml] trait OneVsRestParams extends PredictorParams {
@Since("1.4.0")
@Experimental
final class OneVsRestModel private[ml] (
- @Since("1.4.0") override val uid: String,
- @Since("1.4.0") labelMetadata: Metadata,
+ @Since("1.4.0") override val uid: String,
+ private[ml] val labelMetadata: Metadata,
@Since("1.4.0") val models: Array[_ <: ClassificationModel[_, _]])
- extends Model[OneVsRestModel] with OneVsRestParams {
+ extends Model[OneVsRestModel] with OneVsRestParams with MLWritable {
@Since("1.4.0")
override def transformSchema(schema: StructType): StructType = {
validateAndTransformSchema(schema, fitting = false, getClassifier.featuresDataType)
}
- @Since("1.4.0")
- override def transform(dataset: DataFrame): DataFrame = {
+ @Since("2.0.0")
+ override def transform(dataset: Dataset[_]): DataFrame = {
// Check schema
transformSchema(dataset.schema, logging = true)
@@ -143,6 +200,56 @@ final class OneVsRestModel private[ml] (
uid, labelMetadata, models.map(_.copy(extra).asInstanceOf[ClassificationModel[_, _]]))
copyValues(copied, extra).setParent(parent)
}
+
+ @Since("2.0.0")
+ override def write: MLWriter = new OneVsRestModel.OneVsRestModelWriter(this)
+}
+
+@Since("2.0.0")
+object OneVsRestModel extends MLReadable[OneVsRestModel] {
+
+ @Since("2.0.0")
+ override def read: MLReader[OneVsRestModel] = new OneVsRestModelReader
+
+ @Since("2.0.0")
+ override def load(path: String): OneVsRestModel = super.load(path)
+
+ /** [[MLWriter]] instance for [[OneVsRestModel]] */
+ private[OneVsRestModel] class OneVsRestModelWriter(instance: OneVsRestModel) extends MLWriter {
+
+ OneVsRestParams.validateParams(instance)
+
+ override protected def saveImpl(path: String): Unit = {
+ val extraJson = ("labelMetadata" -> instance.labelMetadata.json) ~
+ ("numClasses" -> instance.models.length)
+ OneVsRestParams.saveImpl(path, instance, sc, Some(extraJson))
+ instance.models.zipWithIndex.foreach { case (model: MLWritable, idx) =>
+ val modelPath = new Path(path, s"model_$idx").toString
+ model.save(modelPath)
+ }
+ }
+ }
+
+ private class OneVsRestModelReader extends MLReader[OneVsRestModel] {
+
+ /** Checked against metadata when loading model */
+ private val className = classOf[OneVsRestModel].getName
+
+ override def load(path: String): OneVsRestModel = {
+ implicit val format = DefaultFormats
+ val (metadata, classifier) = OneVsRestParams.loadImpl(path, sc, className)
+ val labelMetadata = Metadata.fromJson((metadata.metadata \ "labelMetadata").extract[String])
+ val numClasses = (metadata.metadata \ "numClasses").extract[Int]
+ val models = Range(0, numClasses).toArray.map { idx =>
+ val modelPath = new Path(path, s"model_$idx").toString
+ DefaultParamsReader.loadParamsInstance[ClassificationModel[_, _]](modelPath, sc)
+ }
+ val ovrModel = new OneVsRestModel(metadata.uid, labelMetadata, models)
+ DefaultParamsReader.getAndSetParams(ovrModel, metadata)
+ ovrModel.set("classifier", classifier)
+ ovrModel
+ }
+ }
}
/**
@@ -158,7 +265,7 @@ final class OneVsRestModel private[ml] (
@Experimental
final class OneVsRest @Since("1.4.0") (
@Since("1.4.0") override val uid: String)
- extends Estimator[OneVsRestModel] with OneVsRestParams {
+ extends Estimator[OneVsRestModel] with OneVsRestParams with MLWritable {
@Since("1.4.0")
def this() = this(Identifiable.randomUID("oneVsRest"))
@@ -186,12 +293,14 @@ final class OneVsRest @Since("1.4.0") (
validateAndTransformSchema(schema, fitting = true, getClassifier.featuresDataType)
}
- @Since("1.4.0")
- override def fit(dataset: DataFrame): OneVsRestModel = {
+ @Since("2.0.0")
+ override def fit(dataset: Dataset[_]): OneVsRestModel = {
+ transformSchema(dataset.schema)
+
// determine number of classes either from metadata if provided, or via computation.
val labelSchema = dataset.schema($(labelCol))
val computeNumClasses: () => Int = () => {
- val Row(maxLabelIndex: Double) = dataset.agg(max($(labelCol))).head()
+ val Row(maxLabelIndex: Double) = dataset.agg(max(col($(labelCol)).cast(DoubleType))).head()
// classes are assumed to be numbered from 0,...,maxLabelIndex
maxLabelIndex.toInt + 1
}
@@ -243,4 +352,40 @@ final class OneVsRest @Since("1.4.0") (
}
copied
}
+
+ @Since("2.0.0")
+ override def write: MLWriter = new OneVsRest.OneVsRestWriter(this)
+}
+
+@Since("2.0.0")
+object OneVsRest extends MLReadable[OneVsRest] {
+
+ @Since("2.0.0")
+ override def read: MLReader[OneVsRest] = new OneVsRestReader
+
+ @Since("2.0.0")
+ override def load(path: String): OneVsRest = super.load(path)
+
+ /** [[MLWriter]] instance for [[OneVsRest]] */
+ private[OneVsRest] class OneVsRestWriter(instance: OneVsRest) extends MLWriter {
+
+ OneVsRestParams.validateParams(instance)
+
+ override protected def saveImpl(path: String): Unit = {
+ OneVsRestParams.saveImpl(path, instance, sc)
+ }
+ }
+
+ private class OneVsRestReader extends MLReader[OneVsRest] {
+
+ /** Checked against metadata when loading model */
+ private val className = classOf[OneVsRest].getName
+
+ override def load(path: String): OneVsRest = {
+ val (metadata, classifier) = OneVsRestParams.loadImpl(path, sc, className)
+ val ovr = new OneVsRest(metadata.uid)
+ DefaultParamsReader.getAndSetParams(ovr, metadata)
+ ovr.setClassifier(classifier)
+ }
+ }
}