aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
authorRam Sriharsha <rsriharsha@hw11853.local>2015-05-12 13:35:12 -0700
committerJoseph K. Bradley <joseph@databricks.com>2015-05-12 13:35:12 -0700
commit595a67589a42f8025d3e5fd4da413b1faa2e14bf (patch)
tree54073754a09b6ff793ba03fd4711dfcb16c7ad42 /mllib
parent5438f49ccf374fed16bc2b7fc1556e4c0095b14c (diff)
downloadspark-595a67589a42f8025d3e5fd4da413b1faa2e14bf.tar.gz
spark-595a67589a42f8025d3e5fd4da413b1faa2e14bf.tar.bz2
spark-595a67589a42f8025d3e5fd4da413b1faa2e14bf.zip
[SPARK-7015] [MLLIB] [WIP] Multiclass to Binary Reduction: One Against All
initial cut of one against all. test code is a scaffolding , not fully implemented. This WIP is to gather early feedback. Author: Ram Sriharsha <rsriharsha@hw11853.local> Closes #5830 from harsha2010/reduction and squashes the following commits: 5f4b495 [Ram Sriharsha] Fix Test 386e98b [Ram Sriharsha] Style fix 49b4a17 [Ram Sriharsha] Simplify the test 02279cc [Ram Sriharsha] Output Label Metadata in Prediction Col bc78032 [Ram Sriharsha] Code Review Updates 8ce4845 [Ram Sriharsha] Merge with Master 2a807be [Ram Sriharsha] Merge branch 'master' into reduction e21bfcc [Ram Sriharsha] Style Fix 5614f23 [Ram Sriharsha] Style Fix c75583a [Ram Sriharsha] Cleanup 7a5f136 [Ram Sriharsha] Fix TODOs 804826b [Ram Sriharsha] Merge with Master 1448a5f [Ram Sriharsha] Style Fix 6e47807 [Ram Sriharsha] Style Fix d63e46b [Ram Sriharsha] Incorporate Code Review Feedback ced68b5 [Ram Sriharsha] Refactor OneVsAll to implement Predictor 78fa82a [Ram Sriharsha] extra line 0dfa1fb [Ram Sriharsha] Fix inexhaustive match cases that may arise from UnresolvedAttribute a59a4f4 [Ram Sriharsha] @Experimental 4167234 [Ram Sriharsha] Merge branch 'master' into reduction 868a4fd [Ram Sriharsha] @Experimental 041d905 [Ram Sriharsha] Code Review Fixes df188d8 [Ram Sriharsha] Style fix 612ec48 [Ram Sriharsha] Style Fix 6ef43d3 [Ram Sriharsha] Prefer Unresolved Attribute to Option: Java APIs are cleaner 6bf6bff [Ram Sriharsha] Update OneHotEncoder to new API e29cb89 [Ram Sriharsha] Merge branch 'master' into reduction 1c7fa44 [Ram Sriharsha] Fix Tests ca83672 [Ram Sriharsha] Incorporate Code Review Feedback + Rename to OneVsRestClassifier 221beeed [Ram Sriharsha] Upgrade to use Copy method for cloning Base Classifiers 26f1ddb [Ram Sriharsha] Merge with SPARK-5956 API changes 9738744 [Ram Sriharsha] Merge branch 'master' into reduction 1a3e375 [Ram Sriharsha] More efficient Implementation: Use withColumn to generate label column dynamically 32e0189 [Ram Sriharsha] Restrict reduction to Margin Based Classifiers ff272da [Ram Sriharsha] Style fix 28771f5 [Ram Sriharsha] Add Tests for Multiclass to Binary Reduction b60f874 [Ram Sriharsha] Fix Style issues in Test 3191cdf [Ram Sriharsha] Remove this test, accidental commit 23f056c [Ram Sriharsha] Fix Headers for test 1b5e929 [Ram Sriharsha] Fix Style issues and add Header 8752863 [Ram Sriharsha] [SPARK-7015][MLLib][WIP] Multiclass to Binary Reduction: One Against All
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/Predictor.scala3
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/attribute/AttributeGroup.scala1
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/attribute/AttributeType.scala8
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/attribute/attributes.scala37
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala4
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/reduction/OneVsRest.scala211
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/util/MetadataUtils.scala7
-rw-r--r--mllib/src/test/java/org/apache/spark/ml/reduction/JavaOneVsRestSuite.java85
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/attribute/AttributeSuite.scala10
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/reduction/OneVsRestSuite.scala113
10 files changed, 471 insertions, 8 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala b/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala
index 0e53877de9..f6a5f27425 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala
@@ -113,7 +113,8 @@ abstract class Predictor[
*
* The default value is VectorUDT, but it may be overridden if FeaturesType is not Vector.
*/
- protected def featuresDataType: DataType = new VectorUDT
+ @DeveloperApi
+ private[ml] def featuresDataType: DataType = new VectorUDT
override def transformSchema(schema: StructType): StructType = {
validateAndTransformSchema(schema, fitting = true, featuresDataType)
diff --git a/mllib/src/main/scala/org/apache/spark/ml/attribute/AttributeGroup.scala b/mllib/src/main/scala/org/apache/spark/ml/attribute/AttributeGroup.scala
index d7dee8fed2..f5f37aa779 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/attribute/AttributeGroup.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/attribute/AttributeGroup.scala
@@ -123,6 +123,7 @@ class AttributeGroup private (
nominalMetadata += nominal.toMetadataImpl(withType = false)
case binary: BinaryAttribute =>
binaryMetadata += binary.toMetadataImpl(withType = false)
+ case UnresolvedAttribute =>
}
val attrBldr = new MetadataBuilder
if (numericMetadata.nonEmpty) {
diff --git a/mllib/src/main/scala/org/apache/spark/ml/attribute/AttributeType.scala b/mllib/src/main/scala/org/apache/spark/ml/attribute/AttributeType.scala
index 65e7e43d5a..a83febd7de 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/attribute/AttributeType.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/attribute/AttributeType.scala
@@ -43,6 +43,12 @@ object AttributeType {
Binary
}
+ /** Unresolved type. */
+ val Unresolved: AttributeType = {
+ case object Unresolved extends AttributeType("unresolved")
+ Unresolved
+ }
+
/**
* Gets the [[AttributeType]] object from its name.
* @param name attribute type name: "numeric", "nominal", or "binary"
@@ -54,6 +60,8 @@ object AttributeType {
Nominal
} else if (name == Binary.name) {
Binary
+ } else if (name == Unresolved.name) {
+ Unresolved
} else {
throw new IllegalArgumentException(s"Cannot recognize type $name.")
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/attribute/attributes.scala b/mllib/src/main/scala/org/apache/spark/ml/attribute/attributes.scala
index 5717d6ec2e..e8f7f15278 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/attribute/attributes.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/attribute/attributes.scala
@@ -125,7 +125,13 @@ private[attribute] trait AttributeFactory {
*/
def fromStructField(field: StructField): Attribute = {
require(field.dataType == DoubleType)
- fromMetadata(field.metadata.getMetadata(AttributeKeys.ML_ATTR)).withName(field.name)
+ val metadata = field.metadata
+ val mlAttr = AttributeKeys.ML_ATTR
+ if (metadata.contains(mlAttr)) {
+ fromMetadata(metadata.getMetadata(mlAttr)).withName(field.name)
+ } else {
+ UnresolvedAttribute
+ }
}
}
@@ -535,3 +541,32 @@ object BinaryAttribute extends AttributeFactory {
new BinaryAttribute(name, index, values)
}
}
+
+/**
+ * An unresolved attribute.
+ */
+object UnresolvedAttribute extends Attribute {
+
+ override def attrType: AttributeType = AttributeType.Unresolved
+
+ override def withIndex(index: Int): Attribute = this
+
+ override def isNumeric: Boolean = false
+
+ override def withoutIndex: Attribute = this
+
+ override def isNominal: Boolean = false
+
+ override def name: Option[String] = None
+
+ override private[attribute] def toMetadataImpl(withType: Boolean): Metadata = {
+ Metadata.empty
+ }
+
+ override def withoutName: Attribute = this
+
+ override def index: Option[Int] = None
+
+ override def withName(name: String): Attribute = this
+
+}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala
index 07ea579d69..2e6313ac14 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala
@@ -19,7 +19,7 @@ package org.apache.spark.ml.feature
import org.apache.spark.annotation.AlphaComponent
import org.apache.spark.ml.{Estimator, Model}
-import org.apache.spark.ml.attribute.{Attribute, AttributeGroup, BinaryAttribute, NominalAttribute, NumericAttribute}
+import org.apache.spark.ml.attribute._
import org.apache.spark.ml.param.{IntParam, ParamValidators, Params}
import org.apache.spark.ml.param.shared._
import org.apache.spark.ml.util.SchemaUtils
@@ -375,6 +375,8 @@ class VectorIndexerModel private[ml] (
}
case (origAttr: Attribute, featAttr: NumericAttribute) =>
origAttr.withIndex(featAttr.index.get)
+ case (origAttr: Attribute, _) =>
+ origAttr
}
} else {
partialFeatureAttributes
diff --git a/mllib/src/main/scala/org/apache/spark/ml/reduction/OneVsRest.scala b/mllib/src/main/scala/org/apache/spark/ml/reduction/OneVsRest.scala
new file mode 100644
index 0000000000..0a6728ef1f
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/ml/reduction/OneVsRest.scala
@@ -0,0 +1,211 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.reduction
+
+import java.util.UUID
+
+import scala.language.existentials
+
+import org.apache.spark.annotation.{AlphaComponent, Experimental}
+import org.apache.spark.ml._
+import org.apache.spark.ml.attribute._
+import org.apache.spark.ml.classification.{ClassificationModel, Classifier}
+import org.apache.spark.ml.param.Param
+import org.apache.spark.ml.util.MetadataUtils
+import org.apache.spark.mllib.linalg.Vector
+import org.apache.spark.sql.{DataFrame, Row}
+import org.apache.spark.sql.functions._
+import org.apache.spark.sql.types._
+import org.apache.spark.storage.StorageLevel
+
+/**
+ * Params for [[OneVsRest]].
+ */
+private[ml] trait OneVsRestParams extends PredictorParams {
+
+ type ClassifierType = Classifier[F, E, M] forSome {
+ type F
+ type M <: ClassificationModel[F, M]
+ type E <: Classifier[F, E, M]
+ }
+
+ /**
+ * param for the base binary classifier that we reduce multiclass classification into.
+ * @group param
+ */
+ val classifier: Param[ClassifierType] =
+ new Param(this, "classifier", "base binary classifier ")
+
+ /** @group getParam */
+ def getClassifier: ClassifierType = $(classifier)
+
+}
+
+/**
+ * Model produced by [[OneVsRest]].
+ * Stores the models resulting from training k different classifiers:
+ * one for each class.
+ * Each example is scored against all k models and the model with highest score
+ * is picked to label the example.
+ * TODO: API may need to change when we introduce a ClassificationModel trait as the public API
+ * @param parent
+ * @param labelMetadata Metadata of label column if it exists, or Nominal attribute
+ * representing the number of classes in training dataset otherwise.
+ * @param models the binary classification models for reduction.
+ * The i-th model is produced by testing the i-th class vs the rest.
+ */
+@AlphaComponent
+class OneVsRestModel(
+ override val parent: OneVsRest,
+ labelMetadata: Metadata,
+ val models: Array[_ <: ClassificationModel[_,_]])
+ extends Model[OneVsRestModel] with OneVsRestParams {
+
+ override def transformSchema(schema: StructType): StructType = {
+ validateAndTransformSchema(schema, fitting = false, getClassifier.featuresDataType)
+ }
+
+ override def transform(dataset: DataFrame): DataFrame = {
+ // Check schema
+ transformSchema(dataset.schema, logging = true)
+
+ // determine the input columns: these need to be passed through
+ val origCols = dataset.schema.map(f => col(f.name))
+
+ // add an accumulator column to store predictions of all the models
+ val accColName = "mbc$acc" + UUID.randomUUID().toString
+ val init: () => Map[Int, Double] = () => {Map()}
+ val mapType = MapType(IntegerType, DoubleType, false)
+ val newDataset = dataset.withColumn(accColName, callUDF(init, mapType))
+
+ // persist if underlying dataset is not persistent.
+ val handlePersistence = dataset.rdd.getStorageLevel == StorageLevel.NONE
+ if (handlePersistence) {
+ newDataset.persist(StorageLevel.MEMORY_AND_DISK)
+ }
+
+ // update the accumulator column with the result of prediction of models
+ val aggregatedDataset = models.zipWithIndex.foldLeft[DataFrame](newDataset) {
+ case (df, (model, index)) => {
+ val rawPredictionCol = model.getRawPredictionCol
+ val columns = origCols ++ List(col(rawPredictionCol), col(accColName))
+
+ // add temporary column to store intermediate scores and update
+ val tmpColName = "mbc$tmp" + UUID.randomUUID().toString
+ val update: (Map[Int, Double], Vector) => Map[Int, Double] =
+ (predictions: Map[Int, Double], prediction: Vector) => {
+ predictions + ((index, prediction(1)))
+ }
+ val updateUdf = callUDF(update, mapType, col(accColName), col(rawPredictionCol))
+ val transformedDataset = model.transform(df).select(columns:_*)
+ val updatedDataset = transformedDataset.withColumn(tmpColName, updateUdf)
+ val newColumns = origCols ++ List(col(tmpColName))
+
+ // switch out the intermediate column with the accumulator column
+ updatedDataset.select(newColumns:_*).withColumnRenamed(tmpColName, accColName)
+ }
+ }
+
+ if (handlePersistence) {
+ newDataset.unpersist()
+ }
+
+ // output the index of the classifier with highest confidence as prediction
+ val label: Map[Int, Double] => Double = (predictions: Map[Int, Double]) => {
+ predictions.maxBy(_._2)._1.toDouble
+ }
+
+ // output label and label metadata as prediction
+ val labelUdf = callUDF(label, DoubleType, col(accColName))
+ aggregatedDataset.withColumn($(predictionCol), labelUdf.as($(predictionCol), labelMetadata))
+ }
+}
+
+/**
+ * :: Experimental ::
+ *
+ * Reduction of Multiclass Classification to Binary Classification.
+ * Performs reduction using one against all strategy.
+ * For a multiclass classification with k classes, train k models (one per class).
+ * Each example is scored against all k models and the model with highest score
+ * is picked to label the example.
+ */
+@Experimental
+final class OneVsRest extends Estimator[OneVsRestModel] with OneVsRestParams {
+
+ /** @group setParam */
+ // TODO: Find a better way to do this. Existential Types don't work with Java API so cast needed.
+ def setClassifier(value: Classifier[_,_,_]): this.type = {
+ set(classifier, value.asInstanceOf[ClassifierType])
+ }
+
+ override def transformSchema(schema: StructType): StructType = {
+ validateAndTransformSchema(schema, fitting = true, getClassifier.featuresDataType)
+ }
+
+ override def fit(dataset: DataFrame): OneVsRestModel = {
+ // determine number of classes either from metadata if provided, or via computation.
+ val labelSchema = dataset.schema($(labelCol))
+ val computeNumClasses: () => Int = () => {
+ val Row(maxLabelIndex: Double) = dataset.agg(max($(labelCol))).head()
+ // classes are assumed to be numbered from 0,...,maxLabelIndex
+ maxLabelIndex.toInt + 1
+ }
+ val numClasses = MetadataUtils.getNumClasses(labelSchema).fold(computeNumClasses())(identity)
+
+ val multiclassLabeled = dataset.select($(labelCol), $(featuresCol))
+
+ // persist if underlying dataset is not persistent.
+ val handlePersistence = dataset.rdd.getStorageLevel == StorageLevel.NONE
+ if (handlePersistence) {
+ multiclassLabeled.persist(StorageLevel.MEMORY_AND_DISK)
+ }
+
+ // create k columns, one for each binary classifier.
+ val models = Range(0, numClasses).par.map { index =>
+
+ val label: Double => Double = (label: Double) => {
+ if (label.toInt == index) 1.0 else 0.0
+ }
+
+ // generate new label metadata for the binary problem.
+ // TODO: use when ... otherwise after SPARK-7321 is merged
+ val labelUDF = callUDF(label, DoubleType, col($(labelCol)))
+ val newLabelMeta = BinaryAttribute.defaultAttr.withName("label").toMetadata()
+ val labelColName = "mc2b$" + index
+ val labelUDFWithNewMeta = labelUDF.as(labelColName, newLabelMeta)
+ val trainingDataset = multiclassLabeled.withColumn(labelColName, labelUDFWithNewMeta)
+ val classifier = getClassifier
+ classifier.fit(trainingDataset, classifier.labelCol -> labelColName)
+ }.toArray[ClassificationModel[_,_]]
+
+ if (handlePersistence) {
+ multiclassLabeled.unpersist()
+ }
+
+ // extract label metadata from label column if present, or create a nominal attribute
+ // to output the number of labels
+ val labelAttribute = Attribute.fromStructField(labelSchema) match {
+ case _: NumericAttribute | UnresolvedAttribute => {
+ NominalAttribute.defaultAttr.withName("label").withNumValues(numClasses)
+ }
+ case attr: Attribute => attr
+ }
+ copyValues(new OneVsRestModel(this, labelAttribute.toMetadata(), models))
+ }
+}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/MetadataUtils.scala b/mllib/src/main/scala/org/apache/spark/ml/util/MetadataUtils.scala
index c84c8b4eb7..56075c9a6b 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/util/MetadataUtils.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/util/MetadataUtils.scala
@@ -20,8 +20,7 @@ package org.apache.spark.ml.util
import scala.collection.immutable.HashMap
import org.apache.spark.annotation.Experimental
-import org.apache.spark.ml.attribute.{Attribute, AttributeGroup, BinaryAttribute, NominalAttribute,
- NumericAttribute}
+import org.apache.spark.ml.attribute._
import org.apache.spark.sql.types.StructField
@@ -39,9 +38,9 @@ object MetadataUtils {
*/
def getNumClasses(labelSchema: StructField): Option[Int] = {
Attribute.fromStructField(labelSchema) match {
- case numAttr: NumericAttribute => None
case binAttr: BinaryAttribute => Some(2)
case nomAttr: NominalAttribute => nomAttr.getNumValues
+ case _: NumericAttribute | UnresolvedAttribute => None
}
}
@@ -65,7 +64,7 @@ object MetadataUtils {
Iterator()
} else {
attr match {
- case numAttr: NumericAttribute => Iterator()
+ case _: NumericAttribute | UnresolvedAttribute => Iterator()
case binAttr: BinaryAttribute => Iterator(idx -> 2)
case nomAttr: NominalAttribute =>
nomAttr.getNumValues match {
diff --git a/mllib/src/test/java/org/apache/spark/ml/reduction/JavaOneVsRestSuite.java b/mllib/src/test/java/org/apache/spark/ml/reduction/JavaOneVsRestSuite.java
new file mode 100644
index 0000000000..40a90ae9de
--- /dev/null
+++ b/mllib/src/test/java/org/apache/spark/ml/reduction/JavaOneVsRestSuite.java
@@ -0,0 +1,85 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.reduction;
+
+import java.io.Serializable;
+import java.util.List;
+
+import org.junit.After;
+import org.junit.Assert;
+import org.junit.Before;
+import org.junit.Test;
+
+import static scala.collection.JavaConversions.seqAsJavaList;
+
+import org.apache.spark.api.java.JavaRDD;
+import org.apache.spark.api.java.JavaSparkContext;
+import org.apache.spark.ml.classification.LogisticRegression;
+import static org.apache.spark.mllib.classification.LogisticRegressionSuite.generateMultinomialLogisticInput;
+import org.apache.spark.mllib.regression.LabeledPoint;
+import org.apache.spark.sql.DataFrame;
+import org.apache.spark.sql.SQLContext;
+
+public class JavaOneVsRestSuite implements Serializable {
+
+ private transient JavaSparkContext jsc;
+ private transient SQLContext jsql;
+ private transient DataFrame dataset;
+ private transient JavaRDD<LabeledPoint> datasetRDD;
+
+ @Before
+ public void setUp() {
+ jsc = new JavaSparkContext("local", "JavaLOneVsRestSuite");
+ jsql = new SQLContext(jsc);
+ int nPoints = 3;
+
+ /**
+ * The following weights and xMean/xVariance are computed from iris dataset with lambda = 0.2.
+ * As a result, we are actually drawing samples from probability distribution of built model.
+ */
+ double[] weights = {
+ -0.57997, 0.912083, -0.371077, -0.819866, 2.688191,
+ -0.16624, -0.84355, -0.048509, -0.301789, 4.170682 };
+
+ double[] xMean = {5.843, 3.057, 3.758, 1.199};
+ double[] xVariance = {0.6856, 0.1899, 3.116, 0.581};
+ List<LabeledPoint> points = seqAsJavaList(generateMultinomialLogisticInput(
+ weights, xMean, xVariance, true, nPoints, 42));
+ datasetRDD = jsc.parallelize(points, 2);
+ dataset = jsql.createDataFrame(datasetRDD, LabeledPoint.class);
+ }
+
+ @After
+ public void tearDown() {
+ jsc.stop();
+ jsc = null;
+ }
+
+ @Test
+ public void oneVsRestDefaultParams() {
+ OneVsRest ova = new OneVsRest();
+ ova.setClassifier(new LogisticRegression());
+ Assert.assertEquals(ova.getLabelCol() , "label");
+ Assert.assertEquals(ova.getPredictionCol() , "prediction");
+ OneVsRestModel ovaModel = ova.fit(dataset);
+ DataFrame predictions = ovaModel.transform(dataset).select("label", "prediction");
+ predictions.collectAsList();
+ Assert.assertEquals(ovaModel.getLabelCol(), "label");
+ Assert.assertEquals(ovaModel.getPredictionCol() , "prediction");
+ }
+}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/attribute/AttributeSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/attribute/AttributeSuite.scala
index 3e1a7196e3..ec9b717e41 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/attribute/AttributeSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/attribute/AttributeSuite.scala
@@ -19,7 +19,7 @@ package org.apache.spark.ml.attribute
import org.scalatest.FunSuite
-import org.apache.spark.sql.types.{DoubleType, MetadataBuilder, Metadata}
+import org.apache.spark.sql.types._
class AttributeSuite extends FunSuite {
@@ -209,4 +209,12 @@ class AttributeSuite extends FunSuite {
intercept[IllegalArgumentException](attr.withName(""))
intercept[IllegalArgumentException](attr.withIndex(-1))
}
+
+ test("attribute from struct field") {
+ val metadata = NumericAttribute.defaultAttr.withName("label").toMetadata()
+ val fldWithoutMeta = new StructField("x", DoubleType, false, Metadata.empty)
+ assert(Attribute.fromStructField(fldWithoutMeta) == UnresolvedAttribute)
+ val fldWithMeta = new StructField("x", DoubleType, false, metadata)
+ assert(Attribute.fromStructField(fldWithMeta).isNumeric)
+ }
}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/reduction/OneVsRestSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/reduction/OneVsRestSuite.scala
new file mode 100644
index 0000000000..ebec7c68e8
--- /dev/null
+++ b/mllib/src/test/scala/org/apache/spark/ml/reduction/OneVsRestSuite.scala
@@ -0,0 +1,113 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.reduction
+
+import org.scalatest.FunSuite
+
+import org.apache.spark.ml.attribute.NominalAttribute
+import org.apache.spark.ml.classification.{LogisticRegressionModel, LogisticRegression}
+import org.apache.spark.ml.util.MetadataUtils
+import org.apache.spark.mllib.classification.LogisticRegressionSuite._
+import org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS
+import org.apache.spark.mllib.evaluation.MulticlassMetrics
+import org.apache.spark.mllib.regression.LabeledPoint
+import org.apache.spark.mllib.util.MLlibTestSparkContext
+import org.apache.spark.mllib.util.TestingUtils._
+import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.{DataFrame, SQLContext}
+
+class OneVsRestSuite extends FunSuite with MLlibTestSparkContext {
+
+ @transient var sqlContext: SQLContext = _
+ @transient var dataset: DataFrame = _
+ @transient var rdd: RDD[LabeledPoint] = _
+
+ override def beforeAll(): Unit = {
+ super.beforeAll()
+ sqlContext = new SQLContext(sc)
+ val nPoints = 1000
+
+ /**
+ * The following weights and xMean/xVariance are computed from iris dataset with lambda = 0.2.
+ * As a result, we are actually drawing samples from probability distribution of built model.
+ */
+ val weights = Array(
+ -0.57997, 0.912083, -0.371077, -0.819866, 2.688191,
+ -0.16624, -0.84355, -0.048509, -0.301789, 4.170682)
+
+ val xMean = Array(5.843, 3.057, 3.758, 1.199)
+ val xVariance = Array(0.6856, 0.1899, 3.116, 0.581)
+ rdd = sc.parallelize(generateMultinomialLogisticInput(
+ weights, xMean, xVariance, true, nPoints, 42), 2)
+ dataset = sqlContext.createDataFrame(rdd)
+ }
+
+ test("one-vs-rest: default params") {
+ val numClasses = 3
+ val ova = new OneVsRest()
+ ova.setClassifier(new LogisticRegression)
+ assert(ova.getLabelCol === "label")
+ assert(ova.getPredictionCol === "prediction")
+ val ovaModel = ova.fit(dataset)
+ assert(ovaModel.models.size === numClasses)
+ val transformedDataset = ovaModel.transform(dataset)
+
+ // check for label metadata in prediction col
+ val predictionColSchema = transformedDataset.schema(ovaModel.getPredictionCol)
+ assert(MetadataUtils.getNumClasses(predictionColSchema) === Some(3))
+
+ val ovaResults = transformedDataset
+ .select("prediction", "label")
+ .map(row => (row.getDouble(0), row.getDouble(1)))
+
+ val lr = new LogisticRegressionWithLBFGS().setIntercept(true).setNumClasses(numClasses)
+ lr.optimizer.setRegParam(0.1).setNumIterations(100)
+
+ val model = lr.run(rdd)
+ val results = model.predict(rdd.map(_.features)).zip(rdd.map(_.label))
+ // determine the #confusion matrix in each class.
+ // bound how much error we allow compared to multinomial logistic regression.
+ val expectedMetrics = new MulticlassMetrics(results)
+ val ovaMetrics = new MulticlassMetrics(ovaResults)
+ assert(expectedMetrics.confusionMatrix ~== ovaMetrics.confusionMatrix absTol 400)
+ }
+
+ test("one-vs-rest: pass label metadata correctly during train") {
+ val numClasses = 3
+ val ova = new OneVsRest()
+ ova.setClassifier(new MockLogisticRegression)
+
+ val labelMetadata = NominalAttribute.defaultAttr.withName("label").withNumValues(numClasses)
+ val labelWithMetadata = dataset("label").as("label", labelMetadata.toMetadata())
+ val features = dataset("features").as("features")
+ val datasetWithLabelMetadata = dataset.select(labelWithMetadata, features)
+ ova.fit(datasetWithLabelMetadata)
+ }
+}
+
+private class MockLogisticRegression extends LogisticRegression {
+
+ setMaxIter(1)
+
+ override protected def train(dataset: DataFrame): LogisticRegressionModel = {
+ val labelSchema = dataset.schema($(labelCol))
+ // check for label attribute propagation.
+ assert(MetadataUtils.getNumClasses(labelSchema).forall(_ == 2))
+ super.train(dataset)
+ }
+}