diff options
author | Joseph K. Bradley <joseph@databricks.com> | 2015-04-12 22:38:27 -0700 |
---|---|---|
committer | Xiangrui Meng <meng@databricks.com> | 2015-04-12 22:38:27 -0700 |
commit | d3792f54974e16cbe8f10b3091d248e0bdd48986 (patch) | |
tree | 89d679f7a9f76599841f169239021f190968654b /mllib/src/test | |
parent | fc17661475443d9f0a8d28e3439feeb7a7bca67b (diff) | |
download | spark-d3792f54974e16cbe8f10b3091d248e0bdd48986.tar.gz spark-d3792f54974e16cbe8f10b3091d248e0bdd48986.tar.bz2 spark-d3792f54974e16cbe8f10b3091d248e0bdd48986.zip |
[SPARK-4081] [mllib] VectorIndexer
**Ready for review!**
Since the original PR, I moved the code to the spark.ml API and renamed this to VectorIndexer.
This introduces a VectorIndexer class which does the following:
* VectorIndexer.fit(): collect statistics about how many values each feature in a dataset (RDD[Vector]) can take (limited by maxCategories)
* Feature which exceed maxCategories are declared continuous, and the Model will treat them as such.
* VectorIndexerModel.transform(): Convert categorical feature values to corresponding 0-based indices
Design notes:
* This maintains sparsity in vectors by ensuring that categorical feature value 0.0 gets index 0.
* This does not yet support transforming data with new (unknown) categorical feature values. That can be added later.
* This is necessary for DecisionTree and tree ensembles.
Reviewers: Please check my use of metadata and my unit tests for it; I'm not sure if I covered everything in the tests.
Other notes:
* This also adds a public toMetadata method to AttributeGroup (for simpler construction of metadata).
CC: mengxr
Author: Joseph K. Bradley <joseph@databricks.com>
Closes #3000 from jkbradley/indexer and squashes the following commits:
5956d91 [Joseph K. Bradley] minor cleanups
f5c57a8 [Joseph K. Bradley] added Java test suite
643b444 [Joseph K. Bradley] removed FeatureTests
02236c3 [Joseph K. Bradley] Updated VectorIndexer, ready for PR
286d221 [Joseph K. Bradley] Reworked DatasetIndexer for spark.ml API, and renamed it to VectorIndexer
12e6cf2 [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into indexer
6d8f3f1 [Joseph K. Bradley] Added partly done DatasetIndexer to spark.ml
6a2f553 [Joseph K. Bradley] Updated TODO for allowUnknownCategories
3f041f8 [Joseph K. Bradley] Final cleanups for DatasetIndexer
038b9e3 [Joseph K. Bradley] DatasetIndexer now maintains sparsity in SparseVector
3a4a0bd [Joseph K. Bradley] Added another test for DatasetIndexer
2006923 [Joseph K. Bradley] DatasetIndexer now passes tests
f409987 [Joseph K. Bradley] partly done with DatasetIndexerSuite
5e7c874 [Joseph K. Bradley] working on DatasetIndexer
Diffstat (limited to 'mllib/src/test')
5 files changed, 394 insertions, 6 deletions
diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorIndexerSuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorIndexerSuite.java new file mode 100644 index 0000000000..161100134c --- /dev/null +++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorIndexerSuite.java @@ -0,0 +1,70 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.feature; + +import java.io.Serializable; +import java.util.List; + +import org.junit.After; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; + +import com.google.common.collect.Lists; + +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.ml.feature.VectorIndexerSuite.FeatureData; +import org.apache.spark.mllib.linalg.Vectors; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.SQLContext; + + +public class JavaVectorIndexerSuite implements Serializable { + private transient JavaSparkContext sc; + + @Before + public void setUp() { + sc = new JavaSparkContext("local", "JavaVectorIndexerSuite"); + } + + @After + public void tearDown() { + sc.stop(); + sc = null; + } + + @Test + public void vectorIndexerAPI() { + // The tests are to check Java compatibility. + List<FeatureData> points = Lists.newArrayList( + new FeatureData(Vectors.dense(0.0, -2.0)), + new FeatureData(Vectors.dense(1.0, 3.0)), + new FeatureData(Vectors.dense(1.0, 4.0)) + ); + SQLContext sqlContext = new SQLContext(sc); + DataFrame data = sqlContext.createDataFrame(sc.parallelize(points, 2), FeatureData.class); + VectorIndexer indexer = new VectorIndexer() + .setInputCol("features") + .setOutputCol("indexed") + .setMaxCategories(2); + VectorIndexerModel model = indexer.fit(data); + Assert.assertEquals(model.numFeatures(), 2); + Assert.assertEquals(model.categoryMaps().size(), 1); + DataFrame indexedData = model.transform(data); + } +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/attribute/AttributeGroupSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/attribute/AttributeGroupSuite.scala index 3fb6e2ec46..0dcfe5a200 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/attribute/AttributeGroupSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/attribute/AttributeGroupSuite.scala @@ -43,8 +43,8 @@ class AttributeGroupSuite extends FunSuite { intercept[NoSuchElementException] { group("abc") } - assert(group === AttributeGroup.fromMetadata(group.toMetadata, group.name)) - assert(group === AttributeGroup.fromStructField(group.toStructField())) + assert(group === AttributeGroup.fromMetadata(group.toMetadataImpl, group.name)) + assert(group === AttributeGroup.fromStructField(group.toStructField)) } test("attribute group without attributes") { @@ -53,8 +53,8 @@ class AttributeGroupSuite extends FunSuite { assert(group0.numAttributes === Some(10)) assert(group0.size === 10) assert(group0.attributes.isEmpty) - assert(group0 === AttributeGroup.fromMetadata(group0.toMetadata, group0.name)) - assert(group0 === AttributeGroup.fromStructField(group0.toStructField())) + assert(group0 === AttributeGroup.fromMetadata(group0.toMetadataImpl, group0.name)) + assert(group0 === AttributeGroup.fromStructField(group0.toStructField)) val group1 = new AttributeGroup("item") assert(group1.name === "item") diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/NormalizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/NormalizerSuite.scala index a18c335952..9d09f24709 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/NormalizerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/NormalizerSuite.scala @@ -24,7 +24,6 @@ import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ import org.apache.spark.sql.{DataFrame, Row, SQLContext} -private case class DataSet(features: Vector) class NormalizerSuite extends FunSuite with MLlibTestSparkContext { @@ -63,7 +62,7 @@ class NormalizerSuite extends FunSuite with MLlibTestSparkContext { ) val sqlContext = new SQLContext(sc) - dataFrame = sqlContext.createDataFrame(sc.parallelize(data, 2).map(DataSet)) + dataFrame = sqlContext.createDataFrame(sc.parallelize(data, 2).map(NormalizerSuite.FeatureData)) normalizer = new Normalizer() .setInputCol("features") .setOutputCol("normalized_features") @@ -107,3 +106,7 @@ class NormalizerSuite extends FunSuite with MLlibTestSparkContext { assertValues(result, l1Normalized) } } + +private object NormalizerSuite { + case class FeatureData(features: Vector) +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorIndexerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorIndexerSuite.scala new file mode 100644 index 0000000000..61c46c85a7 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorIndexerSuite.scala @@ -0,0 +1,255 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.feature + +import scala.beans.{BeanInfo, BeanProperty} + +import org.scalatest.FunSuite + +import org.apache.spark.SparkException +import org.apache.spark.ml.attribute._ +import org.apache.spark.ml.util.TestingUtils +import org.apache.spark.mllib.linalg.{SparseVector, Vector, Vectors} +import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.{DataFrame, SQLContext} + + +class VectorIndexerSuite extends FunSuite with MLlibTestSparkContext { + + import VectorIndexerSuite.FeatureData + + @transient var sqlContext: SQLContext = _ + + // identical, of length 3 + @transient var densePoints1: DataFrame = _ + @transient var sparsePoints1: DataFrame = _ + @transient var point1maxes: Array[Double] = _ + + // identical, of length 2 + @transient var densePoints2: DataFrame = _ + @transient var sparsePoints2: DataFrame = _ + + // different lengths + @transient var badPoints: DataFrame = _ + + override def beforeAll(): Unit = { + super.beforeAll() + + val densePoints1Seq = Seq( + Vectors.dense(1.0, 2.0, 0.0), + Vectors.dense(0.0, 1.0, 2.0), + Vectors.dense(0.0, 0.0, -1.0), + Vectors.dense(1.0, 3.0, 2.0)) + val sparsePoints1Seq = Seq( + Vectors.sparse(3, Array(0, 1), Array(1.0, 2.0)), + Vectors.sparse(3, Array(1, 2), Array(1.0, 2.0)), + Vectors.sparse(3, Array(2), Array(-1.0)), + Vectors.sparse(3, Array(0, 1, 2), Array(1.0, 3.0, 2.0))) + point1maxes = Array(1.0, 3.0, 2.0) + + val densePoints2Seq = Seq( + Vectors.dense(1.0, 1.0, 0.0, 1.0), + Vectors.dense(0.0, 1.0, 1.0, 1.0), + Vectors.dense(-1.0, 1.0, 2.0, 0.0)) + val sparsePoints2Seq = Seq( + Vectors.sparse(4, Array(0, 1, 3), Array(1.0, 1.0, 1.0)), + Vectors.sparse(4, Array(1, 2, 3), Array(1.0, 1.0, 1.0)), + Vectors.sparse(4, Array(0, 1, 2), Array(-1.0, 1.0, 2.0))) + + val badPointsSeq = Seq( + Vectors.sparse(2, Array(0, 1), Array(1.0, 1.0)), + Vectors.sparse(3, Array(2), Array(-1.0))) + + // Sanity checks for assumptions made in tests + assert(densePoints1Seq.head.size == sparsePoints1Seq.head.size) + assert(densePoints2Seq.head.size == sparsePoints2Seq.head.size) + assert(densePoints1Seq.head.size != densePoints2Seq.head.size) + def checkPair(dvSeq: Seq[Vector], svSeq: Seq[Vector]): Unit = { + assert(dvSeq.zip(svSeq).forall { case (dv, sv) => dv.toArray === sv.toArray }, + "typo in unit test") + } + checkPair(densePoints1Seq, sparsePoints1Seq) + checkPair(densePoints2Seq, sparsePoints2Seq) + + sqlContext = new SQLContext(sc) + densePoints1 = sqlContext.createDataFrame(sc.parallelize(densePoints1Seq, 2).map(FeatureData)) + sparsePoints1 = sqlContext.createDataFrame(sc.parallelize(sparsePoints1Seq, 2).map(FeatureData)) + densePoints2 = sqlContext.createDataFrame(sc.parallelize(densePoints2Seq, 2).map(FeatureData)) + sparsePoints2 = sqlContext.createDataFrame(sc.parallelize(sparsePoints2Seq, 2).map(FeatureData)) + badPoints = sqlContext.createDataFrame(sc.parallelize(badPointsSeq, 2).map(FeatureData)) + } + + private def getIndexer: VectorIndexer = + new VectorIndexer().setInputCol("features").setOutputCol("indexed") + + test("Cannot fit an empty DataFrame") { + val rdd = sqlContext.createDataFrame(sc.parallelize(Array.empty[Vector], 2).map(FeatureData)) + val vectorIndexer = getIndexer + intercept[IllegalArgumentException] { + vectorIndexer.fit(rdd) + } + } + + test("Throws error when given RDDs with different size vectors") { + val vectorIndexer = getIndexer + val model = vectorIndexer.fit(densePoints1) // vectors of length 3 + model.transform(densePoints1) // should work + model.transform(sparsePoints1) // should work + intercept[IllegalArgumentException] { + model.transform(densePoints2) + println("Did not throw error when fit, transform were called on vectors of different lengths") + } + intercept[SparkException] { + vectorIndexer.fit(badPoints) + println("Did not throw error when fitting vectors of different lengths in same RDD.") + } + } + + test("Same result with dense and sparse vectors") { + def testDenseSparse(densePoints: DataFrame, sparsePoints: DataFrame): Unit = { + val denseVectorIndexer = getIndexer.setMaxCategories(2) + val sparseVectorIndexer = getIndexer.setMaxCategories(2) + val denseModel = denseVectorIndexer.fit(densePoints) + val sparseModel = sparseVectorIndexer.fit(sparsePoints) + val denseMap = denseModel.categoryMaps + val sparseMap = sparseModel.categoryMaps + assert(denseMap.keys.toSet == sparseMap.keys.toSet, + "Categorical features chosen from dense vs. sparse vectors did not match.") + assert(denseMap == sparseMap, + "Categorical feature value indexes chosen from dense vs. sparse vectors did not match.") + } + testDenseSparse(densePoints1, sparsePoints1) + testDenseSparse(densePoints2, sparsePoints2) + } + + test("Builds valid categorical feature value index, transform correctly, check metadata") { + def checkCategoryMaps( + data: DataFrame, + maxCategories: Int, + categoricalFeatures: Set[Int]): Unit = { + val collectedData = data.collect().map(_.getAs[Vector](0)) + val errMsg = s"checkCategoryMaps failed for input with maxCategories=$maxCategories," + + s" categoricalFeatures=${categoricalFeatures.mkString(", ")}" + try { + val vectorIndexer = getIndexer.setMaxCategories(maxCategories) + val model = vectorIndexer.fit(data) + val categoryMaps = model.categoryMaps + assert(categoryMaps.keys.toSet === categoricalFeatures) // Chose correct categorical features + val transformed = model.transform(data).select("indexed") + val indexedRDD: RDD[Vector] = transformed.map(_.getAs[Vector](0)) + val featureAttrs = AttributeGroup.fromStructField(transformed.schema("indexed")) + assert(featureAttrs.name === "indexed") + assert(featureAttrs.attributes.get.length === model.numFeatures) + categoricalFeatures.foreach { feature: Int => + val origValueSet = collectedData.map(_(feature)).toSet + val targetValueIndexSet = Range(0, origValueSet.size).toSet + val catMap = categoryMaps(feature) + assert(catMap.keys.toSet === origValueSet) // Correct categories + assert(catMap.values.toSet === targetValueIndexSet) // Correct category indices + if (origValueSet.contains(0.0)) { + assert(catMap(0.0) === 0) // value 0 gets index 0 + } + // Check transformed data + assert(indexedRDD.map(_(feature)).collect().toSet === targetValueIndexSet) + // Check metadata + val featureAttr = featureAttrs(feature) + assert(featureAttr.index.get === feature) + featureAttr match { + case attr: BinaryAttribute => + assert(attr.values.get === origValueSet.toArray.sorted.map(_.toString)) + case attr: NominalAttribute => + assert(attr.values.get === origValueSet.toArray.sorted.map(_.toString)) + assert(attr.isOrdinal.get === false) + case _ => + throw new RuntimeException(errMsg + s". Categorical feature $feature failed" + + s" metadata check. Found feature attribute: $featureAttr.") + } + } + // Check numerical feature metadata. + Range(0, model.numFeatures).filter(feature => !categoricalFeatures.contains(feature)) + .foreach { feature: Int => + val featureAttr = featureAttrs(feature) + featureAttr match { + case attr: NumericAttribute => + assert(featureAttr.index.get === feature) + case _ => + throw new RuntimeException(errMsg + s". Numerical feature $feature failed" + + s" metadata check. Found feature attribute: $featureAttr.") + } + } + } catch { + case e: org.scalatest.exceptions.TestFailedException => + println(errMsg) + throw e + } + } + checkCategoryMaps(densePoints1, maxCategories = 2, categoricalFeatures = Set(0)) + checkCategoryMaps(densePoints1, maxCategories = 3, categoricalFeatures = Set(0, 2)) + checkCategoryMaps(densePoints2, maxCategories = 2, categoricalFeatures = Set(1, 3)) + } + + test("Maintain sparsity for sparse vectors") { + def checkSparsity(data: DataFrame, maxCategories: Int): Unit = { + val points = data.collect().map(_.getAs[Vector](0)) + val vectorIndexer = getIndexer.setMaxCategories(maxCategories) + val model = vectorIndexer.fit(data) + val indexedPoints = model.transform(data).select("indexed").map(_.getAs[Vector](0)).collect() + points.zip(indexedPoints).foreach { + case (orig: SparseVector, indexed: SparseVector) => + assert(orig.indices.length == indexed.indices.length) + case _ => throw new UnknownError("Unit test has a bug in it.") // should never happen + } + } + checkSparsity(sparsePoints1, maxCategories = 2) + checkSparsity(sparsePoints2, maxCategories = 2) + } + + test("Preserve metadata") { + // For continuous features, preserve name and stats. + val featureAttributes: Array[Attribute] = point1maxes.zipWithIndex.map { case (maxVal, i) => + NumericAttribute.defaultAttr.withName(i.toString).withMax(maxVal) + } + val attrGroup = new AttributeGroup("features", featureAttributes) + val densePoints1WithMeta = + densePoints1.select(densePoints1("features").as("features", attrGroup.toMetadata)) + val vectorIndexer = getIndexer.setMaxCategories(2) + val model = vectorIndexer.fit(densePoints1WithMeta) + // Check that ML metadata are preserved. + val indexedPoints = model.transform(densePoints1WithMeta) + val transAttributes: Array[Attribute] = + AttributeGroup.fromStructField(indexedPoints.schema("indexed")).attributes.get + featureAttributes.zip(transAttributes).foreach { case (orig, trans) => + assert(orig.name === trans.name) + (orig, trans) match { + case (orig: NumericAttribute, trans: NumericAttribute) => + assert(orig.max.nonEmpty && orig.max === trans.max) + case _ => + // do nothing + // TODO: Once input features marked as categorical are handled correctly, check that here. + } + } + // Check that non-ML metadata are preserved. + TestingUtils.testPreserveMetadata(densePoints1WithMeta, model, "features", "indexed") + } +} + +private[feature] object VectorIndexerSuite { + @BeanInfo + case class FeatureData(@BeanProperty features: Vector) +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/util/TestingUtils.scala b/mllib/src/test/scala/org/apache/spark/ml/util/TestingUtils.scala new file mode 100644 index 0000000000..c44cb61b34 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/util/TestingUtils.scala @@ -0,0 +1,60 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.util + +import org.apache.spark.ml.Transformer +import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.types.MetadataBuilder +import org.scalatest.FunSuite + +private[ml] object TestingUtils extends FunSuite { + + /** + * Test whether unrelated metadata are preserved for this transformer. + * This attaches extra metadata to a column, transforms the column, and check to ensure the + * extra metadata have not changed. + * @param data Input dataset + * @param transformer Transformer to test + * @param inputCol Unique input column for Transformer. This must be the ONLY input column. + * @param outputCol Output column to test for metadata presence. + */ + def testPreserveMetadata( + data: DataFrame, + transformer: Transformer, + inputCol: String, + outputCol: String): Unit = { + // Create some fake metadata + val origMetadata = data.schema(inputCol).metadata + val metaKey = "__testPreserveMetadata__fake_key" + val metaValue = 12345 + assert(!origMetadata.contains(metaKey), + s"Unit test with testPreserveMetadata will fail since metadata key was present: $metaKey") + val newMetadata = + new MetadataBuilder().withMetadata(origMetadata).putLong(metaKey, metaValue).build() + // Add metadata to the inputCol + val withMetadata = data.select(data(inputCol).as(inputCol, newMetadata)) + // Transform, and ensure extra metadata was not affected + val transformed = transformer.transform(withMetadata) + val transMetadata = transformed.schema(outputCol).metadata + assert(transMetadata.contains(metaKey), + "Unit test with testPreserveMetadata failed; extra metadata key was not present.") + assert(transMetadata.getLong(metaKey) === metaValue, + "Unit test with testPreserveMetadata failed; extra metadata value was wrong." + + s" Expected $metaValue but found ${transMetadata.getLong(metaKey)}") + } +} |