diff options
author | Joseph K. Bradley <joseph@databricks.com> | 2015-04-12 22:38:27 -0700 |
---|---|---|
committer | Xiangrui Meng <meng@databricks.com> | 2015-04-12 22:38:27 -0700 |
commit | d3792f54974e16cbe8f10b3091d248e0bdd48986 (patch) | |
tree | 89d679f7a9f76599841f169239021f190968654b /mllib | |
parent | fc17661475443d9f0a8d28e3439feeb7a7bca67b (diff) | |
download | spark-d3792f54974e16cbe8f10b3091d248e0bdd48986.tar.gz spark-d3792f54974e16cbe8f10b3091d248e0bdd48986.tar.bz2 spark-d3792f54974e16cbe8f10b3091d248e0bdd48986.zip |
[SPARK-4081] [mllib] VectorIndexer
**Ready for review!**
Since the original PR, I moved the code to the spark.ml API and renamed this to VectorIndexer.
This introduces a VectorIndexer class which does the following:
* VectorIndexer.fit(): collect statistics about how many values each feature in a dataset (RDD[Vector]) can take (limited by maxCategories)
* Feature which exceed maxCategories are declared continuous, and the Model will treat them as such.
* VectorIndexerModel.transform(): Convert categorical feature values to corresponding 0-based indices
Design notes:
* This maintains sparsity in vectors by ensuring that categorical feature value 0.0 gets index 0.
* This does not yet support transforming data with new (unknown) categorical feature values. That can be added later.
* This is necessary for DecisionTree and tree ensembles.
Reviewers: Please check my use of metadata and my unit tests for it; I'm not sure if I covered everything in the tests.
Other notes:
* This also adds a public toMetadata method to AttributeGroup (for simpler construction of metadata).
CC: mengxr
Author: Joseph K. Bradley <joseph@databricks.com>
Closes #3000 from jkbradley/indexer and squashes the following commits:
5956d91 [Joseph K. Bradley] minor cleanups
f5c57a8 [Joseph K. Bradley] added Java test suite
643b444 [Joseph K. Bradley] removed FeatureTests
02236c3 [Joseph K. Bradley] Updated VectorIndexer, ready for PR
286d221 [Joseph K. Bradley] Reworked DatasetIndexer for spark.ml API, and renamed it to VectorIndexer
12e6cf2 [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into indexer
6d8f3f1 [Joseph K. Bradley] Added partly done DatasetIndexer to spark.ml
6a2f553 [Joseph K. Bradley] Updated TODO for allowUnknownCategories
3f041f8 [Joseph K. Bradley] Final cleanups for DatasetIndexer
038b9e3 [Joseph K. Bradley] DatasetIndexer now maintains sparsity in SparseVector
3a4a0bd [Joseph K. Bradley] Added another test for DatasetIndexer
2006923 [Joseph K. Bradley] DatasetIndexer now passes tests
f409987 [Joseph K. Bradley] partly done with DatasetIndexerSuite
5e7c874 [Joseph K. Bradley] working on DatasetIndexer
Diffstat (limited to 'mllib')
9 files changed, 818 insertions, 19 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala b/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala index c4a3610330..a455341a1f 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala @@ -47,6 +47,9 @@ abstract class PipelineStage extends Serializable with Logging { /** * Derives the output schema from the input schema and parameters, optionally with logging. + * + * This should be optimistic. If it is unclear whether the schema will be valid, then it should + * be assumed valid until proven otherwise. */ protected def transformSchema( schema: StructType, diff --git a/mllib/src/main/scala/org/apache/spark/ml/attribute/AttributeGroup.scala b/mllib/src/main/scala/org/apache/spark/ml/attribute/AttributeGroup.scala index 970e6ad551..aa27a668f1 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/attribute/AttributeGroup.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/attribute/AttributeGroup.scala @@ -106,7 +106,7 @@ class AttributeGroup private ( def getAttr(attrIndex: Int): Attribute = this(attrIndex) /** Converts to metadata without name. */ - private[attribute] def toMetadata: Metadata = { + private[attribute] def toMetadataImpl: Metadata = { import AttributeKeys._ val bldr = new MetadataBuilder() if (attributes.isDefined) { @@ -142,17 +142,24 @@ class AttributeGroup private ( bldr.build() } - /** Converts to a StructField with some existing metadata. */ - def toStructField(existingMetadata: Metadata): StructField = { - val newMetadata = new MetadataBuilder() + /** Converts to ML metadata with some existing metadata. */ + def toMetadata(existingMetadata: Metadata): Metadata = { + new MetadataBuilder() .withMetadata(existingMetadata) - .putMetadata(AttributeKeys.ML_ATTR, toMetadata) + .putMetadata(AttributeKeys.ML_ATTR, toMetadataImpl) .build() - StructField(name, new VectorUDT, nullable = false, newMetadata) + } + + /** Converts to ML metadata */ + def toMetadata: Metadata = toMetadata(Metadata.empty) + + /** Converts to a StructField with some existing metadata. */ + def toStructField(existingMetadata: Metadata): StructField = { + StructField(name, new VectorUDT, nullable = false, toMetadata(existingMetadata)) } /** Converts to a StructField. */ - def toStructField(): StructField = toStructField(Metadata.empty) + def toStructField: StructField = toStructField(Metadata.empty) override def equals(other: Any): Boolean = { other match { diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala new file mode 100644 index 0000000000..8760960e19 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala @@ -0,0 +1,393 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.feature + +import org.apache.spark.annotation.AlphaComponent +import org.apache.spark.ml.{Estimator, Model} +import org.apache.spark.ml.attribute.{BinaryAttribute, NumericAttribute, NominalAttribute, + Attribute, AttributeGroup} +import org.apache.spark.ml.param.{HasInputCol, HasOutputCol, IntParam, ParamMap, Params} +import org.apache.spark.mllib.linalg.{SparseVector, DenseVector, Vector, VectorUDT} +import org.apache.spark.sql.{Row, DataFrame} +import org.apache.spark.sql.functions.callUDF +import org.apache.spark.sql.types.{StructField, StructType} +import org.apache.spark.util.collection.OpenHashSet + + +/** Private trait for params for VectorIndexer and VectorIndexerModel */ +private[ml] trait VectorIndexerParams extends Params with HasInputCol with HasOutputCol { + + /** + * Threshold for the number of values a categorical feature can take. + * If a feature is found to have > maxCategories values, then it is declared continuous. + * + * (default = 20) + */ + val maxCategories = new IntParam(this, "maxCategories", + "Threshold for the number of values a categorical feature can take." + + " If a feature is found to have > maxCategories values, then it is declared continuous.", + Some(20)) + + /** @group getParam */ + def getMaxCategories: Int = get(maxCategories) +} + +/** + * :: AlphaComponent :: + * + * Class for indexing categorical feature columns in a dataset of [[Vector]]. + * + * This has 2 usage modes: + * - Automatically identify categorical features (default behavior) + * - This helps process a dataset of unknown vectors into a dataset with some continuous + * features and some categorical features. The choice between continuous and categorical + * is based upon a maxCategories parameter. + * - Set maxCategories to the maximum number of categorical any categorical feature should have. + * - E.g.: Feature 0 has unique values {-1.0, 0.0}, and feature 1 values {1.0, 3.0, 5.0}. + * If maxCategories = 2, then feature 0 will be declared categorical and use indices {0, 1}, + * and feature 1 will be declared continuous. + * - Index all features, if all features are categorical + * - If maxCategories is set to be very large, then this will build an index of unique + * values for all features. + * - Warning: This can cause problems if features are continuous since this will collect ALL + * unique values to the driver. + * - E.g.: Feature 0 has unique values {-1.0, 0.0}, and feature 1 values {1.0, 3.0, 5.0}. + * If maxCategories >= 3, then both features will be declared categorical. + * + * This returns a model which can transform categorical features to use 0-based indices. + * + * Index stability: + * - This is not guaranteed to choose the same category index across multiple runs. + * - If a categorical feature includes value 0, then this is guaranteed to map value 0 to index 0. + * This maintains vector sparsity. + * - More stability may be added in the future. + * + * TODO: Future extensions: The following functionality is planned for the future: + * - Preserve metadata in transform; if a feature's metadata is already present, do not recompute. + * - Specify certain features to not index, either via a parameter or via existing metadata. + * - Add warning if a categorical feature has only 1 category. + * - Add option for allowing unknown categories. + */ +@AlphaComponent +class VectorIndexer extends Estimator[VectorIndexerModel] with VectorIndexerParams { + + /** @group setParam */ + def setMaxCategories(value: Int): this.type = { + require(value > 1, + s"DatasetIndexer given maxCategories = value, but requires maxCategories > 1.") + set(maxCategories, value) + } + + /** @group setParam */ + def setInputCol(value: String): this.type = set(inputCol, value) + + /** @group setParam */ + def setOutputCol(value: String): this.type = set(outputCol, value) + + override def fit(dataset: DataFrame, paramMap: ParamMap): VectorIndexerModel = { + transformSchema(dataset.schema, paramMap, logging = true) + val map = this.paramMap ++ paramMap + val firstRow = dataset.select(map(inputCol)).take(1) + require(firstRow.length == 1, s"VectorIndexer cannot be fit on an empty dataset.") + val numFeatures = firstRow(0).getAs[Vector](0).size + val vectorDataset = dataset.select(map(inputCol)).map { case Row(v: Vector) => v } + val maxCats = map(maxCategories) + val categoryStats: VectorIndexer.CategoryStats = vectorDataset.mapPartitions { iter => + val localCatStats = new VectorIndexer.CategoryStats(numFeatures, maxCats) + iter.foreach(localCatStats.addVector) + Iterator(localCatStats) + }.reduce((stats1, stats2) => stats1.merge(stats2)) + val model = new VectorIndexerModel(this, map, numFeatures, categoryStats.getCategoryMaps) + Params.inheritValues(map, this, model) + model + } + + override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = { + // We do not transfer feature metadata since we do not know what types of features we will + // produce in transform(). + val map = this.paramMap ++ paramMap + val dataType = new VectorUDT + require(map.contains(inputCol), s"VectorIndexer requires input column parameter: $inputCol") + require(map.contains(outputCol), s"VectorIndexer requires output column parameter: $outputCol") + checkInputColumn(schema, map(inputCol), dataType) + addOutputColumn(schema, map(outputCol), dataType) + } +} + +private object VectorIndexer { + + /** + * Helper class for tracking unique values for each feature. + * + * TODO: Track which features are known to be continuous already; do not update counts for them. + * + * @param numFeatures This class fails if it encounters a Vector whose length is not numFeatures. + * @param maxCategories This class caps the number of unique values collected at maxCategories. + */ + class CategoryStats(private val numFeatures: Int, private val maxCategories: Int) + extends Serializable { + + /** featureValueSets[feature index] = set of unique values */ + private val featureValueSets = + Array.fill[OpenHashSet[Double]](numFeatures)(new OpenHashSet[Double]()) + + /** Merge with another instance, modifying this instance. */ + def merge(other: CategoryStats): CategoryStats = { + featureValueSets.zip(other.featureValueSets).foreach { case (thisValSet, otherValSet) => + otherValSet.iterator.foreach { x => + // Once we have found > maxCategories values, we know the feature is continuous + // and do not need to collect more values for it. + if (thisValSet.size <= maxCategories) thisValSet.add(x) + } + } + this + } + + /** Add a new vector to this index, updating sets of unique feature values */ + def addVector(v: Vector): Unit = { + require(v.size == numFeatures, s"VectorIndexer expected $numFeatures features but" + + s" found vector of size ${v.size}.") + v match { + case dv: DenseVector => addDenseVector(dv) + case sv: SparseVector => addSparseVector(sv) + } + } + + /** + * Based on stats collected, decide which features are categorical, + * and choose indices for categories. + * + * Sparsity: This tries to maintain sparsity by treating value 0.0 specially. + * If a categorical feature takes value 0.0, then value 0.0 is given index 0. + * + * @return Feature value index. Keys are categorical feature indices (column indices). + * Values are mappings from original features values to 0-based category indices. + */ + def getCategoryMaps: Map[Int, Map[Double, Int]] = { + // Filter out features which are declared continuous. + featureValueSets.zipWithIndex.filter(_._1.size <= maxCategories).map { + case (featureValues: OpenHashSet[Double], featureIndex: Int) => + var sortedFeatureValues = featureValues.iterator.filter(_ != 0.0).toArray.sorted + val zeroExists = sortedFeatureValues.length + 1 == featureValues.size + if (zeroExists) { + sortedFeatureValues = 0.0 +: sortedFeatureValues + } + val categoryMap: Map[Double, Int] = sortedFeatureValues.zipWithIndex.toMap + (featureIndex, categoryMap) + }.toMap + } + + private def addDenseVector(dv: DenseVector): Unit = { + var i = 0 + while (i < dv.size) { + if (featureValueSets(i).size <= maxCategories) { + featureValueSets(i).add(dv(i)) + } + i += 1 + } + } + + private def addSparseVector(sv: SparseVector): Unit = { + // TODO: This might be able to handle 0's more efficiently. + var vecIndex = 0 // index into vector + var k = 0 // index into non-zero elements + while (vecIndex < sv.size) { + val featureValue = if (k < sv.indices.length && vecIndex == sv.indices(k)) { + k += 1 + sv.values(k - 1) + } else { + 0.0 + } + if (featureValueSets(vecIndex).size <= maxCategories) { + featureValueSets(vecIndex).add(featureValue) + } + vecIndex += 1 + } + } + } +} + +/** + * :: AlphaComponent :: + * + * Transform categorical features to use 0-based indices instead of their original values. + * - Categorical features are mapped to indices. + * - Continuous features (columns) are left unchanged. + * This also appends metadata to the output column, marking features as Numeric (continuous), + * Nominal (categorical), or Binary (either continuous or categorical). + * + * This maintains vector sparsity. + * + * @param numFeatures Number of features, i.e., length of Vectors which this transforms + * @param categoryMaps Feature value index. Keys are categorical feature indices (column indices). + * Values are maps from original features values to 0-based category indices. + * If a feature is not in this map, it is treated as continuous. + */ +@AlphaComponent +class VectorIndexerModel private[ml] ( + override val parent: VectorIndexer, + override val fittingParamMap: ParamMap, + val numFeatures: Int, + val categoryMaps: Map[Int, Map[Double, Int]]) + extends Model[VectorIndexerModel] with VectorIndexerParams { + + /** + * Pre-computed feature attributes, with some missing info. + * In transform(), set attribute name and other info, if available. + */ + private val partialFeatureAttributes: Array[Attribute] = { + val attrs = new Array[Attribute](numFeatures) + var categoricalFeatureCount = 0 // validity check for numFeatures, categoryMaps + var featureIndex = 0 + while (featureIndex < numFeatures) { + if (categoryMaps.contains(featureIndex)) { + // categorical feature + val featureValues: Array[String] = + categoryMaps(featureIndex).toArray.sortBy(_._1).map(_._1).map(_.toString) + if (featureValues.length == 2) { + attrs(featureIndex) = new BinaryAttribute(index = Some(featureIndex), + values = Some(featureValues)) + } else { + attrs(featureIndex) = new NominalAttribute(index = Some(featureIndex), + isOrdinal = Some(false), values = Some(featureValues)) + } + categoricalFeatureCount += 1 + } else { + // continuous feature + attrs(featureIndex) = new NumericAttribute(index = Some(featureIndex)) + } + featureIndex += 1 + } + require(categoricalFeatureCount == categoryMaps.size, "VectorIndexerModel given categoryMaps" + + s" with keys outside expected range [0,...,numFeatures), where numFeatures=$numFeatures") + attrs + } + + // TODO: Check more carefully about whether this whole class will be included in a closure. + + private val transformFunc: Vector => Vector = { + val sortedCategoricalFeatureIndices = categoryMaps.keys.toArray.sorted + val localVectorMap = categoryMaps + val f: Vector => Vector = { + case dv: DenseVector => + val tmpv = dv.copy + localVectorMap.foreach { case (featureIndex: Int, categoryMap: Map[Double, Int]) => + tmpv.values(featureIndex) = categoryMap(tmpv(featureIndex)) + } + tmpv + case sv: SparseVector => + // We use the fact that categorical value 0 is always mapped to index 0. + val tmpv = sv.copy + var catFeatureIdx = 0 // index into sortedCategoricalFeatureIndices + var k = 0 // index into non-zero elements of sparse vector + while (catFeatureIdx < sortedCategoricalFeatureIndices.length && k < tmpv.indices.length) { + val featureIndex = sortedCategoricalFeatureIndices(catFeatureIdx) + if (featureIndex < tmpv.indices(k)) { + catFeatureIdx += 1 + } else if (featureIndex > tmpv.indices(k)) { + k += 1 + } else { + tmpv.values(k) = localVectorMap(featureIndex)(tmpv.values(k)) + catFeatureIdx += 1 + k += 1 + } + } + tmpv + } + f + } + + /** @group setParam */ + def setInputCol(value: String): this.type = set(inputCol, value) + + /** @group setParam */ + def setOutputCol(value: String): this.type = set(outputCol, value) + + override def transform(dataset: DataFrame, paramMap: ParamMap): DataFrame = { + transformSchema(dataset.schema, paramMap, logging = true) + val map = this.paramMap ++ paramMap + val newField = prepOutputField(dataset.schema, map) + val newCol = callUDF(transformFunc, new VectorUDT, dataset(map(inputCol))) + // For now, just check the first row of inputCol for vector length. + val firstRow = dataset.select(map(inputCol)).take(1) + if (firstRow.length != 0) { + val actualNumFeatures = firstRow(0).getAs[Vector](0).size + require(numFeatures == actualNumFeatures, "VectorIndexerModel expected vector of length" + + s" $numFeatures but found length $actualNumFeatures") + } + dataset.withColumn(map(outputCol), newCol.as(map(outputCol), newField.metadata)) + } + + override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = { + val map = this.paramMap ++ paramMap + val dataType = new VectorUDT + require(map.contains(inputCol), + s"VectorIndexerModel requires input column parameter: $inputCol") + require(map.contains(outputCol), + s"VectorIndexerModel requires output column parameter: $outputCol") + checkInputColumn(schema, map(inputCol), dataType) + + val origAttrGroup = AttributeGroup.fromStructField(schema(map(inputCol))) + val origNumFeatures: Option[Int] = if (origAttrGroup.attributes.nonEmpty) { + Some(origAttrGroup.attributes.get.length) + } else { + origAttrGroup.numAttributes + } + require(origNumFeatures.forall(_ == numFeatures), "VectorIndexerModel expected" + + s" $numFeatures features, but input column ${map(inputCol)} had metadata specifying" + + s" ${origAttrGroup.numAttributes.get} features.") + + val newField = prepOutputField(schema, map) + val outputFields = schema.fields :+ newField + StructType(outputFields) + } + + /** + * Prepare the output column field, including per-feature metadata. + * @param schema Input schema + * @param map Parameter map (with this class' embedded parameter map folded in) + * @return Output column field + */ + private def prepOutputField(schema: StructType, map: ParamMap): StructField = { + val origAttrGroup = AttributeGroup.fromStructField(schema(map(inputCol))) + val featureAttributes: Array[Attribute] = if (origAttrGroup.attributes.nonEmpty) { + // Convert original attributes to modified attributes + val origAttrs: Array[Attribute] = origAttrGroup.attributes.get + origAttrs.zip(partialFeatureAttributes).map { + case (origAttr: Attribute, featAttr: BinaryAttribute) => + if (origAttr.name.nonEmpty) { + featAttr.withName(origAttr.name.get) + } else { + featAttr + } + case (origAttr: Attribute, featAttr: NominalAttribute) => + if (origAttr.name.nonEmpty) { + featAttr.withName(origAttr.name.get) + } else { + featAttr + } + case (origAttr: Attribute, featAttr: NumericAttribute) => + origAttr.withIndex(featAttr.index.get) + } + } else { + partialFeatureAttributes + } + val newAttributeGroup = new AttributeGroup(map(outputCol), featureAttributes) + newAttributeGroup.toStructField(schema(map(inputCol)).metadata) + } +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala index 17ece897a6..7d5178d0ab 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala @@ -198,23 +198,31 @@ trait Params extends Identifiable with Serializable { /** * Check whether the given schema contains an input column. - * @param colName Parameter name for the input column. - * @param dataType SQL DataType of the input column. + * @param colName Input column name + * @param dataType Input column DataType */ protected def checkInputColumn(schema: StructType, colName: String, dataType: DataType): Unit = { val actualDataType = schema(colName).dataType - require(actualDataType.equals(dataType), - s"Input column $colName must be of type $dataType" + - s" but was actually $actualDataType. Column param description: ${getParam(colName)}") + require(actualDataType.equals(dataType), s"Input column $colName must be of type $dataType" + + s" but was actually $actualDataType. Column param description: ${getParam(colName)}") } + /** + * Add an output column to the given schema. + * This fails if the given output column already exists. + * @param schema Initial schema (not modified) + * @param colName Output column name. If this column name is an empy String "", this method + * returns the initial schema, unchanged. This allows users to disable output + * columns. + * @param dataType Output column DataType + */ protected def addOutputColumn( schema: StructType, colName: String, dataType: DataType): StructType = { if (colName.length == 0) return schema val fieldNames = schema.fieldNames - require(!fieldNames.contains(colName), s"Prediction column $colName already exists.") + require(!fieldNames.contains(colName), s"Output column $colName already exists.") val outputFields = schema.fields ++ Seq(StructField(colName, dataType, nullable = false)) StructType(outputFields) } diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorIndexerSuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorIndexerSuite.java new file mode 100644 index 0000000000..161100134c --- /dev/null +++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorIndexerSuite.java @@ -0,0 +1,70 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.feature; + +import java.io.Serializable; +import java.util.List; + +import org.junit.After; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; + +import com.google.common.collect.Lists; + +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.ml.feature.VectorIndexerSuite.FeatureData; +import org.apache.spark.mllib.linalg.Vectors; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.SQLContext; + + +public class JavaVectorIndexerSuite implements Serializable { + private transient JavaSparkContext sc; + + @Before + public void setUp() { + sc = new JavaSparkContext("local", "JavaVectorIndexerSuite"); + } + + @After + public void tearDown() { + sc.stop(); + sc = null; + } + + @Test + public void vectorIndexerAPI() { + // The tests are to check Java compatibility. + List<FeatureData> points = Lists.newArrayList( + new FeatureData(Vectors.dense(0.0, -2.0)), + new FeatureData(Vectors.dense(1.0, 3.0)), + new FeatureData(Vectors.dense(1.0, 4.0)) + ); + SQLContext sqlContext = new SQLContext(sc); + DataFrame data = sqlContext.createDataFrame(sc.parallelize(points, 2), FeatureData.class); + VectorIndexer indexer = new VectorIndexer() + .setInputCol("features") + .setOutputCol("indexed") + .setMaxCategories(2); + VectorIndexerModel model = indexer.fit(data); + Assert.assertEquals(model.numFeatures(), 2); + Assert.assertEquals(model.categoryMaps().size(), 1); + DataFrame indexedData = model.transform(data); + } +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/attribute/AttributeGroupSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/attribute/AttributeGroupSuite.scala index 3fb6e2ec46..0dcfe5a200 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/attribute/AttributeGroupSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/attribute/AttributeGroupSuite.scala @@ -43,8 +43,8 @@ class AttributeGroupSuite extends FunSuite { intercept[NoSuchElementException] { group("abc") } - assert(group === AttributeGroup.fromMetadata(group.toMetadata, group.name)) - assert(group === AttributeGroup.fromStructField(group.toStructField())) + assert(group === AttributeGroup.fromMetadata(group.toMetadataImpl, group.name)) + assert(group === AttributeGroup.fromStructField(group.toStructField)) } test("attribute group without attributes") { @@ -53,8 +53,8 @@ class AttributeGroupSuite extends FunSuite { assert(group0.numAttributes === Some(10)) assert(group0.size === 10) assert(group0.attributes.isEmpty) - assert(group0 === AttributeGroup.fromMetadata(group0.toMetadata, group0.name)) - assert(group0 === AttributeGroup.fromStructField(group0.toStructField())) + assert(group0 === AttributeGroup.fromMetadata(group0.toMetadataImpl, group0.name)) + assert(group0 === AttributeGroup.fromStructField(group0.toStructField)) val group1 = new AttributeGroup("item") assert(group1.name === "item") diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/NormalizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/NormalizerSuite.scala index a18c335952..9d09f24709 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/NormalizerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/NormalizerSuite.scala @@ -24,7 +24,6 @@ import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ import org.apache.spark.sql.{DataFrame, Row, SQLContext} -private case class DataSet(features: Vector) class NormalizerSuite extends FunSuite with MLlibTestSparkContext { @@ -63,7 +62,7 @@ class NormalizerSuite extends FunSuite with MLlibTestSparkContext { ) val sqlContext = new SQLContext(sc) - dataFrame = sqlContext.createDataFrame(sc.parallelize(data, 2).map(DataSet)) + dataFrame = sqlContext.createDataFrame(sc.parallelize(data, 2).map(NormalizerSuite.FeatureData)) normalizer = new Normalizer() .setInputCol("features") .setOutputCol("normalized_features") @@ -107,3 +106,7 @@ class NormalizerSuite extends FunSuite with MLlibTestSparkContext { assertValues(result, l1Normalized) } } + +private object NormalizerSuite { + case class FeatureData(features: Vector) +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorIndexerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorIndexerSuite.scala new file mode 100644 index 0000000000..61c46c85a7 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorIndexerSuite.scala @@ -0,0 +1,255 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.feature + +import scala.beans.{BeanInfo, BeanProperty} + +import org.scalatest.FunSuite + +import org.apache.spark.SparkException +import org.apache.spark.ml.attribute._ +import org.apache.spark.ml.util.TestingUtils +import org.apache.spark.mllib.linalg.{SparseVector, Vector, Vectors} +import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.{DataFrame, SQLContext} + + +class VectorIndexerSuite extends FunSuite with MLlibTestSparkContext { + + import VectorIndexerSuite.FeatureData + + @transient var sqlContext: SQLContext = _ + + // identical, of length 3 + @transient var densePoints1: DataFrame = _ + @transient var sparsePoints1: DataFrame = _ + @transient var point1maxes: Array[Double] = _ + + // identical, of length 2 + @transient var densePoints2: DataFrame = _ + @transient var sparsePoints2: DataFrame = _ + + // different lengths + @transient var badPoints: DataFrame = _ + + override def beforeAll(): Unit = { + super.beforeAll() + + val densePoints1Seq = Seq( + Vectors.dense(1.0, 2.0, 0.0), + Vectors.dense(0.0, 1.0, 2.0), + Vectors.dense(0.0, 0.0, -1.0), + Vectors.dense(1.0, 3.0, 2.0)) + val sparsePoints1Seq = Seq( + Vectors.sparse(3, Array(0, 1), Array(1.0, 2.0)), + Vectors.sparse(3, Array(1, 2), Array(1.0, 2.0)), + Vectors.sparse(3, Array(2), Array(-1.0)), + Vectors.sparse(3, Array(0, 1, 2), Array(1.0, 3.0, 2.0))) + point1maxes = Array(1.0, 3.0, 2.0) + + val densePoints2Seq = Seq( + Vectors.dense(1.0, 1.0, 0.0, 1.0), + Vectors.dense(0.0, 1.0, 1.0, 1.0), + Vectors.dense(-1.0, 1.0, 2.0, 0.0)) + val sparsePoints2Seq = Seq( + Vectors.sparse(4, Array(0, 1, 3), Array(1.0, 1.0, 1.0)), + Vectors.sparse(4, Array(1, 2, 3), Array(1.0, 1.0, 1.0)), + Vectors.sparse(4, Array(0, 1, 2), Array(-1.0, 1.0, 2.0))) + + val badPointsSeq = Seq( + Vectors.sparse(2, Array(0, 1), Array(1.0, 1.0)), + Vectors.sparse(3, Array(2), Array(-1.0))) + + // Sanity checks for assumptions made in tests + assert(densePoints1Seq.head.size == sparsePoints1Seq.head.size) + assert(densePoints2Seq.head.size == sparsePoints2Seq.head.size) + assert(densePoints1Seq.head.size != densePoints2Seq.head.size) + def checkPair(dvSeq: Seq[Vector], svSeq: Seq[Vector]): Unit = { + assert(dvSeq.zip(svSeq).forall { case (dv, sv) => dv.toArray === sv.toArray }, + "typo in unit test") + } + checkPair(densePoints1Seq, sparsePoints1Seq) + checkPair(densePoints2Seq, sparsePoints2Seq) + + sqlContext = new SQLContext(sc) + densePoints1 = sqlContext.createDataFrame(sc.parallelize(densePoints1Seq, 2).map(FeatureData)) + sparsePoints1 = sqlContext.createDataFrame(sc.parallelize(sparsePoints1Seq, 2).map(FeatureData)) + densePoints2 = sqlContext.createDataFrame(sc.parallelize(densePoints2Seq, 2).map(FeatureData)) + sparsePoints2 = sqlContext.createDataFrame(sc.parallelize(sparsePoints2Seq, 2).map(FeatureData)) + badPoints = sqlContext.createDataFrame(sc.parallelize(badPointsSeq, 2).map(FeatureData)) + } + + private def getIndexer: VectorIndexer = + new VectorIndexer().setInputCol("features").setOutputCol("indexed") + + test("Cannot fit an empty DataFrame") { + val rdd = sqlContext.createDataFrame(sc.parallelize(Array.empty[Vector], 2).map(FeatureData)) + val vectorIndexer = getIndexer + intercept[IllegalArgumentException] { + vectorIndexer.fit(rdd) + } + } + + test("Throws error when given RDDs with different size vectors") { + val vectorIndexer = getIndexer + val model = vectorIndexer.fit(densePoints1) // vectors of length 3 + model.transform(densePoints1) // should work + model.transform(sparsePoints1) // should work + intercept[IllegalArgumentException] { + model.transform(densePoints2) + println("Did not throw error when fit, transform were called on vectors of different lengths") + } + intercept[SparkException] { + vectorIndexer.fit(badPoints) + println("Did not throw error when fitting vectors of different lengths in same RDD.") + } + } + + test("Same result with dense and sparse vectors") { + def testDenseSparse(densePoints: DataFrame, sparsePoints: DataFrame): Unit = { + val denseVectorIndexer = getIndexer.setMaxCategories(2) + val sparseVectorIndexer = getIndexer.setMaxCategories(2) + val denseModel = denseVectorIndexer.fit(densePoints) + val sparseModel = sparseVectorIndexer.fit(sparsePoints) + val denseMap = denseModel.categoryMaps + val sparseMap = sparseModel.categoryMaps + assert(denseMap.keys.toSet == sparseMap.keys.toSet, + "Categorical features chosen from dense vs. sparse vectors did not match.") + assert(denseMap == sparseMap, + "Categorical feature value indexes chosen from dense vs. sparse vectors did not match.") + } + testDenseSparse(densePoints1, sparsePoints1) + testDenseSparse(densePoints2, sparsePoints2) + } + + test("Builds valid categorical feature value index, transform correctly, check metadata") { + def checkCategoryMaps( + data: DataFrame, + maxCategories: Int, + categoricalFeatures: Set[Int]): Unit = { + val collectedData = data.collect().map(_.getAs[Vector](0)) + val errMsg = s"checkCategoryMaps failed for input with maxCategories=$maxCategories," + + s" categoricalFeatures=${categoricalFeatures.mkString(", ")}" + try { + val vectorIndexer = getIndexer.setMaxCategories(maxCategories) + val model = vectorIndexer.fit(data) + val categoryMaps = model.categoryMaps + assert(categoryMaps.keys.toSet === categoricalFeatures) // Chose correct categorical features + val transformed = model.transform(data).select("indexed") + val indexedRDD: RDD[Vector] = transformed.map(_.getAs[Vector](0)) + val featureAttrs = AttributeGroup.fromStructField(transformed.schema("indexed")) + assert(featureAttrs.name === "indexed") + assert(featureAttrs.attributes.get.length === model.numFeatures) + categoricalFeatures.foreach { feature: Int => + val origValueSet = collectedData.map(_(feature)).toSet + val targetValueIndexSet = Range(0, origValueSet.size).toSet + val catMap = categoryMaps(feature) + assert(catMap.keys.toSet === origValueSet) // Correct categories + assert(catMap.values.toSet === targetValueIndexSet) // Correct category indices + if (origValueSet.contains(0.0)) { + assert(catMap(0.0) === 0) // value 0 gets index 0 + } + // Check transformed data + assert(indexedRDD.map(_(feature)).collect().toSet === targetValueIndexSet) + // Check metadata + val featureAttr = featureAttrs(feature) + assert(featureAttr.index.get === feature) + featureAttr match { + case attr: BinaryAttribute => + assert(attr.values.get === origValueSet.toArray.sorted.map(_.toString)) + case attr: NominalAttribute => + assert(attr.values.get === origValueSet.toArray.sorted.map(_.toString)) + assert(attr.isOrdinal.get === false) + case _ => + throw new RuntimeException(errMsg + s". Categorical feature $feature failed" + + s" metadata check. Found feature attribute: $featureAttr.") + } + } + // Check numerical feature metadata. + Range(0, model.numFeatures).filter(feature => !categoricalFeatures.contains(feature)) + .foreach { feature: Int => + val featureAttr = featureAttrs(feature) + featureAttr match { + case attr: NumericAttribute => + assert(featureAttr.index.get === feature) + case _ => + throw new RuntimeException(errMsg + s". Numerical feature $feature failed" + + s" metadata check. Found feature attribute: $featureAttr.") + } + } + } catch { + case e: org.scalatest.exceptions.TestFailedException => + println(errMsg) + throw e + } + } + checkCategoryMaps(densePoints1, maxCategories = 2, categoricalFeatures = Set(0)) + checkCategoryMaps(densePoints1, maxCategories = 3, categoricalFeatures = Set(0, 2)) + checkCategoryMaps(densePoints2, maxCategories = 2, categoricalFeatures = Set(1, 3)) + } + + test("Maintain sparsity for sparse vectors") { + def checkSparsity(data: DataFrame, maxCategories: Int): Unit = { + val points = data.collect().map(_.getAs[Vector](0)) + val vectorIndexer = getIndexer.setMaxCategories(maxCategories) + val model = vectorIndexer.fit(data) + val indexedPoints = model.transform(data).select("indexed").map(_.getAs[Vector](0)).collect() + points.zip(indexedPoints).foreach { + case (orig: SparseVector, indexed: SparseVector) => + assert(orig.indices.length == indexed.indices.length) + case _ => throw new UnknownError("Unit test has a bug in it.") // should never happen + } + } + checkSparsity(sparsePoints1, maxCategories = 2) + checkSparsity(sparsePoints2, maxCategories = 2) + } + + test("Preserve metadata") { + // For continuous features, preserve name and stats. + val featureAttributes: Array[Attribute] = point1maxes.zipWithIndex.map { case (maxVal, i) => + NumericAttribute.defaultAttr.withName(i.toString).withMax(maxVal) + } + val attrGroup = new AttributeGroup("features", featureAttributes) + val densePoints1WithMeta = + densePoints1.select(densePoints1("features").as("features", attrGroup.toMetadata)) + val vectorIndexer = getIndexer.setMaxCategories(2) + val model = vectorIndexer.fit(densePoints1WithMeta) + // Check that ML metadata are preserved. + val indexedPoints = model.transform(densePoints1WithMeta) + val transAttributes: Array[Attribute] = + AttributeGroup.fromStructField(indexedPoints.schema("indexed")).attributes.get + featureAttributes.zip(transAttributes).foreach { case (orig, trans) => + assert(orig.name === trans.name) + (orig, trans) match { + case (orig: NumericAttribute, trans: NumericAttribute) => + assert(orig.max.nonEmpty && orig.max === trans.max) + case _ => + // do nothing + // TODO: Once input features marked as categorical are handled correctly, check that here. + } + } + // Check that non-ML metadata are preserved. + TestingUtils.testPreserveMetadata(densePoints1WithMeta, model, "features", "indexed") + } +} + +private[feature] object VectorIndexerSuite { + @BeanInfo + case class FeatureData(@BeanProperty features: Vector) +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/util/TestingUtils.scala b/mllib/src/test/scala/org/apache/spark/ml/util/TestingUtils.scala new file mode 100644 index 0000000000..c44cb61b34 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/util/TestingUtils.scala @@ -0,0 +1,60 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.util + +import org.apache.spark.ml.Transformer +import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.types.MetadataBuilder +import org.scalatest.FunSuite + +private[ml] object TestingUtils extends FunSuite { + + /** + * Test whether unrelated metadata are preserved for this transformer. + * This attaches extra metadata to a column, transforms the column, and check to ensure the + * extra metadata have not changed. + * @param data Input dataset + * @param transformer Transformer to test + * @param inputCol Unique input column for Transformer. This must be the ONLY input column. + * @param outputCol Output column to test for metadata presence. + */ + def testPreserveMetadata( + data: DataFrame, + transformer: Transformer, + inputCol: String, + outputCol: String): Unit = { + // Create some fake metadata + val origMetadata = data.schema(inputCol).metadata + val metaKey = "__testPreserveMetadata__fake_key" + val metaValue = 12345 + assert(!origMetadata.contains(metaKey), + s"Unit test with testPreserveMetadata will fail since metadata key was present: $metaKey") + val newMetadata = + new MetadataBuilder().withMetadata(origMetadata).putLong(metaKey, metaValue).build() + // Add metadata to the inputCol + val withMetadata = data.select(data(inputCol).as(inputCol, newMetadata)) + // Transform, and ensure extra metadata was not affected + val transformed = transformer.transform(withMetadata) + val transMetadata = transformed.schema(outputCol).metadata + assert(transMetadata.contains(metaKey), + "Unit test with testPreserveMetadata failed; extra metadata key was not present.") + assert(transMetadata.getLong(metaKey) === metaValue, + "Unit test with testPreserveMetadata failed; extra metadata value was wrong." + + s" Expected $metaValue but found ${transMetadata.getLong(metaKey)}") + } +} |