From 96c4846db89802f5a81dca5dcfa3f2a0f72b5cb8 Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Tue, 12 May 2015 16:42:30 -0700 Subject: [SPARK-7573] [ML] OneVsRest cleanups Minor cleanups discussed with [~mengxr]: * move OneVsRest from reduction to classification sub-package * make model constructor private Some doc cleanups too CC: harsha2010 Could you please verify this looks OK? Thanks! Author: Joseph K. Bradley Closes #6097 from jkbradley/onevsrest-cleanup and squashes the following commits: 4ecd48d [Joseph K. Bradley] org imports 430b065 [Joseph K. Bradley] moved OneVsRest from reduction subpackage to classification. small java doc style fixes 9f8b9b9 [Joseph K. Bradley] Small cleanups to OneVsRest. Made model constructor private to ml package. --- .../apache/spark/ml/classification/OneVsRest.scala | 209 ++++++++++++++++++++ .../org/apache/spark/ml/reduction/OneVsRest.scala | 211 --------------------- .../ml/classification/JavaOneVsRestSuite.java | 82 ++++++++ .../spark/ml/reduction/JavaOneVsRestSuite.java | 85 --------- .../spark/ml/classification/OneVsRestSuite.scala | 110 +++++++++++ .../apache/spark/ml/reduction/OneVsRestSuite.scala | 113 ----------- 6 files changed, 401 insertions(+), 409 deletions(-) create mode 100644 mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala delete mode 100644 mllib/src/main/scala/org/apache/spark/ml/reduction/OneVsRest.scala create mode 100644 mllib/src/test/java/org/apache/spark/ml/classification/JavaOneVsRestSuite.java delete mode 100644 mllib/src/test/java/org/apache/spark/ml/reduction/JavaOneVsRestSuite.java create mode 100644 mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala delete mode 100644 mllib/src/test/scala/org/apache/spark/ml/reduction/OneVsRestSuite.scala (limited to 'mllib') diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala new file mode 100644 index 0000000000..afb8d75d57 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala @@ -0,0 +1,209 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.classification + +import java.util.UUID + +import scala.language.existentials + +import org.apache.spark.annotation.{AlphaComponent, Experimental} +import org.apache.spark.ml._ +import org.apache.spark.ml.attribute._ +import org.apache.spark.ml.param.Param +import org.apache.spark.ml.util.MetadataUtils +import org.apache.spark.mllib.linalg.Vector +import org.apache.spark.sql.{DataFrame, Row} +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.types._ +import org.apache.spark.storage.StorageLevel + +/** + * Params for [[OneVsRest]]. + */ +private[ml] trait OneVsRestParams extends PredictorParams { + + type ClassifierType = Classifier[F, E, M] forSome { + type F + type M <: ClassificationModel[F, M] + type E <: Classifier[F, E, M] + } + + /** + * param for the base binary classifier that we reduce multiclass classification into. + * @group param + */ + val classifier: Param[ClassifierType] = + new Param(this, "classifier", "base binary classifier ") + + /** @group getParam */ + def getClassifier: ClassifierType = $(classifier) + +} + +/** + * :: AlphaComponent :: + * + * Model produced by [[OneVsRest]]. + * This stores the models resulting from training k binary classifiers: one for each class. + * Each example is scored against all k models, and the model with the highest score + * is picked to label the example. + * + * @param labelMetadata Metadata of label column if it exists, or Nominal attribute + * representing the number of classes in training dataset otherwise. + * @param models The binary classification models for the reduction. + * The i-th model is produced by testing the i-th class (taking label 1) vs the rest + * (taking label 0). + */ +@AlphaComponent +class OneVsRestModel private[ml] ( + override val parent: OneVsRest, + labelMetadata: Metadata, + val models: Array[_ <: ClassificationModel[_,_]]) + extends Model[OneVsRestModel] with OneVsRestParams { + + override def transformSchema(schema: StructType): StructType = { + validateAndTransformSchema(schema, fitting = false, getClassifier.featuresDataType) + } + + override def transform(dataset: DataFrame): DataFrame = { + // Check schema + transformSchema(dataset.schema, logging = true) + + // determine the input columns: these need to be passed through + val origCols = dataset.schema.map(f => col(f.name)) + + // add an accumulator column to store predictions of all the models + val accColName = "mbc$acc" + UUID.randomUUID().toString + val init: () => Map[Int, Double] = () => {Map()} + val mapType = MapType(IntegerType, DoubleType, valueContainsNull = false) + val newDataset = dataset.withColumn(accColName, callUDF(init, mapType)) + + // persist if underlying dataset is not persistent. + val handlePersistence = dataset.rdd.getStorageLevel == StorageLevel.NONE + if (handlePersistence) { + newDataset.persist(StorageLevel.MEMORY_AND_DISK) + } + + // update the accumulator column with the result of prediction of models + val aggregatedDataset = models.zipWithIndex.foldLeft[DataFrame](newDataset) { + case (df, (model, index)) => + val rawPredictionCol = model.getRawPredictionCol + val columns = origCols ++ List(col(rawPredictionCol), col(accColName)) + + // add temporary column to store intermediate scores and update + val tmpColName = "mbc$tmp" + UUID.randomUUID().toString + val update: (Map[Int, Double], Vector) => Map[Int, Double] = + (predictions: Map[Int, Double], prediction: Vector) => { + predictions + ((index, prediction(1))) + } + val updateUdf = callUDF(update, mapType, col(accColName), col(rawPredictionCol)) + val transformedDataset = model.transform(df).select(columns:_*) + val updatedDataset = transformedDataset.withColumn(tmpColName, updateUdf) + val newColumns = origCols ++ List(col(tmpColName)) + + // switch out the intermediate column with the accumulator column + updatedDataset.select(newColumns:_*).withColumnRenamed(tmpColName, accColName) + } + + if (handlePersistence) { + newDataset.unpersist() + } + + // output the index of the classifier with highest confidence as prediction + val label: Map[Int, Double] => Double = (predictions: Map[Int, Double]) => { + predictions.maxBy(_._2)._1.toDouble + } + + // output label and label metadata as prediction + val labelUdf = callUDF(label, DoubleType, col(accColName)) + aggregatedDataset.withColumn($(predictionCol), labelUdf.as($(predictionCol), labelMetadata)) + } +} + +/** + * :: Experimental :: + * + * Reduction of Multiclass Classification to Binary Classification. + * Performs reduction using one against all strategy. + * For a multiclass classification with k classes, train k models (one per class). + * Each example is scored against all k models and the model with highest score + * is picked to label the example. + */ +@Experimental +final class OneVsRest extends Estimator[OneVsRestModel] with OneVsRestParams { + + /** @group setParam */ + def setClassifier(value: Classifier[_,_,_]): this.type = { + // TODO: Find a better way to do this. Existential Types don't work with Java API so cast needed + set(classifier, value.asInstanceOf[ClassifierType]) + } + + override def transformSchema(schema: StructType): StructType = { + validateAndTransformSchema(schema, fitting = true, getClassifier.featuresDataType) + } + + override def fit(dataset: DataFrame): OneVsRestModel = { + // determine number of classes either from metadata if provided, or via computation. + val labelSchema = dataset.schema($(labelCol)) + val computeNumClasses: () => Int = () => { + val Row(maxLabelIndex: Double) = dataset.agg(max($(labelCol))).head() + // classes are assumed to be numbered from 0,...,maxLabelIndex + maxLabelIndex.toInt + 1 + } + val numClasses = MetadataUtils.getNumClasses(labelSchema).fold(computeNumClasses())(identity) + + val multiclassLabeled = dataset.select($(labelCol), $(featuresCol)) + + // persist if underlying dataset is not persistent. + val handlePersistence = dataset.rdd.getStorageLevel == StorageLevel.NONE + if (handlePersistence) { + multiclassLabeled.persist(StorageLevel.MEMORY_AND_DISK) + } + + // create k columns, one for each binary classifier. + val models = Range(0, numClasses).par.map { index => + + val label: Double => Double = (label: Double) => { + if (label.toInt == index) 1.0 else 0.0 + } + + // generate new label metadata for the binary problem. + // TODO: use when ... otherwise after SPARK-7321 is merged + val labelUDF = callUDF(label, DoubleType, col($(labelCol))) + val newLabelMeta = BinaryAttribute.defaultAttr.withName("label").toMetadata() + val labelColName = "mc2b$" + index + val labelUDFWithNewMeta = labelUDF.as(labelColName, newLabelMeta) + val trainingDataset = multiclassLabeled.withColumn(labelColName, labelUDFWithNewMeta) + val classifier = getClassifier + classifier.fit(trainingDataset, classifier.labelCol -> labelColName) + }.toArray[ClassificationModel[_,_]] + + if (handlePersistence) { + multiclassLabeled.unpersist() + } + + // extract label metadata from label column if present, or create a nominal attribute + // to output the number of labels + val labelAttribute = Attribute.fromStructField(labelSchema) match { + case _: NumericAttribute | UnresolvedAttribute => + NominalAttribute.defaultAttr.withName("label").withNumValues(numClasses) + case attr: Attribute => attr + } + copyValues(new OneVsRestModel(this, labelAttribute.toMetadata(), models)) + } +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/reduction/OneVsRest.scala b/mllib/src/main/scala/org/apache/spark/ml/reduction/OneVsRest.scala deleted file mode 100644 index 0a6728ef1f..0000000000 --- a/mllib/src/main/scala/org/apache/spark/ml/reduction/OneVsRest.scala +++ /dev/null @@ -1,211 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.ml.reduction - -import java.util.UUID - -import scala.language.existentials - -import org.apache.spark.annotation.{AlphaComponent, Experimental} -import org.apache.spark.ml._ -import org.apache.spark.ml.attribute._ -import org.apache.spark.ml.classification.{ClassificationModel, Classifier} -import org.apache.spark.ml.param.Param -import org.apache.spark.ml.util.MetadataUtils -import org.apache.spark.mllib.linalg.Vector -import org.apache.spark.sql.{DataFrame, Row} -import org.apache.spark.sql.functions._ -import org.apache.spark.sql.types._ -import org.apache.spark.storage.StorageLevel - -/** - * Params for [[OneVsRest]]. - */ -private[ml] trait OneVsRestParams extends PredictorParams { - - type ClassifierType = Classifier[F, E, M] forSome { - type F - type M <: ClassificationModel[F, M] - type E <: Classifier[F, E, M] - } - - /** - * param for the base binary classifier that we reduce multiclass classification into. - * @group param - */ - val classifier: Param[ClassifierType] = - new Param(this, "classifier", "base binary classifier ") - - /** @group getParam */ - def getClassifier: ClassifierType = $(classifier) - -} - -/** - * Model produced by [[OneVsRest]]. - * Stores the models resulting from training k different classifiers: - * one for each class. - * Each example is scored against all k models and the model with highest score - * is picked to label the example. - * TODO: API may need to change when we introduce a ClassificationModel trait as the public API - * @param parent - * @param labelMetadata Metadata of label column if it exists, or Nominal attribute - * representing the number of classes in training dataset otherwise. - * @param models the binary classification models for reduction. - * The i-th model is produced by testing the i-th class vs the rest. - */ -@AlphaComponent -class OneVsRestModel( - override val parent: OneVsRest, - labelMetadata: Metadata, - val models: Array[_ <: ClassificationModel[_,_]]) - extends Model[OneVsRestModel] with OneVsRestParams { - - override def transformSchema(schema: StructType): StructType = { - validateAndTransformSchema(schema, fitting = false, getClassifier.featuresDataType) - } - - override def transform(dataset: DataFrame): DataFrame = { - // Check schema - transformSchema(dataset.schema, logging = true) - - // determine the input columns: these need to be passed through - val origCols = dataset.schema.map(f => col(f.name)) - - // add an accumulator column to store predictions of all the models - val accColName = "mbc$acc" + UUID.randomUUID().toString - val init: () => Map[Int, Double] = () => {Map()} - val mapType = MapType(IntegerType, DoubleType, false) - val newDataset = dataset.withColumn(accColName, callUDF(init, mapType)) - - // persist if underlying dataset is not persistent. - val handlePersistence = dataset.rdd.getStorageLevel == StorageLevel.NONE - if (handlePersistence) { - newDataset.persist(StorageLevel.MEMORY_AND_DISK) - } - - // update the accumulator column with the result of prediction of models - val aggregatedDataset = models.zipWithIndex.foldLeft[DataFrame](newDataset) { - case (df, (model, index)) => { - val rawPredictionCol = model.getRawPredictionCol - val columns = origCols ++ List(col(rawPredictionCol), col(accColName)) - - // add temporary column to store intermediate scores and update - val tmpColName = "mbc$tmp" + UUID.randomUUID().toString - val update: (Map[Int, Double], Vector) => Map[Int, Double] = - (predictions: Map[Int, Double], prediction: Vector) => { - predictions + ((index, prediction(1))) - } - val updateUdf = callUDF(update, mapType, col(accColName), col(rawPredictionCol)) - val transformedDataset = model.transform(df).select(columns:_*) - val updatedDataset = transformedDataset.withColumn(tmpColName, updateUdf) - val newColumns = origCols ++ List(col(tmpColName)) - - // switch out the intermediate column with the accumulator column - updatedDataset.select(newColumns:_*).withColumnRenamed(tmpColName, accColName) - } - } - - if (handlePersistence) { - newDataset.unpersist() - } - - // output the index of the classifier with highest confidence as prediction - val label: Map[Int, Double] => Double = (predictions: Map[Int, Double]) => { - predictions.maxBy(_._2)._1.toDouble - } - - // output label and label metadata as prediction - val labelUdf = callUDF(label, DoubleType, col(accColName)) - aggregatedDataset.withColumn($(predictionCol), labelUdf.as($(predictionCol), labelMetadata)) - } -} - -/** - * :: Experimental :: - * - * Reduction of Multiclass Classification to Binary Classification. - * Performs reduction using one against all strategy. - * For a multiclass classification with k classes, train k models (one per class). - * Each example is scored against all k models and the model with highest score - * is picked to label the example. - */ -@Experimental -final class OneVsRest extends Estimator[OneVsRestModel] with OneVsRestParams { - - /** @group setParam */ - // TODO: Find a better way to do this. Existential Types don't work with Java API so cast needed. - def setClassifier(value: Classifier[_,_,_]): this.type = { - set(classifier, value.asInstanceOf[ClassifierType]) - } - - override def transformSchema(schema: StructType): StructType = { - validateAndTransformSchema(schema, fitting = true, getClassifier.featuresDataType) - } - - override def fit(dataset: DataFrame): OneVsRestModel = { - // determine number of classes either from metadata if provided, or via computation. - val labelSchema = dataset.schema($(labelCol)) - val computeNumClasses: () => Int = () => { - val Row(maxLabelIndex: Double) = dataset.agg(max($(labelCol))).head() - // classes are assumed to be numbered from 0,...,maxLabelIndex - maxLabelIndex.toInt + 1 - } - val numClasses = MetadataUtils.getNumClasses(labelSchema).fold(computeNumClasses())(identity) - - val multiclassLabeled = dataset.select($(labelCol), $(featuresCol)) - - // persist if underlying dataset is not persistent. - val handlePersistence = dataset.rdd.getStorageLevel == StorageLevel.NONE - if (handlePersistence) { - multiclassLabeled.persist(StorageLevel.MEMORY_AND_DISK) - } - - // create k columns, one for each binary classifier. - val models = Range(0, numClasses).par.map { index => - - val label: Double => Double = (label: Double) => { - if (label.toInt == index) 1.0 else 0.0 - } - - // generate new label metadata for the binary problem. - // TODO: use when ... otherwise after SPARK-7321 is merged - val labelUDF = callUDF(label, DoubleType, col($(labelCol))) - val newLabelMeta = BinaryAttribute.defaultAttr.withName("label").toMetadata() - val labelColName = "mc2b$" + index - val labelUDFWithNewMeta = labelUDF.as(labelColName, newLabelMeta) - val trainingDataset = multiclassLabeled.withColumn(labelColName, labelUDFWithNewMeta) - val classifier = getClassifier - classifier.fit(trainingDataset, classifier.labelCol -> labelColName) - }.toArray[ClassificationModel[_,_]] - - if (handlePersistence) { - multiclassLabeled.unpersist() - } - - // extract label metadata from label column if present, or create a nominal attribute - // to output the number of labels - val labelAttribute = Attribute.fromStructField(labelSchema) match { - case _: NumericAttribute | UnresolvedAttribute => { - NominalAttribute.defaultAttr.withName("label").withNumValues(numClasses) - } - case attr: Attribute => attr - } - copyValues(new OneVsRestModel(this, labelAttribute.toMetadata(), models)) - } -} diff --git a/mllib/src/test/java/org/apache/spark/ml/classification/JavaOneVsRestSuite.java b/mllib/src/test/java/org/apache/spark/ml/classification/JavaOneVsRestSuite.java new file mode 100644 index 0000000000..a1ee554152 --- /dev/null +++ b/mllib/src/test/java/org/apache/spark/ml/classification/JavaOneVsRestSuite.java @@ -0,0 +1,82 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.classification; + +import java.io.Serializable; +import java.util.List; + +import static scala.collection.JavaConversions.seqAsJavaList; + +import org.junit.After; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; +import static org.apache.spark.mllib.classification.LogisticRegressionSuite.generateMultinomialLogisticInput; +import org.apache.spark.mllib.regression.LabeledPoint; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.SQLContext; + +public class JavaOneVsRestSuite implements Serializable { + + private transient JavaSparkContext jsc; + private transient SQLContext jsql; + private transient DataFrame dataset; + private transient JavaRDD datasetRDD; + + @Before + public void setUp() { + jsc = new JavaSparkContext("local", "JavaLOneVsRestSuite"); + jsql = new SQLContext(jsc); + int nPoints = 3; + + // The following weights and xMean/xVariance are computed from iris dataset with lambda=0.2. + // As a result, we are drawing samples from probability distribution of an actual model. + double[] weights = { + -0.57997, 0.912083, -0.371077, -0.819866, 2.688191, + -0.16624, -0.84355, -0.048509, -0.301789, 4.170682 }; + + double[] xMean = {5.843, 3.057, 3.758, 1.199}; + double[] xVariance = {0.6856, 0.1899, 3.116, 0.581}; + List points = seqAsJavaList(generateMultinomialLogisticInput( + weights, xMean, xVariance, true, nPoints, 42)); + datasetRDD = jsc.parallelize(points, 2); + dataset = jsql.createDataFrame(datasetRDD, LabeledPoint.class); + } + + @After + public void tearDown() { + jsc.stop(); + jsc = null; + } + + @Test + public void oneVsRestDefaultParams() { + OneVsRest ova = new OneVsRest(); + ova.setClassifier(new LogisticRegression()); + Assert.assertEquals(ova.getLabelCol() , "label"); + Assert.assertEquals(ova.getPredictionCol() , "prediction"); + OneVsRestModel ovaModel = ova.fit(dataset); + DataFrame predictions = ovaModel.transform(dataset).select("label", "prediction"); + predictions.collectAsList(); + Assert.assertEquals(ovaModel.getLabelCol(), "label"); + Assert.assertEquals(ovaModel.getPredictionCol() , "prediction"); + } +} diff --git a/mllib/src/test/java/org/apache/spark/ml/reduction/JavaOneVsRestSuite.java b/mllib/src/test/java/org/apache/spark/ml/reduction/JavaOneVsRestSuite.java deleted file mode 100644 index 40a90ae9de..0000000000 --- a/mllib/src/test/java/org/apache/spark/ml/reduction/JavaOneVsRestSuite.java +++ /dev/null @@ -1,85 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.ml.reduction; - -import java.io.Serializable; -import java.util.List; - -import org.junit.After; -import org.junit.Assert; -import org.junit.Before; -import org.junit.Test; - -import static scala.collection.JavaConversions.seqAsJavaList; - -import org.apache.spark.api.java.JavaRDD; -import org.apache.spark.api.java.JavaSparkContext; -import org.apache.spark.ml.classification.LogisticRegression; -import static org.apache.spark.mllib.classification.LogisticRegressionSuite.generateMultinomialLogisticInput; -import org.apache.spark.mllib.regression.LabeledPoint; -import org.apache.spark.sql.DataFrame; -import org.apache.spark.sql.SQLContext; - -public class JavaOneVsRestSuite implements Serializable { - - private transient JavaSparkContext jsc; - private transient SQLContext jsql; - private transient DataFrame dataset; - private transient JavaRDD datasetRDD; - - @Before - public void setUp() { - jsc = new JavaSparkContext("local", "JavaLOneVsRestSuite"); - jsql = new SQLContext(jsc); - int nPoints = 3; - - /** - * The following weights and xMean/xVariance are computed from iris dataset with lambda = 0.2. - * As a result, we are actually drawing samples from probability distribution of built model. - */ - double[] weights = { - -0.57997, 0.912083, -0.371077, -0.819866, 2.688191, - -0.16624, -0.84355, -0.048509, -0.301789, 4.170682 }; - - double[] xMean = {5.843, 3.057, 3.758, 1.199}; - double[] xVariance = {0.6856, 0.1899, 3.116, 0.581}; - List points = seqAsJavaList(generateMultinomialLogisticInput( - weights, xMean, xVariance, true, nPoints, 42)); - datasetRDD = jsc.parallelize(points, 2); - dataset = jsql.createDataFrame(datasetRDD, LabeledPoint.class); - } - - @After - public void tearDown() { - jsc.stop(); - jsc = null; - } - - @Test - public void oneVsRestDefaultParams() { - OneVsRest ova = new OneVsRest(); - ova.setClassifier(new LogisticRegression()); - Assert.assertEquals(ova.getLabelCol() , "label"); - Assert.assertEquals(ova.getPredictionCol() , "prediction"); - OneVsRestModel ovaModel = ova.fit(dataset); - DataFrame predictions = ovaModel.transform(dataset).select("label", "prediction"); - predictions.collectAsList(); - Assert.assertEquals(ovaModel.getLabelCol(), "label"); - Assert.assertEquals(ovaModel.getPredictionCol() , "prediction"); - } -} diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala new file mode 100644 index 0000000000..e65ffae918 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala @@ -0,0 +1,110 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.classification + +import org.scalatest.FunSuite + +import org.apache.spark.ml.attribute.NominalAttribute +import org.apache.spark.ml.util.MetadataUtils +import org.apache.spark.mllib.classification.LogisticRegressionSuite._ +import org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS +import org.apache.spark.mllib.evaluation.MulticlassMetrics +import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.mllib.util.TestingUtils._ +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.{DataFrame, SQLContext} + +class OneVsRestSuite extends FunSuite with MLlibTestSparkContext { + + @transient var sqlContext: SQLContext = _ + @transient var dataset: DataFrame = _ + @transient var rdd: RDD[LabeledPoint] = _ + + override def beforeAll(): Unit = { + super.beforeAll() + sqlContext = new SQLContext(sc) + val nPoints = 1000 + + // The following weights and xMean/xVariance are computed from iris dataset with lambda=0.2. + // As a result, we are drawing samples from probability distribution of an actual model. + val weights = Array( + -0.57997, 0.912083, -0.371077, -0.819866, 2.688191, + -0.16624, -0.84355, -0.048509, -0.301789, 4.170682) + + val xMean = Array(5.843, 3.057, 3.758, 1.199) + val xVariance = Array(0.6856, 0.1899, 3.116, 0.581) + rdd = sc.parallelize(generateMultinomialLogisticInput( + weights, xMean, xVariance, true, nPoints, 42), 2) + dataset = sqlContext.createDataFrame(rdd) + } + + test("one-vs-rest: default params") { + val numClasses = 3 + val ova = new OneVsRest() + ova.setClassifier(new LogisticRegression) + assert(ova.getLabelCol === "label") + assert(ova.getPredictionCol === "prediction") + val ovaModel = ova.fit(dataset) + assert(ovaModel.models.size === numClasses) + val transformedDataset = ovaModel.transform(dataset) + + // check for label metadata in prediction col + val predictionColSchema = transformedDataset.schema(ovaModel.getPredictionCol) + assert(MetadataUtils.getNumClasses(predictionColSchema) === Some(3)) + + val ovaResults = transformedDataset + .select("prediction", "label") + .map(row => (row.getDouble(0), row.getDouble(1))) + + val lr = new LogisticRegressionWithLBFGS().setIntercept(true).setNumClasses(numClasses) + lr.optimizer.setRegParam(0.1).setNumIterations(100) + + val model = lr.run(rdd) + val results = model.predict(rdd.map(_.features)).zip(rdd.map(_.label)) + // determine the #confusion matrix in each class. + // bound how much error we allow compared to multinomial logistic regression. + val expectedMetrics = new MulticlassMetrics(results) + val ovaMetrics = new MulticlassMetrics(ovaResults) + assert(expectedMetrics.confusionMatrix ~== ovaMetrics.confusionMatrix absTol 400) + } + + test("one-vs-rest: pass label metadata correctly during train") { + val numClasses = 3 + val ova = new OneVsRest() + ova.setClassifier(new MockLogisticRegression) + + val labelMetadata = NominalAttribute.defaultAttr.withName("label").withNumValues(numClasses) + val labelWithMetadata = dataset("label").as("label", labelMetadata.toMetadata()) + val features = dataset("features").as("features") + val datasetWithLabelMetadata = dataset.select(labelWithMetadata, features) + ova.fit(datasetWithLabelMetadata) + } +} + +private class MockLogisticRegression extends LogisticRegression { + + setMaxIter(1) + + override protected def train(dataset: DataFrame): LogisticRegressionModel = { + val labelSchema = dataset.schema($(labelCol)) + // check for label attribute propagation. + assert(MetadataUtils.getNumClasses(labelSchema).forall(_ == 2)) + super.train(dataset) + } +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/reduction/OneVsRestSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/reduction/OneVsRestSuite.scala deleted file mode 100644 index ebec7c68e8..0000000000 --- a/mllib/src/test/scala/org/apache/spark/ml/reduction/OneVsRestSuite.scala +++ /dev/null @@ -1,113 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.ml.reduction - -import org.scalatest.FunSuite - -import org.apache.spark.ml.attribute.NominalAttribute -import org.apache.spark.ml.classification.{LogisticRegressionModel, LogisticRegression} -import org.apache.spark.ml.util.MetadataUtils -import org.apache.spark.mllib.classification.LogisticRegressionSuite._ -import org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS -import org.apache.spark.mllib.evaluation.MulticlassMetrics -import org.apache.spark.mllib.regression.LabeledPoint -import org.apache.spark.mllib.util.MLlibTestSparkContext -import org.apache.spark.mllib.util.TestingUtils._ -import org.apache.spark.rdd.RDD -import org.apache.spark.sql.{DataFrame, SQLContext} - -class OneVsRestSuite extends FunSuite with MLlibTestSparkContext { - - @transient var sqlContext: SQLContext = _ - @transient var dataset: DataFrame = _ - @transient var rdd: RDD[LabeledPoint] = _ - - override def beforeAll(): Unit = { - super.beforeAll() - sqlContext = new SQLContext(sc) - val nPoints = 1000 - - /** - * The following weights and xMean/xVariance are computed from iris dataset with lambda = 0.2. - * As a result, we are actually drawing samples from probability distribution of built model. - */ - val weights = Array( - -0.57997, 0.912083, -0.371077, -0.819866, 2.688191, - -0.16624, -0.84355, -0.048509, -0.301789, 4.170682) - - val xMean = Array(5.843, 3.057, 3.758, 1.199) - val xVariance = Array(0.6856, 0.1899, 3.116, 0.581) - rdd = sc.parallelize(generateMultinomialLogisticInput( - weights, xMean, xVariance, true, nPoints, 42), 2) - dataset = sqlContext.createDataFrame(rdd) - } - - test("one-vs-rest: default params") { - val numClasses = 3 - val ova = new OneVsRest() - ova.setClassifier(new LogisticRegression) - assert(ova.getLabelCol === "label") - assert(ova.getPredictionCol === "prediction") - val ovaModel = ova.fit(dataset) - assert(ovaModel.models.size === numClasses) - val transformedDataset = ovaModel.transform(dataset) - - // check for label metadata in prediction col - val predictionColSchema = transformedDataset.schema(ovaModel.getPredictionCol) - assert(MetadataUtils.getNumClasses(predictionColSchema) === Some(3)) - - val ovaResults = transformedDataset - .select("prediction", "label") - .map(row => (row.getDouble(0), row.getDouble(1))) - - val lr = new LogisticRegressionWithLBFGS().setIntercept(true).setNumClasses(numClasses) - lr.optimizer.setRegParam(0.1).setNumIterations(100) - - val model = lr.run(rdd) - val results = model.predict(rdd.map(_.features)).zip(rdd.map(_.label)) - // determine the #confusion matrix in each class. - // bound how much error we allow compared to multinomial logistic regression. - val expectedMetrics = new MulticlassMetrics(results) - val ovaMetrics = new MulticlassMetrics(ovaResults) - assert(expectedMetrics.confusionMatrix ~== ovaMetrics.confusionMatrix absTol 400) - } - - test("one-vs-rest: pass label metadata correctly during train") { - val numClasses = 3 - val ova = new OneVsRest() - ova.setClassifier(new MockLogisticRegression) - - val labelMetadata = NominalAttribute.defaultAttr.withName("label").withNumValues(numClasses) - val labelWithMetadata = dataset("label").as("label", labelMetadata.toMetadata()) - val features = dataset("features").as("features") - val datasetWithLabelMetadata = dataset.select(labelWithMetadata, features) - ova.fit(datasetWithLabelMetadata) - } -} - -private class MockLogisticRegression extends LogisticRegression { - - setMaxIter(1) - - override protected def train(dataset: DataFrame): LogisticRegressionModel = { - val labelSchema = dataset.schema($(labelCol)) - // check for label attribute propagation. - assert(MetadataUtils.getNumClasses(labelSchema).forall(_ == 2)) - super.train(dataset) - } -} -- cgit v1.2.3