aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
authorJoseph K. Bradley <joseph@databricks.com>2015-05-12 16:42:30 -0700
committerXiangrui Meng <meng@databricks.com>2015-05-12 16:42:30 -0700
commit96c4846db89802f5a81dca5dcfa3f2a0f72b5cb8 (patch)
tree498bb4bda7cb35ecf99e5a81da52f32952af732a /mllib
parentf0c1bc3472a7422ae5649634f29c88e161f5ecaf (diff)
downloadspark-96c4846db89802f5a81dca5dcfa3f2a0f72b5cb8.tar.gz
spark-96c4846db89802f5a81dca5dcfa3f2a0f72b5cb8.tar.bz2
spark-96c4846db89802f5a81dca5dcfa3f2a0f72b5cb8.zip
[SPARK-7573] [ML] OneVsRest cleanups
Minor cleanups discussed with [~mengxr]: * move OneVsRest from reduction to classification sub-package * make model constructor private Some doc cleanups too CC: harsha2010 Could you please verify this looks OK? Thanks! Author: Joseph K. Bradley <joseph@databricks.com> Closes #6097 from jkbradley/onevsrest-cleanup and squashes the following commits: 4ecd48d [Joseph K. Bradley] org imports 430b065 [Joseph K. Bradley] moved OneVsRest from reduction subpackage to classification. small java doc style fixes 9f8b9b9 [Joseph K. Bradley] Small cleanups to OneVsRest. Made model constructor private to ml package.
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala (renamed from mllib/src/main/scala/org/apache/spark/ml/reduction/OneVsRest.scala)32
-rw-r--r--mllib/src/test/java/org/apache/spark/ml/classification/JavaOneVsRestSuite.java (renamed from mllib/src/test/java/org/apache/spark/ml/reduction/JavaOneVsRestSuite.java)13
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala (renamed from mllib/src/test/scala/org/apache/spark/ml/reduction/OneVsRestSuite.scala)9
3 files changed, 23 insertions, 31 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/reduction/OneVsRest.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala
index 0a6728ef1f..afb8d75d57 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/reduction/OneVsRest.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala
@@ -15,7 +15,7 @@
* limitations under the License.
*/
-package org.apache.spark.ml.reduction
+package org.apache.spark.ml.classification
import java.util.UUID
@@ -24,7 +24,6 @@ import scala.language.existentials
import org.apache.spark.annotation.{AlphaComponent, Experimental}
import org.apache.spark.ml._
import org.apache.spark.ml.attribute._
-import org.apache.spark.ml.classification.{ClassificationModel, Classifier}
import org.apache.spark.ml.param.Param
import org.apache.spark.ml.util.MetadataUtils
import org.apache.spark.mllib.linalg.Vector
@@ -57,20 +56,21 @@ private[ml] trait OneVsRestParams extends PredictorParams {
}
/**
+ * :: AlphaComponent ::
+ *
* Model produced by [[OneVsRest]].
- * Stores the models resulting from training k different classifiers:
- * one for each class.
- * Each example is scored against all k models and the model with highest score
+ * This stores the models resulting from training k binary classifiers: one for each class.
+ * Each example is scored against all k models, and the model with the highest score
* is picked to label the example.
- * TODO: API may need to change when we introduce a ClassificationModel trait as the public API
- * @param parent
+ *
* @param labelMetadata Metadata of label column if it exists, or Nominal attribute
* representing the number of classes in training dataset otherwise.
- * @param models the binary classification models for reduction.
- * The i-th model is produced by testing the i-th class vs the rest.
+ * @param models The binary classification models for the reduction.
+ * The i-th model is produced by testing the i-th class (taking label 1) vs the rest
+ * (taking label 0).
*/
@AlphaComponent
-class OneVsRestModel(
+class OneVsRestModel private[ml] (
override val parent: OneVsRest,
labelMetadata: Metadata,
val models: Array[_ <: ClassificationModel[_,_]])
@@ -90,7 +90,7 @@ class OneVsRestModel(
// add an accumulator column to store predictions of all the models
val accColName = "mbc$acc" + UUID.randomUUID().toString
val init: () => Map[Int, Double] = () => {Map()}
- val mapType = MapType(IntegerType, DoubleType, false)
+ val mapType = MapType(IntegerType, DoubleType, valueContainsNull = false)
val newDataset = dataset.withColumn(accColName, callUDF(init, mapType))
// persist if underlying dataset is not persistent.
@@ -101,7 +101,7 @@ class OneVsRestModel(
// update the accumulator column with the result of prediction of models
val aggregatedDataset = models.zipWithIndex.foldLeft[DataFrame](newDataset) {
- case (df, (model, index)) => {
+ case (df, (model, index)) =>
val rawPredictionCol = model.getRawPredictionCol
val columns = origCols ++ List(col(rawPredictionCol), col(accColName))
@@ -110,7 +110,7 @@ class OneVsRestModel(
val update: (Map[Int, Double], Vector) => Map[Int, Double] =
(predictions: Map[Int, Double], prediction: Vector) => {
predictions + ((index, prediction(1)))
- }
+ }
val updateUdf = callUDF(update, mapType, col(accColName), col(rawPredictionCol))
val transformedDataset = model.transform(df).select(columns:_*)
val updatedDataset = transformedDataset.withColumn(tmpColName, updateUdf)
@@ -118,7 +118,6 @@ class OneVsRestModel(
// switch out the intermediate column with the accumulator column
updatedDataset.select(newColumns:_*).withColumnRenamed(tmpColName, accColName)
- }
}
if (handlePersistence) {
@@ -149,8 +148,8 @@ class OneVsRestModel(
final class OneVsRest extends Estimator[OneVsRestModel] with OneVsRestParams {
/** @group setParam */
- // TODO: Find a better way to do this. Existential Types don't work with Java API so cast needed.
def setClassifier(value: Classifier[_,_,_]): this.type = {
+ // TODO: Find a better way to do this. Existential Types don't work with Java API so cast needed
set(classifier, value.asInstanceOf[ClassifierType])
}
@@ -201,9 +200,8 @@ final class OneVsRest extends Estimator[OneVsRestModel] with OneVsRestParams {
// extract label metadata from label column if present, or create a nominal attribute
// to output the number of labels
val labelAttribute = Attribute.fromStructField(labelSchema) match {
- case _: NumericAttribute | UnresolvedAttribute => {
+ case _: NumericAttribute | UnresolvedAttribute =>
NominalAttribute.defaultAttr.withName("label").withNumValues(numClasses)
- }
case attr: Attribute => attr
}
copyValues(new OneVsRestModel(this, labelAttribute.toMetadata(), models))
diff --git a/mllib/src/test/java/org/apache/spark/ml/reduction/JavaOneVsRestSuite.java b/mllib/src/test/java/org/apache/spark/ml/classification/JavaOneVsRestSuite.java
index 40a90ae9de..a1ee554152 100644
--- a/mllib/src/test/java/org/apache/spark/ml/reduction/JavaOneVsRestSuite.java
+++ b/mllib/src/test/java/org/apache/spark/ml/classification/JavaOneVsRestSuite.java
@@ -15,21 +15,20 @@
* limitations under the License.
*/
-package org.apache.spark.ml.reduction;
+package org.apache.spark.ml.classification;
import java.io.Serializable;
import java.util.List;
+import static scala.collection.JavaConversions.seqAsJavaList;
+
import org.junit.After;
import org.junit.Assert;
import org.junit.Before;
import org.junit.Test;
-import static scala.collection.JavaConversions.seqAsJavaList;
-
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
-import org.apache.spark.ml.classification.LogisticRegression;
import static org.apache.spark.mllib.classification.LogisticRegressionSuite.generateMultinomialLogisticInput;
import org.apache.spark.mllib.regression.LabeledPoint;
import org.apache.spark.sql.DataFrame;
@@ -48,10 +47,8 @@ public class JavaOneVsRestSuite implements Serializable {
jsql = new SQLContext(jsc);
int nPoints = 3;
- /**
- * The following weights and xMean/xVariance are computed from iris dataset with lambda = 0.2.
- * As a result, we are actually drawing samples from probability distribution of built model.
- */
+ // The following weights and xMean/xVariance are computed from iris dataset with lambda=0.2.
+ // As a result, we are drawing samples from probability distribution of an actual model.
double[] weights = {
-0.57997, 0.912083, -0.371077, -0.819866, 2.688191,
-0.16624, -0.84355, -0.048509, -0.301789, 4.170682 };
diff --git a/mllib/src/test/scala/org/apache/spark/ml/reduction/OneVsRestSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala
index ebec7c68e8..e65ffae918 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/reduction/OneVsRestSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala
@@ -15,12 +15,11 @@
* limitations under the License.
*/
-package org.apache.spark.ml.reduction
+package org.apache.spark.ml.classification
import org.scalatest.FunSuite
import org.apache.spark.ml.attribute.NominalAttribute
-import org.apache.spark.ml.classification.{LogisticRegressionModel, LogisticRegression}
import org.apache.spark.ml.util.MetadataUtils
import org.apache.spark.mllib.classification.LogisticRegressionSuite._
import org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS
@@ -42,10 +41,8 @@ class OneVsRestSuite extends FunSuite with MLlibTestSparkContext {
sqlContext = new SQLContext(sc)
val nPoints = 1000
- /**
- * The following weights and xMean/xVariance are computed from iris dataset with lambda = 0.2.
- * As a result, we are actually drawing samples from probability distribution of built model.
- */
+ // The following weights and xMean/xVariance are computed from iris dataset with lambda=0.2.
+ // As a result, we are drawing samples from probability distribution of an actual model.
val weights = Array(
-0.57997, 0.912083, -0.371077, -0.819866, 2.688191,
-0.16624, -0.84355, -0.048509, -0.301789, 4.170682)