diff options
author | Joseph K. Bradley <joseph@databricks.com> | 2015-05-12 16:42:30 -0700 |
---|---|---|
committer | Xiangrui Meng <meng@databricks.com> | 2015-05-12 16:42:36 -0700 |
commit | 612247ff0593810a09a63905d0a88daecf821d10 (patch) | |
tree | 77060773b5a43a079e2ad0b78d81efa78683271a /mllib/src/main | |
parent | d080df10bc6085435ef2cebdb217b6c52ed9fdf3 (diff) | |
download | spark-612247ff0593810a09a63905d0a88daecf821d10.tar.gz spark-612247ff0593810a09a63905d0a88daecf821d10.tar.bz2 spark-612247ff0593810a09a63905d0a88daecf821d10.zip |
[SPARK-7573] [ML] OneVsRest cleanups
Minor cleanups discussed with [~mengxr]:
* move OneVsRest from reduction to classification sub-package
* make model constructor private
Some doc cleanups too
CC: harsha2010 Could you please verify this looks OK? Thanks!
Author: Joseph K. Bradley <joseph@databricks.com>
Closes #6097 from jkbradley/onevsrest-cleanup and squashes the following commits:
4ecd48d [Joseph K. Bradley] org imports
430b065 [Joseph K. Bradley] moved OneVsRest from reduction subpackage to classification. small java doc style fixes
9f8b9b9 [Joseph K. Bradley] Small cleanups to OneVsRest. Made model constructor private to ml package.
(cherry picked from commit 96c4846db89802f5a81dca5dcfa3f2a0f72b5cb8)
Signed-off-by: Xiangrui Meng <meng@databricks.com>
Diffstat (limited to 'mllib/src/main')
-rw-r--r-- | mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala (renamed from mllib/src/main/scala/org/apache/spark/ml/reduction/OneVsRest.scala) | 32 |
1 files changed, 15 insertions, 17 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/reduction/OneVsRest.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala index 0a6728ef1f..afb8d75d57 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/reduction/OneVsRest.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.ml.reduction +package org.apache.spark.ml.classification import java.util.UUID @@ -24,7 +24,6 @@ import scala.language.existentials import org.apache.spark.annotation.{AlphaComponent, Experimental} import org.apache.spark.ml._ import org.apache.spark.ml.attribute._ -import org.apache.spark.ml.classification.{ClassificationModel, Classifier} import org.apache.spark.ml.param.Param import org.apache.spark.ml.util.MetadataUtils import org.apache.spark.mllib.linalg.Vector @@ -57,20 +56,21 @@ private[ml] trait OneVsRestParams extends PredictorParams { } /** + * :: AlphaComponent :: + * * Model produced by [[OneVsRest]]. - * Stores the models resulting from training k different classifiers: - * one for each class. - * Each example is scored against all k models and the model with highest score + * This stores the models resulting from training k binary classifiers: one for each class. + * Each example is scored against all k models, and the model with the highest score * is picked to label the example. - * TODO: API may need to change when we introduce a ClassificationModel trait as the public API - * @param parent + * * @param labelMetadata Metadata of label column if it exists, or Nominal attribute * representing the number of classes in training dataset otherwise. - * @param models the binary classification models for reduction. - * The i-th model is produced by testing the i-th class vs the rest. + * @param models The binary classification models for the reduction. + * The i-th model is produced by testing the i-th class (taking label 1) vs the rest + * (taking label 0). */ @AlphaComponent -class OneVsRestModel( +class OneVsRestModel private[ml] ( override val parent: OneVsRest, labelMetadata: Metadata, val models: Array[_ <: ClassificationModel[_,_]]) @@ -90,7 +90,7 @@ class OneVsRestModel( // add an accumulator column to store predictions of all the models val accColName = "mbc$acc" + UUID.randomUUID().toString val init: () => Map[Int, Double] = () => {Map()} - val mapType = MapType(IntegerType, DoubleType, false) + val mapType = MapType(IntegerType, DoubleType, valueContainsNull = false) val newDataset = dataset.withColumn(accColName, callUDF(init, mapType)) // persist if underlying dataset is not persistent. @@ -101,7 +101,7 @@ class OneVsRestModel( // update the accumulator column with the result of prediction of models val aggregatedDataset = models.zipWithIndex.foldLeft[DataFrame](newDataset) { - case (df, (model, index)) => { + case (df, (model, index)) => val rawPredictionCol = model.getRawPredictionCol val columns = origCols ++ List(col(rawPredictionCol), col(accColName)) @@ -110,7 +110,7 @@ class OneVsRestModel( val update: (Map[Int, Double], Vector) => Map[Int, Double] = (predictions: Map[Int, Double], prediction: Vector) => { predictions + ((index, prediction(1))) - } + } val updateUdf = callUDF(update, mapType, col(accColName), col(rawPredictionCol)) val transformedDataset = model.transform(df).select(columns:_*) val updatedDataset = transformedDataset.withColumn(tmpColName, updateUdf) @@ -118,7 +118,6 @@ class OneVsRestModel( // switch out the intermediate column with the accumulator column updatedDataset.select(newColumns:_*).withColumnRenamed(tmpColName, accColName) - } } if (handlePersistence) { @@ -149,8 +148,8 @@ class OneVsRestModel( final class OneVsRest extends Estimator[OneVsRestModel] with OneVsRestParams { /** @group setParam */ - // TODO: Find a better way to do this. Existential Types don't work with Java API so cast needed. def setClassifier(value: Classifier[_,_,_]): this.type = { + // TODO: Find a better way to do this. Existential Types don't work with Java API so cast needed set(classifier, value.asInstanceOf[ClassifierType]) } @@ -201,9 +200,8 @@ final class OneVsRest extends Estimator[OneVsRestModel] with OneVsRestParams { // extract label metadata from label column if present, or create a nominal attribute // to output the number of labels val labelAttribute = Attribute.fromStructField(labelSchema) match { - case _: NumericAttribute | UnresolvedAttribute => { + case _: NumericAttribute | UnresolvedAttribute => NominalAttribute.defaultAttr.withName("label").withNumValues(numClasses) - } case attr: Attribute => attr } copyValues(new OneVsRestModel(this, labelAttribute.toMetadata(), models)) |