diff options
Diffstat (limited to 'mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala')
-rw-r--r-- | mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala | 179 |
1 files changed, 162 insertions, 17 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala index c41a611f1c..4de1b877b0 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala @@ -21,22 +21,24 @@ import java.util.UUID import scala.language.existentials +import org.apache.hadoop.fs.Path +import org.json4s.{DefaultFormats, JObject, _} +import org.json4s.JsonDSL._ +import org.json4s.jackson.JsonMethods._ + +import org.apache.spark.SparkContext import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.ml._ import org.apache.spark.ml.attribute._ -import org.apache.spark.ml.param.{Param, ParamMap} -import org.apache.spark.ml.util.{Identifiable, MetadataUtils} +import org.apache.spark.ml.param.{Param, ParamMap, ParamPair, Params} +import org.apache.spark.ml.util._ import org.apache.spark.mllib.linalg.Vector -import org.apache.spark.sql.{DataFrame, Row} +import org.apache.spark.sql.{DataFrame, Dataset, Row} import org.apache.spark.sql.functions._ import org.apache.spark.sql.types._ import org.apache.spark.storage.StorageLevel -/** - * Params for [[OneVsRest]]. - */ -private[ml] trait OneVsRestParams extends PredictorParams { - +private[ml] trait ClassifierTypeTrait { // scalastyle:off structural.type type ClassifierType = Classifier[F, E, M] forSome { type F @@ -44,6 +46,12 @@ private[ml] trait OneVsRestParams extends PredictorParams { type E <: Classifier[F, E, M] } // scalastyle:on structural.type +} + +/** + * Params for [[OneVsRest]]. + */ +private[ml] trait OneVsRestParams extends PredictorParams with ClassifierTypeTrait { /** * param for the base binary classifier that we reduce multiclass classification into. @@ -57,6 +65,55 @@ private[ml] trait OneVsRestParams extends PredictorParams { def getClassifier: ClassifierType = $(classifier) } +private[ml] object OneVsRestParams extends ClassifierTypeTrait { + + def validateParams(instance: OneVsRestParams): Unit = { + def checkElement(elem: Params, name: String): Unit = elem match { + case stage: MLWritable => // good + case other => + throw new UnsupportedOperationException("OneVsRest write will fail " + + s" because it contains $name which does not implement MLWritable." + + s" Non-Writable $name: ${other.uid} of type ${other.getClass}") + } + + instance match { + case ovrModel: OneVsRestModel => ovrModel.models.foreach(checkElement(_, "model")) + case _ => // no need to check OneVsRest here + } + + checkElement(instance.getClassifier, "classifier") + } + + def saveImpl( + path: String, + instance: OneVsRestParams, + sc: SparkContext, + extraMetadata: Option[JObject] = None): Unit = { + + val params = instance.extractParamMap().toSeq + val jsonParams = render(params + .filter { case ParamPair(p, v) => p.name != "classifier" } + .map { case ParamPair(p, v) => p.name -> parse(p.jsonEncode(v)) } + .toList) + + DefaultParamsWriter.saveMetadata(instance, path, sc, extraMetadata, Some(jsonParams)) + + val classifierPath = new Path(path, "classifier").toString + instance.getClassifier.asInstanceOf[MLWritable].save(classifierPath) + } + + def loadImpl( + path: String, + sc: SparkContext, + expectedClassName: String): (DefaultParamsReader.Metadata, ClassifierType) = { + + val metadata = DefaultParamsReader.loadMetadata(path, sc, expectedClassName) + val classifierPath = new Path(path, "classifier").toString + val estimator = DefaultParamsReader.loadParamsInstance[ClassifierType](classifierPath, sc) + (metadata, estimator) + } +} + /** * :: Experimental :: * Model produced by [[OneVsRest]]. @@ -73,18 +130,18 @@ private[ml] trait OneVsRestParams extends PredictorParams { @Since("1.4.0") @Experimental final class OneVsRestModel private[ml] ( - @Since("1.4.0") override val uid: String, - @Since("1.4.0") labelMetadata: Metadata, + @Since("1.4.0") override val uid: String, + private[ml] val labelMetadata: Metadata, @Since("1.4.0") val models: Array[_ <: ClassificationModel[_, _]]) - extends Model[OneVsRestModel] with OneVsRestParams { + extends Model[OneVsRestModel] with OneVsRestParams with MLWritable { @Since("1.4.0") override def transformSchema(schema: StructType): StructType = { validateAndTransformSchema(schema, fitting = false, getClassifier.featuresDataType) } - @Since("1.4.0") - override def transform(dataset: DataFrame): DataFrame = { + @Since("2.0.0") + override def transform(dataset: Dataset[_]): DataFrame = { // Check schema transformSchema(dataset.schema, logging = true) @@ -143,6 +200,56 @@ final class OneVsRestModel private[ml] ( uid, labelMetadata, models.map(_.copy(extra).asInstanceOf[ClassificationModel[_, _]])) copyValues(copied, extra).setParent(parent) } + + @Since("2.0.0") + override def write: MLWriter = new OneVsRestModel.OneVsRestModelWriter(this) +} + +@Since("2.0.0") +object OneVsRestModel extends MLReadable[OneVsRestModel] { + + @Since("2.0.0") + override def read: MLReader[OneVsRestModel] = new OneVsRestModelReader + + @Since("2.0.0") + override def load(path: String): OneVsRestModel = super.load(path) + + /** [[MLWriter]] instance for [[OneVsRestModel]] */ + private[OneVsRestModel] class OneVsRestModelWriter(instance: OneVsRestModel) extends MLWriter { + + OneVsRestParams.validateParams(instance) + + override protected def saveImpl(path: String): Unit = { + val extraJson = ("labelMetadata" -> instance.labelMetadata.json) ~ + ("numClasses" -> instance.models.length) + OneVsRestParams.saveImpl(path, instance, sc, Some(extraJson)) + instance.models.zipWithIndex.foreach { case (model: MLWritable, idx) => + val modelPath = new Path(path, s"model_$idx").toString + model.save(modelPath) + } + } + } + + private class OneVsRestModelReader extends MLReader[OneVsRestModel] { + + /** Checked against metadata when loading model */ + private val className = classOf[OneVsRestModel].getName + + override def load(path: String): OneVsRestModel = { + implicit val format = DefaultFormats + val (metadata, classifier) = OneVsRestParams.loadImpl(path, sc, className) + val labelMetadata = Metadata.fromJson((metadata.metadata \ "labelMetadata").extract[String]) + val numClasses = (metadata.metadata \ "numClasses").extract[Int] + val models = Range(0, numClasses).toArray.map { idx => + val modelPath = new Path(path, s"model_$idx").toString + DefaultParamsReader.loadParamsInstance[ClassificationModel[_, _]](modelPath, sc) + } + val ovrModel = new OneVsRestModel(metadata.uid, labelMetadata, models) + DefaultParamsReader.getAndSetParams(ovrModel, metadata) + ovrModel.set("classifier", classifier) + ovrModel + } + } } /** @@ -158,7 +265,7 @@ final class OneVsRestModel private[ml] ( @Experimental final class OneVsRest @Since("1.4.0") ( @Since("1.4.0") override val uid: String) - extends Estimator[OneVsRestModel] with OneVsRestParams { + extends Estimator[OneVsRestModel] with OneVsRestParams with MLWritable { @Since("1.4.0") def this() = this(Identifiable.randomUID("oneVsRest")) @@ -186,12 +293,14 @@ final class OneVsRest @Since("1.4.0") ( validateAndTransformSchema(schema, fitting = true, getClassifier.featuresDataType) } - @Since("1.4.0") - override def fit(dataset: DataFrame): OneVsRestModel = { + @Since("2.0.0") + override def fit(dataset: Dataset[_]): OneVsRestModel = { + transformSchema(dataset.schema) + // determine number of classes either from metadata if provided, or via computation. val labelSchema = dataset.schema($(labelCol)) val computeNumClasses: () => Int = () => { - val Row(maxLabelIndex: Double) = dataset.agg(max($(labelCol))).head() + val Row(maxLabelIndex: Double) = dataset.agg(max(col($(labelCol)).cast(DoubleType))).head() // classes are assumed to be numbered from 0,...,maxLabelIndex maxLabelIndex.toInt + 1 } @@ -243,4 +352,40 @@ final class OneVsRest @Since("1.4.0") ( } copied } + + @Since("2.0.0") + override def write: MLWriter = new OneVsRest.OneVsRestWriter(this) +} + +@Since("2.0.0") +object OneVsRest extends MLReadable[OneVsRest] { + + @Since("2.0.0") + override def read: MLReader[OneVsRest] = new OneVsRestReader + + @Since("2.0.0") + override def load(path: String): OneVsRest = super.load(path) + + /** [[MLWriter]] instance for [[OneVsRest]] */ + private[OneVsRest] class OneVsRestWriter(instance: OneVsRest) extends MLWriter { + + OneVsRestParams.validateParams(instance) + + override protected def saveImpl(path: String): Unit = { + OneVsRestParams.saveImpl(path, instance, sc) + } + } + + private class OneVsRestReader extends MLReader[OneVsRest] { + + /** Checked against metadata when loading model */ + private val className = classOf[OneVsRest].getName + + override def load(path: String): OneVsRest = { + val (metadata, classifier) = OneVsRestParams.loadImpl(path, sc, className) + val ovr = new OneVsRest(metadata.uid) + DefaultParamsReader.getAndSetParams(ovr, metadata) + ovr.setClassifier(classifier) + } + } } |