aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/main
diff options
context:
space:
mode:
authorJoseph K. Bradley <joseph@databricks.com>2015-05-12 16:42:30 -0700
committerXiangrui Meng <meng@databricks.com>2015-05-12 16:42:36 -0700
commit612247ff0593810a09a63905d0a88daecf821d10 (patch)
tree77060773b5a43a079e2ad0b78d81efa78683271a /mllib/src/main
parentd080df10bc6085435ef2cebdb217b6c52ed9fdf3 (diff)
downloadspark-612247ff0593810a09a63905d0a88daecf821d10.tar.gz
spark-612247ff0593810a09a63905d0a88daecf821d10.tar.bz2
spark-612247ff0593810a09a63905d0a88daecf821d10.zip
[SPARK-7573] [ML] OneVsRest cleanups
Minor cleanups discussed with [~mengxr]: * move OneVsRest from reduction to classification sub-package * make model constructor private Some doc cleanups too CC: harsha2010 Could you please verify this looks OK? Thanks! Author: Joseph K. Bradley <joseph@databricks.com> Closes #6097 from jkbradley/onevsrest-cleanup and squashes the following commits: 4ecd48d [Joseph K. Bradley] org imports 430b065 [Joseph K. Bradley] moved OneVsRest from reduction subpackage to classification. small java doc style fixes 9f8b9b9 [Joseph K. Bradley] Small cleanups to OneVsRest. Made model constructor private to ml package. (cherry picked from commit 96c4846db89802f5a81dca5dcfa3f2a0f72b5cb8) Signed-off-by: Xiangrui Meng <meng@databricks.com>
Diffstat (limited to 'mllib/src/main')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala (renamed from mllib/src/main/scala/org/apache/spark/ml/reduction/OneVsRest.scala)32
1 files changed, 15 insertions, 17 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/reduction/OneVsRest.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala
index 0a6728ef1f..afb8d75d57 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/reduction/OneVsRest.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala
@@ -15,7 +15,7 @@
* limitations under the License.
*/
-package org.apache.spark.ml.reduction
+package org.apache.spark.ml.classification
import java.util.UUID
@@ -24,7 +24,6 @@ import scala.language.existentials
import org.apache.spark.annotation.{AlphaComponent, Experimental}
import org.apache.spark.ml._
import org.apache.spark.ml.attribute._
-import org.apache.spark.ml.classification.{ClassificationModel, Classifier}
import org.apache.spark.ml.param.Param
import org.apache.spark.ml.util.MetadataUtils
import org.apache.spark.mllib.linalg.Vector
@@ -57,20 +56,21 @@ private[ml] trait OneVsRestParams extends PredictorParams {
}
/**
+ * :: AlphaComponent ::
+ *
* Model produced by [[OneVsRest]].
- * Stores the models resulting from training k different classifiers:
- * one for each class.
- * Each example is scored against all k models and the model with highest score
+ * This stores the models resulting from training k binary classifiers: one for each class.
+ * Each example is scored against all k models, and the model with the highest score
* is picked to label the example.
- * TODO: API may need to change when we introduce a ClassificationModel trait as the public API
- * @param parent
+ *
* @param labelMetadata Metadata of label column if it exists, or Nominal attribute
* representing the number of classes in training dataset otherwise.
- * @param models the binary classification models for reduction.
- * The i-th model is produced by testing the i-th class vs the rest.
+ * @param models The binary classification models for the reduction.
+ * The i-th model is produced by testing the i-th class (taking label 1) vs the rest
+ * (taking label 0).
*/
@AlphaComponent
-class OneVsRestModel(
+class OneVsRestModel private[ml] (
override val parent: OneVsRest,
labelMetadata: Metadata,
val models: Array[_ <: ClassificationModel[_,_]])
@@ -90,7 +90,7 @@ class OneVsRestModel(
// add an accumulator column to store predictions of all the models
val accColName = "mbc$acc" + UUID.randomUUID().toString
val init: () => Map[Int, Double] = () => {Map()}
- val mapType = MapType(IntegerType, DoubleType, false)
+ val mapType = MapType(IntegerType, DoubleType, valueContainsNull = false)
val newDataset = dataset.withColumn(accColName, callUDF(init, mapType))
// persist if underlying dataset is not persistent.
@@ -101,7 +101,7 @@ class OneVsRestModel(
// update the accumulator column with the result of prediction of models
val aggregatedDataset = models.zipWithIndex.foldLeft[DataFrame](newDataset) {
- case (df, (model, index)) => {
+ case (df, (model, index)) =>
val rawPredictionCol = model.getRawPredictionCol
val columns = origCols ++ List(col(rawPredictionCol), col(accColName))
@@ -110,7 +110,7 @@ class OneVsRestModel(
val update: (Map[Int, Double], Vector) => Map[Int, Double] =
(predictions: Map[Int, Double], prediction: Vector) => {
predictions + ((index, prediction(1)))
- }
+ }
val updateUdf = callUDF(update, mapType, col(accColName), col(rawPredictionCol))
val transformedDataset = model.transform(df).select(columns:_*)
val updatedDataset = transformedDataset.withColumn(tmpColName, updateUdf)
@@ -118,7 +118,6 @@ class OneVsRestModel(
// switch out the intermediate column with the accumulator column
updatedDataset.select(newColumns:_*).withColumnRenamed(tmpColName, accColName)
- }
}
if (handlePersistence) {
@@ -149,8 +148,8 @@ class OneVsRestModel(
final class OneVsRest extends Estimator[OneVsRestModel] with OneVsRestParams {
/** @group setParam */
- // TODO: Find a better way to do this. Existential Types don't work with Java API so cast needed.
def setClassifier(value: Classifier[_,_,_]): this.type = {
+ // TODO: Find a better way to do this. Existential Types don't work with Java API so cast needed
set(classifier, value.asInstanceOf[ClassifierType])
}
@@ -201,9 +200,8 @@ final class OneVsRest extends Estimator[OneVsRestModel] with OneVsRestParams {
// extract label metadata from label column if present, or create a nominal attribute
// to output the number of labels
val labelAttribute = Attribute.fromStructField(labelSchema) match {
- case _: NumericAttribute | UnresolvedAttribute => {
+ case _: NumericAttribute | UnresolvedAttribute =>
NominalAttribute.defaultAttr.withName("label").withNumValues(numClasses)
- }
case attr: Attribute => attr
}
copyValues(new OneVsRestModel(this, labelAttribute.toMetadata(), models))