From 96c4846db89802f5a81dca5dcfa3f2a0f72b5cb8 Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Tue, 12 May 2015 16:42:30 -0700 Subject: [SPARK-7573] [ML] OneVsRest cleanups Minor cleanups discussed with [~mengxr]: * move OneVsRest from reduction to classification sub-package * make model constructor private Some doc cleanups too CC: harsha2010 Could you please verify this looks OK? Thanks! Author: Joseph K. Bradley Closes #6097 from jkbradley/onevsrest-cleanup and squashes the following commits: 4ecd48d [Joseph K. Bradley] org imports 430b065 [Joseph K. Bradley] moved OneVsRest from reduction subpackage to classification. small java doc style fixes 9f8b9b9 [Joseph K. Bradley] Small cleanups to OneVsRest. Made model constructor private to ml package. --- .../ml/classification/JavaOneVsRestSuite.java | 82 +++++++++++++++ .../spark/ml/reduction/JavaOneVsRestSuite.java | 85 ---------------- .../spark/ml/classification/OneVsRestSuite.scala | 110 ++++++++++++++++++++ .../apache/spark/ml/reduction/OneVsRestSuite.scala | 113 --------------------- 4 files changed, 192 insertions(+), 198 deletions(-) create mode 100644 mllib/src/test/java/org/apache/spark/ml/classification/JavaOneVsRestSuite.java delete mode 100644 mllib/src/test/java/org/apache/spark/ml/reduction/JavaOneVsRestSuite.java create mode 100644 mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala delete mode 100644 mllib/src/test/scala/org/apache/spark/ml/reduction/OneVsRestSuite.scala (limited to 'mllib/src/test') diff --git a/mllib/src/test/java/org/apache/spark/ml/classification/JavaOneVsRestSuite.java b/mllib/src/test/java/org/apache/spark/ml/classification/JavaOneVsRestSuite.java new file mode 100644 index 0000000000..a1ee554152 --- /dev/null +++ b/mllib/src/test/java/org/apache/spark/ml/classification/JavaOneVsRestSuite.java @@ -0,0 +1,82 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.classification; + +import java.io.Serializable; +import java.util.List; + +import static scala.collection.JavaConversions.seqAsJavaList; + +import org.junit.After; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; +import static org.apache.spark.mllib.classification.LogisticRegressionSuite.generateMultinomialLogisticInput; +import org.apache.spark.mllib.regression.LabeledPoint; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.SQLContext; + +public class JavaOneVsRestSuite implements Serializable { + + private transient JavaSparkContext jsc; + private transient SQLContext jsql; + private transient DataFrame dataset; + private transient JavaRDD datasetRDD; + + @Before + public void setUp() { + jsc = new JavaSparkContext("local", "JavaLOneVsRestSuite"); + jsql = new SQLContext(jsc); + int nPoints = 3; + + // The following weights and xMean/xVariance are computed from iris dataset with lambda=0.2. + // As a result, we are drawing samples from probability distribution of an actual model. + double[] weights = { + -0.57997, 0.912083, -0.371077, -0.819866, 2.688191, + -0.16624, -0.84355, -0.048509, -0.301789, 4.170682 }; + + double[] xMean = {5.843, 3.057, 3.758, 1.199}; + double[] xVariance = {0.6856, 0.1899, 3.116, 0.581}; + List points = seqAsJavaList(generateMultinomialLogisticInput( + weights, xMean, xVariance, true, nPoints, 42)); + datasetRDD = jsc.parallelize(points, 2); + dataset = jsql.createDataFrame(datasetRDD, LabeledPoint.class); + } + + @After + public void tearDown() { + jsc.stop(); + jsc = null; + } + + @Test + public void oneVsRestDefaultParams() { + OneVsRest ova = new OneVsRest(); + ova.setClassifier(new LogisticRegression()); + Assert.assertEquals(ova.getLabelCol() , "label"); + Assert.assertEquals(ova.getPredictionCol() , "prediction"); + OneVsRestModel ovaModel = ova.fit(dataset); + DataFrame predictions = ovaModel.transform(dataset).select("label", "prediction"); + predictions.collectAsList(); + Assert.assertEquals(ovaModel.getLabelCol(), "label"); + Assert.assertEquals(ovaModel.getPredictionCol() , "prediction"); + } +} diff --git a/mllib/src/test/java/org/apache/spark/ml/reduction/JavaOneVsRestSuite.java b/mllib/src/test/java/org/apache/spark/ml/reduction/JavaOneVsRestSuite.java deleted file mode 100644 index 40a90ae9de..0000000000 --- a/mllib/src/test/java/org/apache/spark/ml/reduction/JavaOneVsRestSuite.java +++ /dev/null @@ -1,85 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.ml.reduction; - -import java.io.Serializable; -import java.util.List; - -import org.junit.After; -import org.junit.Assert; -import org.junit.Before; -import org.junit.Test; - -import static scala.collection.JavaConversions.seqAsJavaList; - -import org.apache.spark.api.java.JavaRDD; -import org.apache.spark.api.java.JavaSparkContext; -import org.apache.spark.ml.classification.LogisticRegression; -import static org.apache.spark.mllib.classification.LogisticRegressionSuite.generateMultinomialLogisticInput; -import org.apache.spark.mllib.regression.LabeledPoint; -import org.apache.spark.sql.DataFrame; -import org.apache.spark.sql.SQLContext; - -public class JavaOneVsRestSuite implements Serializable { - - private transient JavaSparkContext jsc; - private transient SQLContext jsql; - private transient DataFrame dataset; - private transient JavaRDD datasetRDD; - - @Before - public void setUp() { - jsc = new JavaSparkContext("local", "JavaLOneVsRestSuite"); - jsql = new SQLContext(jsc); - int nPoints = 3; - - /** - * The following weights and xMean/xVariance are computed from iris dataset with lambda = 0.2. - * As a result, we are actually drawing samples from probability distribution of built model. - */ - double[] weights = { - -0.57997, 0.912083, -0.371077, -0.819866, 2.688191, - -0.16624, -0.84355, -0.048509, -0.301789, 4.170682 }; - - double[] xMean = {5.843, 3.057, 3.758, 1.199}; - double[] xVariance = {0.6856, 0.1899, 3.116, 0.581}; - List points = seqAsJavaList(generateMultinomialLogisticInput( - weights, xMean, xVariance, true, nPoints, 42)); - datasetRDD = jsc.parallelize(points, 2); - dataset = jsql.createDataFrame(datasetRDD, LabeledPoint.class); - } - - @After - public void tearDown() { - jsc.stop(); - jsc = null; - } - - @Test - public void oneVsRestDefaultParams() { - OneVsRest ova = new OneVsRest(); - ova.setClassifier(new LogisticRegression()); - Assert.assertEquals(ova.getLabelCol() , "label"); - Assert.assertEquals(ova.getPredictionCol() , "prediction"); - OneVsRestModel ovaModel = ova.fit(dataset); - DataFrame predictions = ovaModel.transform(dataset).select("label", "prediction"); - predictions.collectAsList(); - Assert.assertEquals(ovaModel.getLabelCol(), "label"); - Assert.assertEquals(ovaModel.getPredictionCol() , "prediction"); - } -} diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala new file mode 100644 index 0000000000..e65ffae918 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala @@ -0,0 +1,110 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.classification + +import org.scalatest.FunSuite + +import org.apache.spark.ml.attribute.NominalAttribute +import org.apache.spark.ml.util.MetadataUtils +import org.apache.spark.mllib.classification.LogisticRegressionSuite._ +import org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS +import org.apache.spark.mllib.evaluation.MulticlassMetrics +import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.mllib.util.TestingUtils._ +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.{DataFrame, SQLContext} + +class OneVsRestSuite extends FunSuite with MLlibTestSparkContext { + + @transient var sqlContext: SQLContext = _ + @transient var dataset: DataFrame = _ + @transient var rdd: RDD[LabeledPoint] = _ + + override def beforeAll(): Unit = { + super.beforeAll() + sqlContext = new SQLContext(sc) + val nPoints = 1000 + + // The following weights and xMean/xVariance are computed from iris dataset with lambda=0.2. + // As a result, we are drawing samples from probability distribution of an actual model. + val weights = Array( + -0.57997, 0.912083, -0.371077, -0.819866, 2.688191, + -0.16624, -0.84355, -0.048509, -0.301789, 4.170682) + + val xMean = Array(5.843, 3.057, 3.758, 1.199) + val xVariance = Array(0.6856, 0.1899, 3.116, 0.581) + rdd = sc.parallelize(generateMultinomialLogisticInput( + weights, xMean, xVariance, true, nPoints, 42), 2) + dataset = sqlContext.createDataFrame(rdd) + } + + test("one-vs-rest: default params") { + val numClasses = 3 + val ova = new OneVsRest() + ova.setClassifier(new LogisticRegression) + assert(ova.getLabelCol === "label") + assert(ova.getPredictionCol === "prediction") + val ovaModel = ova.fit(dataset) + assert(ovaModel.models.size === numClasses) + val transformedDataset = ovaModel.transform(dataset) + + // check for label metadata in prediction col + val predictionColSchema = transformedDataset.schema(ovaModel.getPredictionCol) + assert(MetadataUtils.getNumClasses(predictionColSchema) === Some(3)) + + val ovaResults = transformedDataset + .select("prediction", "label") + .map(row => (row.getDouble(0), row.getDouble(1))) + + val lr = new LogisticRegressionWithLBFGS().setIntercept(true).setNumClasses(numClasses) + lr.optimizer.setRegParam(0.1).setNumIterations(100) + + val model = lr.run(rdd) + val results = model.predict(rdd.map(_.features)).zip(rdd.map(_.label)) + // determine the #confusion matrix in each class. + // bound how much error we allow compared to multinomial logistic regression. + val expectedMetrics = new MulticlassMetrics(results) + val ovaMetrics = new MulticlassMetrics(ovaResults) + assert(expectedMetrics.confusionMatrix ~== ovaMetrics.confusionMatrix absTol 400) + } + + test("one-vs-rest: pass label metadata correctly during train") { + val numClasses = 3 + val ova = new OneVsRest() + ova.setClassifier(new MockLogisticRegression) + + val labelMetadata = NominalAttribute.defaultAttr.withName("label").withNumValues(numClasses) + val labelWithMetadata = dataset("label").as("label", labelMetadata.toMetadata()) + val features = dataset("features").as("features") + val datasetWithLabelMetadata = dataset.select(labelWithMetadata, features) + ova.fit(datasetWithLabelMetadata) + } +} + +private class MockLogisticRegression extends LogisticRegression { + + setMaxIter(1) + + override protected def train(dataset: DataFrame): LogisticRegressionModel = { + val labelSchema = dataset.schema($(labelCol)) + // check for label attribute propagation. + assert(MetadataUtils.getNumClasses(labelSchema).forall(_ == 2)) + super.train(dataset) + } +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/reduction/OneVsRestSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/reduction/OneVsRestSuite.scala deleted file mode 100644 index ebec7c68e8..0000000000 --- a/mllib/src/test/scala/org/apache/spark/ml/reduction/OneVsRestSuite.scala +++ /dev/null @@ -1,113 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.ml.reduction - -import org.scalatest.FunSuite - -import org.apache.spark.ml.attribute.NominalAttribute -import org.apache.spark.ml.classification.{LogisticRegressionModel, LogisticRegression} -import org.apache.spark.ml.util.MetadataUtils -import org.apache.spark.mllib.classification.LogisticRegressionSuite._ -import org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS -import org.apache.spark.mllib.evaluation.MulticlassMetrics -import org.apache.spark.mllib.regression.LabeledPoint -import org.apache.spark.mllib.util.MLlibTestSparkContext -import org.apache.spark.mllib.util.TestingUtils._ -import org.apache.spark.rdd.RDD -import org.apache.spark.sql.{DataFrame, SQLContext} - -class OneVsRestSuite extends FunSuite with MLlibTestSparkContext { - - @transient var sqlContext: SQLContext = _ - @transient var dataset: DataFrame = _ - @transient var rdd: RDD[LabeledPoint] = _ - - override def beforeAll(): Unit = { - super.beforeAll() - sqlContext = new SQLContext(sc) - val nPoints = 1000 - - /** - * The following weights and xMean/xVariance are computed from iris dataset with lambda = 0.2. - * As a result, we are actually drawing samples from probability distribution of built model. - */ - val weights = Array( - -0.57997, 0.912083, -0.371077, -0.819866, 2.688191, - -0.16624, -0.84355, -0.048509, -0.301789, 4.170682) - - val xMean = Array(5.843, 3.057, 3.758, 1.199) - val xVariance = Array(0.6856, 0.1899, 3.116, 0.581) - rdd = sc.parallelize(generateMultinomialLogisticInput( - weights, xMean, xVariance, true, nPoints, 42), 2) - dataset = sqlContext.createDataFrame(rdd) - } - - test("one-vs-rest: default params") { - val numClasses = 3 - val ova = new OneVsRest() - ova.setClassifier(new LogisticRegression) - assert(ova.getLabelCol === "label") - assert(ova.getPredictionCol === "prediction") - val ovaModel = ova.fit(dataset) - assert(ovaModel.models.size === numClasses) - val transformedDataset = ovaModel.transform(dataset) - - // check for label metadata in prediction col - val predictionColSchema = transformedDataset.schema(ovaModel.getPredictionCol) - assert(MetadataUtils.getNumClasses(predictionColSchema) === Some(3)) - - val ovaResults = transformedDataset - .select("prediction", "label") - .map(row => (row.getDouble(0), row.getDouble(1))) - - val lr = new LogisticRegressionWithLBFGS().setIntercept(true).setNumClasses(numClasses) - lr.optimizer.setRegParam(0.1).setNumIterations(100) - - val model = lr.run(rdd) - val results = model.predict(rdd.map(_.features)).zip(rdd.map(_.label)) - // determine the #confusion matrix in each class. - // bound how much error we allow compared to multinomial logistic regression. - val expectedMetrics = new MulticlassMetrics(results) - val ovaMetrics = new MulticlassMetrics(ovaResults) - assert(expectedMetrics.confusionMatrix ~== ovaMetrics.confusionMatrix absTol 400) - } - - test("one-vs-rest: pass label metadata correctly during train") { - val numClasses = 3 - val ova = new OneVsRest() - ova.setClassifier(new MockLogisticRegression) - - val labelMetadata = NominalAttribute.defaultAttr.withName("label").withNumValues(numClasses) - val labelWithMetadata = dataset("label").as("label", labelMetadata.toMetadata()) - val features = dataset("features").as("features") - val datasetWithLabelMetadata = dataset.select(labelWithMetadata, features) - ova.fit(datasetWithLabelMetadata) - } -} - -private class MockLogisticRegression extends LogisticRegression { - - setMaxIter(1) - - override protected def train(dataset: DataFrame): LogisticRegressionModel = { - val labelSchema = dataset.schema($(labelCol)) - // check for label attribute propagation. - assert(MetadataUtils.getNumClasses(labelSchema).forall(_ == 2)) - super.train(dataset) - } -} -- cgit v1.2.3