aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
authorJoseph K. Bradley <joseph@databricks.com>2015-04-12 22:38:27 -0700
committerXiangrui Meng <meng@databricks.com>2015-04-12 22:38:27 -0700
commitd3792f54974e16cbe8f10b3091d248e0bdd48986 (patch)
tree89d679f7a9f76599841f169239021f190968654b /mllib
parentfc17661475443d9f0a8d28e3439feeb7a7bca67b (diff)
downloadspark-d3792f54974e16cbe8f10b3091d248e0bdd48986.tar.gz
spark-d3792f54974e16cbe8f10b3091d248e0bdd48986.tar.bz2
spark-d3792f54974e16cbe8f10b3091d248e0bdd48986.zip
[SPARK-4081] [mllib] VectorIndexer
**Ready for review!** Since the original PR, I moved the code to the spark.ml API and renamed this to VectorIndexer. This introduces a VectorIndexer class which does the following: * VectorIndexer.fit(): collect statistics about how many values each feature in a dataset (RDD[Vector]) can take (limited by maxCategories) * Feature which exceed maxCategories are declared continuous, and the Model will treat them as such. * VectorIndexerModel.transform(): Convert categorical feature values to corresponding 0-based indices Design notes: * This maintains sparsity in vectors by ensuring that categorical feature value 0.0 gets index 0. * This does not yet support transforming data with new (unknown) categorical feature values. That can be added later. * This is necessary for DecisionTree and tree ensembles. Reviewers: Please check my use of metadata and my unit tests for it; I'm not sure if I covered everything in the tests. Other notes: * This also adds a public toMetadata method to AttributeGroup (for simpler construction of metadata). CC: mengxr Author: Joseph K. Bradley <joseph@databricks.com> Closes #3000 from jkbradley/indexer and squashes the following commits: 5956d91 [Joseph K. Bradley] minor cleanups f5c57a8 [Joseph K. Bradley] added Java test suite 643b444 [Joseph K. Bradley] removed FeatureTests 02236c3 [Joseph K. Bradley] Updated VectorIndexer, ready for PR 286d221 [Joseph K. Bradley] Reworked DatasetIndexer for spark.ml API, and renamed it to VectorIndexer 12e6cf2 [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into indexer 6d8f3f1 [Joseph K. Bradley] Added partly done DatasetIndexer to spark.ml 6a2f553 [Joseph K. Bradley] Updated TODO for allowUnknownCategories 3f041f8 [Joseph K. Bradley] Final cleanups for DatasetIndexer 038b9e3 [Joseph K. Bradley] DatasetIndexer now maintains sparsity in SparseVector 3a4a0bd [Joseph K. Bradley] Added another test for DatasetIndexer 2006923 [Joseph K. Bradley] DatasetIndexer now passes tests f409987 [Joseph K. Bradley] partly done with DatasetIndexerSuite 5e7c874 [Joseph K. Bradley] working on DatasetIndexer
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala3
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/attribute/AttributeGroup.scala21
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala393
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/param/params.scala20
-rw-r--r--mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorIndexerSuite.java70
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/attribute/AttributeGroupSuite.scala8
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/feature/NormalizerSuite.scala7
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/feature/VectorIndexerSuite.scala255
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/util/TestingUtils.scala60
9 files changed, 818 insertions, 19 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala b/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala
index c4a3610330..a455341a1f 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala
@@ -47,6 +47,9 @@ abstract class PipelineStage extends Serializable with Logging {
/**
* Derives the output schema from the input schema and parameters, optionally with logging.
+ *
+ * This should be optimistic. If it is unclear whether the schema will be valid, then it should
+ * be assumed valid until proven otherwise.
*/
protected def transformSchema(
schema: StructType,
diff --git a/mllib/src/main/scala/org/apache/spark/ml/attribute/AttributeGroup.scala b/mllib/src/main/scala/org/apache/spark/ml/attribute/AttributeGroup.scala
index 970e6ad551..aa27a668f1 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/attribute/AttributeGroup.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/attribute/AttributeGroup.scala
@@ -106,7 +106,7 @@ class AttributeGroup private (
def getAttr(attrIndex: Int): Attribute = this(attrIndex)
/** Converts to metadata without name. */
- private[attribute] def toMetadata: Metadata = {
+ private[attribute] def toMetadataImpl: Metadata = {
import AttributeKeys._
val bldr = new MetadataBuilder()
if (attributes.isDefined) {
@@ -142,17 +142,24 @@ class AttributeGroup private (
bldr.build()
}
- /** Converts to a StructField with some existing metadata. */
- def toStructField(existingMetadata: Metadata): StructField = {
- val newMetadata = new MetadataBuilder()
+ /** Converts to ML metadata with some existing metadata. */
+ def toMetadata(existingMetadata: Metadata): Metadata = {
+ new MetadataBuilder()
.withMetadata(existingMetadata)
- .putMetadata(AttributeKeys.ML_ATTR, toMetadata)
+ .putMetadata(AttributeKeys.ML_ATTR, toMetadataImpl)
.build()
- StructField(name, new VectorUDT, nullable = false, newMetadata)
+ }
+
+ /** Converts to ML metadata */
+ def toMetadata: Metadata = toMetadata(Metadata.empty)
+
+ /** Converts to a StructField with some existing metadata. */
+ def toStructField(existingMetadata: Metadata): StructField = {
+ StructField(name, new VectorUDT, nullable = false, toMetadata(existingMetadata))
}
/** Converts to a StructField. */
- def toStructField(): StructField = toStructField(Metadata.empty)
+ def toStructField: StructField = toStructField(Metadata.empty)
override def equals(other: Any): Boolean = {
other match {
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala
new file mode 100644
index 0000000000..8760960e19
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala
@@ -0,0 +1,393 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.feature
+
+import org.apache.spark.annotation.AlphaComponent
+import org.apache.spark.ml.{Estimator, Model}
+import org.apache.spark.ml.attribute.{BinaryAttribute, NumericAttribute, NominalAttribute,
+ Attribute, AttributeGroup}
+import org.apache.spark.ml.param.{HasInputCol, HasOutputCol, IntParam, ParamMap, Params}
+import org.apache.spark.mllib.linalg.{SparseVector, DenseVector, Vector, VectorUDT}
+import org.apache.spark.sql.{Row, DataFrame}
+import org.apache.spark.sql.functions.callUDF
+import org.apache.spark.sql.types.{StructField, StructType}
+import org.apache.spark.util.collection.OpenHashSet
+
+
+/** Private trait for params for VectorIndexer and VectorIndexerModel */
+private[ml] trait VectorIndexerParams extends Params with HasInputCol with HasOutputCol {
+
+ /**
+ * Threshold for the number of values a categorical feature can take.
+ * If a feature is found to have > maxCategories values, then it is declared continuous.
+ *
+ * (default = 20)
+ */
+ val maxCategories = new IntParam(this, "maxCategories",
+ "Threshold for the number of values a categorical feature can take." +
+ " If a feature is found to have > maxCategories values, then it is declared continuous.",
+ Some(20))
+
+ /** @group getParam */
+ def getMaxCategories: Int = get(maxCategories)
+}
+
+/**
+ * :: AlphaComponent ::
+ *
+ * Class for indexing categorical feature columns in a dataset of [[Vector]].
+ *
+ * This has 2 usage modes:
+ * - Automatically identify categorical features (default behavior)
+ * - This helps process a dataset of unknown vectors into a dataset with some continuous
+ * features and some categorical features. The choice between continuous and categorical
+ * is based upon a maxCategories parameter.
+ * - Set maxCategories to the maximum number of categorical any categorical feature should have.
+ * - E.g.: Feature 0 has unique values {-1.0, 0.0}, and feature 1 values {1.0, 3.0, 5.0}.
+ * If maxCategories = 2, then feature 0 will be declared categorical and use indices {0, 1},
+ * and feature 1 will be declared continuous.
+ * - Index all features, if all features are categorical
+ * - If maxCategories is set to be very large, then this will build an index of unique
+ * values for all features.
+ * - Warning: This can cause problems if features are continuous since this will collect ALL
+ * unique values to the driver.
+ * - E.g.: Feature 0 has unique values {-1.0, 0.0}, and feature 1 values {1.0, 3.0, 5.0}.
+ * If maxCategories >= 3, then both features will be declared categorical.
+ *
+ * This returns a model which can transform categorical features to use 0-based indices.
+ *
+ * Index stability:
+ * - This is not guaranteed to choose the same category index across multiple runs.
+ * - If a categorical feature includes value 0, then this is guaranteed to map value 0 to index 0.
+ * This maintains vector sparsity.
+ * - More stability may be added in the future.
+ *
+ * TODO: Future extensions: The following functionality is planned for the future:
+ * - Preserve metadata in transform; if a feature's metadata is already present, do not recompute.
+ * - Specify certain features to not index, either via a parameter or via existing metadata.
+ * - Add warning if a categorical feature has only 1 category.
+ * - Add option for allowing unknown categories.
+ */
+@AlphaComponent
+class VectorIndexer extends Estimator[VectorIndexerModel] with VectorIndexerParams {
+
+ /** @group setParam */
+ def setMaxCategories(value: Int): this.type = {
+ require(value > 1,
+ s"DatasetIndexer given maxCategories = value, but requires maxCategories > 1.")
+ set(maxCategories, value)
+ }
+
+ /** @group setParam */
+ def setInputCol(value: String): this.type = set(inputCol, value)
+
+ /** @group setParam */
+ def setOutputCol(value: String): this.type = set(outputCol, value)
+
+ override def fit(dataset: DataFrame, paramMap: ParamMap): VectorIndexerModel = {
+ transformSchema(dataset.schema, paramMap, logging = true)
+ val map = this.paramMap ++ paramMap
+ val firstRow = dataset.select(map(inputCol)).take(1)
+ require(firstRow.length == 1, s"VectorIndexer cannot be fit on an empty dataset.")
+ val numFeatures = firstRow(0).getAs[Vector](0).size
+ val vectorDataset = dataset.select(map(inputCol)).map { case Row(v: Vector) => v }
+ val maxCats = map(maxCategories)
+ val categoryStats: VectorIndexer.CategoryStats = vectorDataset.mapPartitions { iter =>
+ val localCatStats = new VectorIndexer.CategoryStats(numFeatures, maxCats)
+ iter.foreach(localCatStats.addVector)
+ Iterator(localCatStats)
+ }.reduce((stats1, stats2) => stats1.merge(stats2))
+ val model = new VectorIndexerModel(this, map, numFeatures, categoryStats.getCategoryMaps)
+ Params.inheritValues(map, this, model)
+ model
+ }
+
+ override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {
+ // We do not transfer feature metadata since we do not know what types of features we will
+ // produce in transform().
+ val map = this.paramMap ++ paramMap
+ val dataType = new VectorUDT
+ require(map.contains(inputCol), s"VectorIndexer requires input column parameter: $inputCol")
+ require(map.contains(outputCol), s"VectorIndexer requires output column parameter: $outputCol")
+ checkInputColumn(schema, map(inputCol), dataType)
+ addOutputColumn(schema, map(outputCol), dataType)
+ }
+}
+
+private object VectorIndexer {
+
+ /**
+ * Helper class for tracking unique values for each feature.
+ *
+ * TODO: Track which features are known to be continuous already; do not update counts for them.
+ *
+ * @param numFeatures This class fails if it encounters a Vector whose length is not numFeatures.
+ * @param maxCategories This class caps the number of unique values collected at maxCategories.
+ */
+ class CategoryStats(private val numFeatures: Int, private val maxCategories: Int)
+ extends Serializable {
+
+ /** featureValueSets[feature index] = set of unique values */
+ private val featureValueSets =
+ Array.fill[OpenHashSet[Double]](numFeatures)(new OpenHashSet[Double]())
+
+ /** Merge with another instance, modifying this instance. */
+ def merge(other: CategoryStats): CategoryStats = {
+ featureValueSets.zip(other.featureValueSets).foreach { case (thisValSet, otherValSet) =>
+ otherValSet.iterator.foreach { x =>
+ // Once we have found > maxCategories values, we know the feature is continuous
+ // and do not need to collect more values for it.
+ if (thisValSet.size <= maxCategories) thisValSet.add(x)
+ }
+ }
+ this
+ }
+
+ /** Add a new vector to this index, updating sets of unique feature values */
+ def addVector(v: Vector): Unit = {
+ require(v.size == numFeatures, s"VectorIndexer expected $numFeatures features but" +
+ s" found vector of size ${v.size}.")
+ v match {
+ case dv: DenseVector => addDenseVector(dv)
+ case sv: SparseVector => addSparseVector(sv)
+ }
+ }
+
+ /**
+ * Based on stats collected, decide which features are categorical,
+ * and choose indices for categories.
+ *
+ * Sparsity: This tries to maintain sparsity by treating value 0.0 specially.
+ * If a categorical feature takes value 0.0, then value 0.0 is given index 0.
+ *
+ * @return Feature value index. Keys are categorical feature indices (column indices).
+ * Values are mappings from original features values to 0-based category indices.
+ */
+ def getCategoryMaps: Map[Int, Map[Double, Int]] = {
+ // Filter out features which are declared continuous.
+ featureValueSets.zipWithIndex.filter(_._1.size <= maxCategories).map {
+ case (featureValues: OpenHashSet[Double], featureIndex: Int) =>
+ var sortedFeatureValues = featureValues.iterator.filter(_ != 0.0).toArray.sorted
+ val zeroExists = sortedFeatureValues.length + 1 == featureValues.size
+ if (zeroExists) {
+ sortedFeatureValues = 0.0 +: sortedFeatureValues
+ }
+ val categoryMap: Map[Double, Int] = sortedFeatureValues.zipWithIndex.toMap
+ (featureIndex, categoryMap)
+ }.toMap
+ }
+
+ private def addDenseVector(dv: DenseVector): Unit = {
+ var i = 0
+ while (i < dv.size) {
+ if (featureValueSets(i).size <= maxCategories) {
+ featureValueSets(i).add(dv(i))
+ }
+ i += 1
+ }
+ }
+
+ private def addSparseVector(sv: SparseVector): Unit = {
+ // TODO: This might be able to handle 0's more efficiently.
+ var vecIndex = 0 // index into vector
+ var k = 0 // index into non-zero elements
+ while (vecIndex < sv.size) {
+ val featureValue = if (k < sv.indices.length && vecIndex == sv.indices(k)) {
+ k += 1
+ sv.values(k - 1)
+ } else {
+ 0.0
+ }
+ if (featureValueSets(vecIndex).size <= maxCategories) {
+ featureValueSets(vecIndex).add(featureValue)
+ }
+ vecIndex += 1
+ }
+ }
+ }
+}
+
+/**
+ * :: AlphaComponent ::
+ *
+ * Transform categorical features to use 0-based indices instead of their original values.
+ * - Categorical features are mapped to indices.
+ * - Continuous features (columns) are left unchanged.
+ * This also appends metadata to the output column, marking features as Numeric (continuous),
+ * Nominal (categorical), or Binary (either continuous or categorical).
+ *
+ * This maintains vector sparsity.
+ *
+ * @param numFeatures Number of features, i.e., length of Vectors which this transforms
+ * @param categoryMaps Feature value index. Keys are categorical feature indices (column indices).
+ * Values are maps from original features values to 0-based category indices.
+ * If a feature is not in this map, it is treated as continuous.
+ */
+@AlphaComponent
+class VectorIndexerModel private[ml] (
+ override val parent: VectorIndexer,
+ override val fittingParamMap: ParamMap,
+ val numFeatures: Int,
+ val categoryMaps: Map[Int, Map[Double, Int]])
+ extends Model[VectorIndexerModel] with VectorIndexerParams {
+
+ /**
+ * Pre-computed feature attributes, with some missing info.
+ * In transform(), set attribute name and other info, if available.
+ */
+ private val partialFeatureAttributes: Array[Attribute] = {
+ val attrs = new Array[Attribute](numFeatures)
+ var categoricalFeatureCount = 0 // validity check for numFeatures, categoryMaps
+ var featureIndex = 0
+ while (featureIndex < numFeatures) {
+ if (categoryMaps.contains(featureIndex)) {
+ // categorical feature
+ val featureValues: Array[String] =
+ categoryMaps(featureIndex).toArray.sortBy(_._1).map(_._1).map(_.toString)
+ if (featureValues.length == 2) {
+ attrs(featureIndex) = new BinaryAttribute(index = Some(featureIndex),
+ values = Some(featureValues))
+ } else {
+ attrs(featureIndex) = new NominalAttribute(index = Some(featureIndex),
+ isOrdinal = Some(false), values = Some(featureValues))
+ }
+ categoricalFeatureCount += 1
+ } else {
+ // continuous feature
+ attrs(featureIndex) = new NumericAttribute(index = Some(featureIndex))
+ }
+ featureIndex += 1
+ }
+ require(categoricalFeatureCount == categoryMaps.size, "VectorIndexerModel given categoryMaps" +
+ s" with keys outside expected range [0,...,numFeatures), where numFeatures=$numFeatures")
+ attrs
+ }
+
+ // TODO: Check more carefully about whether this whole class will be included in a closure.
+
+ private val transformFunc: Vector => Vector = {
+ val sortedCategoricalFeatureIndices = categoryMaps.keys.toArray.sorted
+ val localVectorMap = categoryMaps
+ val f: Vector => Vector = {
+ case dv: DenseVector =>
+ val tmpv = dv.copy
+ localVectorMap.foreach { case (featureIndex: Int, categoryMap: Map[Double, Int]) =>
+ tmpv.values(featureIndex) = categoryMap(tmpv(featureIndex))
+ }
+ tmpv
+ case sv: SparseVector =>
+ // We use the fact that categorical value 0 is always mapped to index 0.
+ val tmpv = sv.copy
+ var catFeatureIdx = 0 // index into sortedCategoricalFeatureIndices
+ var k = 0 // index into non-zero elements of sparse vector
+ while (catFeatureIdx < sortedCategoricalFeatureIndices.length && k < tmpv.indices.length) {
+ val featureIndex = sortedCategoricalFeatureIndices(catFeatureIdx)
+ if (featureIndex < tmpv.indices(k)) {
+ catFeatureIdx += 1
+ } else if (featureIndex > tmpv.indices(k)) {
+ k += 1
+ } else {
+ tmpv.values(k) = localVectorMap(featureIndex)(tmpv.values(k))
+ catFeatureIdx += 1
+ k += 1
+ }
+ }
+ tmpv
+ }
+ f
+ }
+
+ /** @group setParam */
+ def setInputCol(value: String): this.type = set(inputCol, value)
+
+ /** @group setParam */
+ def setOutputCol(value: String): this.type = set(outputCol, value)
+
+ override def transform(dataset: DataFrame, paramMap: ParamMap): DataFrame = {
+ transformSchema(dataset.schema, paramMap, logging = true)
+ val map = this.paramMap ++ paramMap
+ val newField = prepOutputField(dataset.schema, map)
+ val newCol = callUDF(transformFunc, new VectorUDT, dataset(map(inputCol)))
+ // For now, just check the first row of inputCol for vector length.
+ val firstRow = dataset.select(map(inputCol)).take(1)
+ if (firstRow.length != 0) {
+ val actualNumFeatures = firstRow(0).getAs[Vector](0).size
+ require(numFeatures == actualNumFeatures, "VectorIndexerModel expected vector of length" +
+ s" $numFeatures but found length $actualNumFeatures")
+ }
+ dataset.withColumn(map(outputCol), newCol.as(map(outputCol), newField.metadata))
+ }
+
+ override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {
+ val map = this.paramMap ++ paramMap
+ val dataType = new VectorUDT
+ require(map.contains(inputCol),
+ s"VectorIndexerModel requires input column parameter: $inputCol")
+ require(map.contains(outputCol),
+ s"VectorIndexerModel requires output column parameter: $outputCol")
+ checkInputColumn(schema, map(inputCol), dataType)
+
+ val origAttrGroup = AttributeGroup.fromStructField(schema(map(inputCol)))
+ val origNumFeatures: Option[Int] = if (origAttrGroup.attributes.nonEmpty) {
+ Some(origAttrGroup.attributes.get.length)
+ } else {
+ origAttrGroup.numAttributes
+ }
+ require(origNumFeatures.forall(_ == numFeatures), "VectorIndexerModel expected" +
+ s" $numFeatures features, but input column ${map(inputCol)} had metadata specifying" +
+ s" ${origAttrGroup.numAttributes.get} features.")
+
+ val newField = prepOutputField(schema, map)
+ val outputFields = schema.fields :+ newField
+ StructType(outputFields)
+ }
+
+ /**
+ * Prepare the output column field, including per-feature metadata.
+ * @param schema Input schema
+ * @param map Parameter map (with this class' embedded parameter map folded in)
+ * @return Output column field
+ */
+ private def prepOutputField(schema: StructType, map: ParamMap): StructField = {
+ val origAttrGroup = AttributeGroup.fromStructField(schema(map(inputCol)))
+ val featureAttributes: Array[Attribute] = if (origAttrGroup.attributes.nonEmpty) {
+ // Convert original attributes to modified attributes
+ val origAttrs: Array[Attribute] = origAttrGroup.attributes.get
+ origAttrs.zip(partialFeatureAttributes).map {
+ case (origAttr: Attribute, featAttr: BinaryAttribute) =>
+ if (origAttr.name.nonEmpty) {
+ featAttr.withName(origAttr.name.get)
+ } else {
+ featAttr
+ }
+ case (origAttr: Attribute, featAttr: NominalAttribute) =>
+ if (origAttr.name.nonEmpty) {
+ featAttr.withName(origAttr.name.get)
+ } else {
+ featAttr
+ }
+ case (origAttr: Attribute, featAttr: NumericAttribute) =>
+ origAttr.withIndex(featAttr.index.get)
+ }
+ } else {
+ partialFeatureAttributes
+ }
+ val newAttributeGroup = new AttributeGroup(map(outputCol), featureAttributes)
+ newAttributeGroup.toStructField(schema(map(inputCol)).metadata)
+ }
+}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala
index 17ece897a6..7d5178d0ab 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala
@@ -198,23 +198,31 @@ trait Params extends Identifiable with Serializable {
/**
* Check whether the given schema contains an input column.
- * @param colName Parameter name for the input column.
- * @param dataType SQL DataType of the input column.
+ * @param colName Input column name
+ * @param dataType Input column DataType
*/
protected def checkInputColumn(schema: StructType, colName: String, dataType: DataType): Unit = {
val actualDataType = schema(colName).dataType
- require(actualDataType.equals(dataType),
- s"Input column $colName must be of type $dataType" +
- s" but was actually $actualDataType. Column param description: ${getParam(colName)}")
+ require(actualDataType.equals(dataType), s"Input column $colName must be of type $dataType" +
+ s" but was actually $actualDataType. Column param description: ${getParam(colName)}")
}
+ /**
+ * Add an output column to the given schema.
+ * This fails if the given output column already exists.
+ * @param schema Initial schema (not modified)
+ * @param colName Output column name. If this column name is an empy String "", this method
+ * returns the initial schema, unchanged. This allows users to disable output
+ * columns.
+ * @param dataType Output column DataType
+ */
protected def addOutputColumn(
schema: StructType,
colName: String,
dataType: DataType): StructType = {
if (colName.length == 0) return schema
val fieldNames = schema.fieldNames
- require(!fieldNames.contains(colName), s"Prediction column $colName already exists.")
+ require(!fieldNames.contains(colName), s"Output column $colName already exists.")
val outputFields = schema.fields ++ Seq(StructField(colName, dataType, nullable = false))
StructType(outputFields)
}
diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorIndexerSuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorIndexerSuite.java
new file mode 100644
index 0000000000..161100134c
--- /dev/null
+++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorIndexerSuite.java
@@ -0,0 +1,70 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.feature;
+
+import java.io.Serializable;
+import java.util.List;
+
+import org.junit.After;
+import org.junit.Assert;
+import org.junit.Before;
+import org.junit.Test;
+
+import com.google.common.collect.Lists;
+
+import org.apache.spark.api.java.JavaSparkContext;
+import org.apache.spark.ml.feature.VectorIndexerSuite.FeatureData;
+import org.apache.spark.mllib.linalg.Vectors;
+import org.apache.spark.sql.DataFrame;
+import org.apache.spark.sql.SQLContext;
+
+
+public class JavaVectorIndexerSuite implements Serializable {
+ private transient JavaSparkContext sc;
+
+ @Before
+ public void setUp() {
+ sc = new JavaSparkContext("local", "JavaVectorIndexerSuite");
+ }
+
+ @After
+ public void tearDown() {
+ sc.stop();
+ sc = null;
+ }
+
+ @Test
+ public void vectorIndexerAPI() {
+ // The tests are to check Java compatibility.
+ List<FeatureData> points = Lists.newArrayList(
+ new FeatureData(Vectors.dense(0.0, -2.0)),
+ new FeatureData(Vectors.dense(1.0, 3.0)),
+ new FeatureData(Vectors.dense(1.0, 4.0))
+ );
+ SQLContext sqlContext = new SQLContext(sc);
+ DataFrame data = sqlContext.createDataFrame(sc.parallelize(points, 2), FeatureData.class);
+ VectorIndexer indexer = new VectorIndexer()
+ .setInputCol("features")
+ .setOutputCol("indexed")
+ .setMaxCategories(2);
+ VectorIndexerModel model = indexer.fit(data);
+ Assert.assertEquals(model.numFeatures(), 2);
+ Assert.assertEquals(model.categoryMaps().size(), 1);
+ DataFrame indexedData = model.transform(data);
+ }
+}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/attribute/AttributeGroupSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/attribute/AttributeGroupSuite.scala
index 3fb6e2ec46..0dcfe5a200 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/attribute/AttributeGroupSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/attribute/AttributeGroupSuite.scala
@@ -43,8 +43,8 @@ class AttributeGroupSuite extends FunSuite {
intercept[NoSuchElementException] {
group("abc")
}
- assert(group === AttributeGroup.fromMetadata(group.toMetadata, group.name))
- assert(group === AttributeGroup.fromStructField(group.toStructField()))
+ assert(group === AttributeGroup.fromMetadata(group.toMetadataImpl, group.name))
+ assert(group === AttributeGroup.fromStructField(group.toStructField))
}
test("attribute group without attributes") {
@@ -53,8 +53,8 @@ class AttributeGroupSuite extends FunSuite {
assert(group0.numAttributes === Some(10))
assert(group0.size === 10)
assert(group0.attributes.isEmpty)
- assert(group0 === AttributeGroup.fromMetadata(group0.toMetadata, group0.name))
- assert(group0 === AttributeGroup.fromStructField(group0.toStructField()))
+ assert(group0 === AttributeGroup.fromMetadata(group0.toMetadataImpl, group0.name))
+ assert(group0 === AttributeGroup.fromStructField(group0.toStructField))
val group1 = new AttributeGroup("item")
assert(group1.name === "item")
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/NormalizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/NormalizerSuite.scala
index a18c335952..9d09f24709 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/NormalizerSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/NormalizerSuite.scala
@@ -24,7 +24,6 @@ import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.mllib.util.TestingUtils._
import org.apache.spark.sql.{DataFrame, Row, SQLContext}
-private case class DataSet(features: Vector)
class NormalizerSuite extends FunSuite with MLlibTestSparkContext {
@@ -63,7 +62,7 @@ class NormalizerSuite extends FunSuite with MLlibTestSparkContext {
)
val sqlContext = new SQLContext(sc)
- dataFrame = sqlContext.createDataFrame(sc.parallelize(data, 2).map(DataSet))
+ dataFrame = sqlContext.createDataFrame(sc.parallelize(data, 2).map(NormalizerSuite.FeatureData))
normalizer = new Normalizer()
.setInputCol("features")
.setOutputCol("normalized_features")
@@ -107,3 +106,7 @@ class NormalizerSuite extends FunSuite with MLlibTestSparkContext {
assertValues(result, l1Normalized)
}
}
+
+private object NormalizerSuite {
+ case class FeatureData(features: Vector)
+}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorIndexerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorIndexerSuite.scala
new file mode 100644
index 0000000000..61c46c85a7
--- /dev/null
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorIndexerSuite.scala
@@ -0,0 +1,255 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.feature
+
+import scala.beans.{BeanInfo, BeanProperty}
+
+import org.scalatest.FunSuite
+
+import org.apache.spark.SparkException
+import org.apache.spark.ml.attribute._
+import org.apache.spark.ml.util.TestingUtils
+import org.apache.spark.mllib.linalg.{SparseVector, Vector, Vectors}
+import org.apache.spark.mllib.util.MLlibTestSparkContext
+import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.{DataFrame, SQLContext}
+
+
+class VectorIndexerSuite extends FunSuite with MLlibTestSparkContext {
+
+ import VectorIndexerSuite.FeatureData
+
+ @transient var sqlContext: SQLContext = _
+
+ // identical, of length 3
+ @transient var densePoints1: DataFrame = _
+ @transient var sparsePoints1: DataFrame = _
+ @transient var point1maxes: Array[Double] = _
+
+ // identical, of length 2
+ @transient var densePoints2: DataFrame = _
+ @transient var sparsePoints2: DataFrame = _
+
+ // different lengths
+ @transient var badPoints: DataFrame = _
+
+ override def beforeAll(): Unit = {
+ super.beforeAll()
+
+ val densePoints1Seq = Seq(
+ Vectors.dense(1.0, 2.0, 0.0),
+ Vectors.dense(0.0, 1.0, 2.0),
+ Vectors.dense(0.0, 0.0, -1.0),
+ Vectors.dense(1.0, 3.0, 2.0))
+ val sparsePoints1Seq = Seq(
+ Vectors.sparse(3, Array(0, 1), Array(1.0, 2.0)),
+ Vectors.sparse(3, Array(1, 2), Array(1.0, 2.0)),
+ Vectors.sparse(3, Array(2), Array(-1.0)),
+ Vectors.sparse(3, Array(0, 1, 2), Array(1.0, 3.0, 2.0)))
+ point1maxes = Array(1.0, 3.0, 2.0)
+
+ val densePoints2Seq = Seq(
+ Vectors.dense(1.0, 1.0, 0.0, 1.0),
+ Vectors.dense(0.0, 1.0, 1.0, 1.0),
+ Vectors.dense(-1.0, 1.0, 2.0, 0.0))
+ val sparsePoints2Seq = Seq(
+ Vectors.sparse(4, Array(0, 1, 3), Array(1.0, 1.0, 1.0)),
+ Vectors.sparse(4, Array(1, 2, 3), Array(1.0, 1.0, 1.0)),
+ Vectors.sparse(4, Array(0, 1, 2), Array(-1.0, 1.0, 2.0)))
+
+ val badPointsSeq = Seq(
+ Vectors.sparse(2, Array(0, 1), Array(1.0, 1.0)),
+ Vectors.sparse(3, Array(2), Array(-1.0)))
+
+ // Sanity checks for assumptions made in tests
+ assert(densePoints1Seq.head.size == sparsePoints1Seq.head.size)
+ assert(densePoints2Seq.head.size == sparsePoints2Seq.head.size)
+ assert(densePoints1Seq.head.size != densePoints2Seq.head.size)
+ def checkPair(dvSeq: Seq[Vector], svSeq: Seq[Vector]): Unit = {
+ assert(dvSeq.zip(svSeq).forall { case (dv, sv) => dv.toArray === sv.toArray },
+ "typo in unit test")
+ }
+ checkPair(densePoints1Seq, sparsePoints1Seq)
+ checkPair(densePoints2Seq, sparsePoints2Seq)
+
+ sqlContext = new SQLContext(sc)
+ densePoints1 = sqlContext.createDataFrame(sc.parallelize(densePoints1Seq, 2).map(FeatureData))
+ sparsePoints1 = sqlContext.createDataFrame(sc.parallelize(sparsePoints1Seq, 2).map(FeatureData))
+ densePoints2 = sqlContext.createDataFrame(sc.parallelize(densePoints2Seq, 2).map(FeatureData))
+ sparsePoints2 = sqlContext.createDataFrame(sc.parallelize(sparsePoints2Seq, 2).map(FeatureData))
+ badPoints = sqlContext.createDataFrame(sc.parallelize(badPointsSeq, 2).map(FeatureData))
+ }
+
+ private def getIndexer: VectorIndexer =
+ new VectorIndexer().setInputCol("features").setOutputCol("indexed")
+
+ test("Cannot fit an empty DataFrame") {
+ val rdd = sqlContext.createDataFrame(sc.parallelize(Array.empty[Vector], 2).map(FeatureData))
+ val vectorIndexer = getIndexer
+ intercept[IllegalArgumentException] {
+ vectorIndexer.fit(rdd)
+ }
+ }
+
+ test("Throws error when given RDDs with different size vectors") {
+ val vectorIndexer = getIndexer
+ val model = vectorIndexer.fit(densePoints1) // vectors of length 3
+ model.transform(densePoints1) // should work
+ model.transform(sparsePoints1) // should work
+ intercept[IllegalArgumentException] {
+ model.transform(densePoints2)
+ println("Did not throw error when fit, transform were called on vectors of different lengths")
+ }
+ intercept[SparkException] {
+ vectorIndexer.fit(badPoints)
+ println("Did not throw error when fitting vectors of different lengths in same RDD.")
+ }
+ }
+
+ test("Same result with dense and sparse vectors") {
+ def testDenseSparse(densePoints: DataFrame, sparsePoints: DataFrame): Unit = {
+ val denseVectorIndexer = getIndexer.setMaxCategories(2)
+ val sparseVectorIndexer = getIndexer.setMaxCategories(2)
+ val denseModel = denseVectorIndexer.fit(densePoints)
+ val sparseModel = sparseVectorIndexer.fit(sparsePoints)
+ val denseMap = denseModel.categoryMaps
+ val sparseMap = sparseModel.categoryMaps
+ assert(denseMap.keys.toSet == sparseMap.keys.toSet,
+ "Categorical features chosen from dense vs. sparse vectors did not match.")
+ assert(denseMap == sparseMap,
+ "Categorical feature value indexes chosen from dense vs. sparse vectors did not match.")
+ }
+ testDenseSparse(densePoints1, sparsePoints1)
+ testDenseSparse(densePoints2, sparsePoints2)
+ }
+
+ test("Builds valid categorical feature value index, transform correctly, check metadata") {
+ def checkCategoryMaps(
+ data: DataFrame,
+ maxCategories: Int,
+ categoricalFeatures: Set[Int]): Unit = {
+ val collectedData = data.collect().map(_.getAs[Vector](0))
+ val errMsg = s"checkCategoryMaps failed for input with maxCategories=$maxCategories," +
+ s" categoricalFeatures=${categoricalFeatures.mkString(", ")}"
+ try {
+ val vectorIndexer = getIndexer.setMaxCategories(maxCategories)
+ val model = vectorIndexer.fit(data)
+ val categoryMaps = model.categoryMaps
+ assert(categoryMaps.keys.toSet === categoricalFeatures) // Chose correct categorical features
+ val transformed = model.transform(data).select("indexed")
+ val indexedRDD: RDD[Vector] = transformed.map(_.getAs[Vector](0))
+ val featureAttrs = AttributeGroup.fromStructField(transformed.schema("indexed"))
+ assert(featureAttrs.name === "indexed")
+ assert(featureAttrs.attributes.get.length === model.numFeatures)
+ categoricalFeatures.foreach { feature: Int =>
+ val origValueSet = collectedData.map(_(feature)).toSet
+ val targetValueIndexSet = Range(0, origValueSet.size).toSet
+ val catMap = categoryMaps(feature)
+ assert(catMap.keys.toSet === origValueSet) // Correct categories
+ assert(catMap.values.toSet === targetValueIndexSet) // Correct category indices
+ if (origValueSet.contains(0.0)) {
+ assert(catMap(0.0) === 0) // value 0 gets index 0
+ }
+ // Check transformed data
+ assert(indexedRDD.map(_(feature)).collect().toSet === targetValueIndexSet)
+ // Check metadata
+ val featureAttr = featureAttrs(feature)
+ assert(featureAttr.index.get === feature)
+ featureAttr match {
+ case attr: BinaryAttribute =>
+ assert(attr.values.get === origValueSet.toArray.sorted.map(_.toString))
+ case attr: NominalAttribute =>
+ assert(attr.values.get === origValueSet.toArray.sorted.map(_.toString))
+ assert(attr.isOrdinal.get === false)
+ case _ =>
+ throw new RuntimeException(errMsg + s". Categorical feature $feature failed" +
+ s" metadata check. Found feature attribute: $featureAttr.")
+ }
+ }
+ // Check numerical feature metadata.
+ Range(0, model.numFeatures).filter(feature => !categoricalFeatures.contains(feature))
+ .foreach { feature: Int =>
+ val featureAttr = featureAttrs(feature)
+ featureAttr match {
+ case attr: NumericAttribute =>
+ assert(featureAttr.index.get === feature)
+ case _ =>
+ throw new RuntimeException(errMsg + s". Numerical feature $feature failed" +
+ s" metadata check. Found feature attribute: $featureAttr.")
+ }
+ }
+ } catch {
+ case e: org.scalatest.exceptions.TestFailedException =>
+ println(errMsg)
+ throw e
+ }
+ }
+ checkCategoryMaps(densePoints1, maxCategories = 2, categoricalFeatures = Set(0))
+ checkCategoryMaps(densePoints1, maxCategories = 3, categoricalFeatures = Set(0, 2))
+ checkCategoryMaps(densePoints2, maxCategories = 2, categoricalFeatures = Set(1, 3))
+ }
+
+ test("Maintain sparsity for sparse vectors") {
+ def checkSparsity(data: DataFrame, maxCategories: Int): Unit = {
+ val points = data.collect().map(_.getAs[Vector](0))
+ val vectorIndexer = getIndexer.setMaxCategories(maxCategories)
+ val model = vectorIndexer.fit(data)
+ val indexedPoints = model.transform(data).select("indexed").map(_.getAs[Vector](0)).collect()
+ points.zip(indexedPoints).foreach {
+ case (orig: SparseVector, indexed: SparseVector) =>
+ assert(orig.indices.length == indexed.indices.length)
+ case _ => throw new UnknownError("Unit test has a bug in it.") // should never happen
+ }
+ }
+ checkSparsity(sparsePoints1, maxCategories = 2)
+ checkSparsity(sparsePoints2, maxCategories = 2)
+ }
+
+ test("Preserve metadata") {
+ // For continuous features, preserve name and stats.
+ val featureAttributes: Array[Attribute] = point1maxes.zipWithIndex.map { case (maxVal, i) =>
+ NumericAttribute.defaultAttr.withName(i.toString).withMax(maxVal)
+ }
+ val attrGroup = new AttributeGroup("features", featureAttributes)
+ val densePoints1WithMeta =
+ densePoints1.select(densePoints1("features").as("features", attrGroup.toMetadata))
+ val vectorIndexer = getIndexer.setMaxCategories(2)
+ val model = vectorIndexer.fit(densePoints1WithMeta)
+ // Check that ML metadata are preserved.
+ val indexedPoints = model.transform(densePoints1WithMeta)
+ val transAttributes: Array[Attribute] =
+ AttributeGroup.fromStructField(indexedPoints.schema("indexed")).attributes.get
+ featureAttributes.zip(transAttributes).foreach { case (orig, trans) =>
+ assert(orig.name === trans.name)
+ (orig, trans) match {
+ case (orig: NumericAttribute, trans: NumericAttribute) =>
+ assert(orig.max.nonEmpty && orig.max === trans.max)
+ case _ =>
+ // do nothing
+ // TODO: Once input features marked as categorical are handled correctly, check that here.
+ }
+ }
+ // Check that non-ML metadata are preserved.
+ TestingUtils.testPreserveMetadata(densePoints1WithMeta, model, "features", "indexed")
+ }
+}
+
+private[feature] object VectorIndexerSuite {
+ @BeanInfo
+ case class FeatureData(@BeanProperty features: Vector)
+}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/util/TestingUtils.scala b/mllib/src/test/scala/org/apache/spark/ml/util/TestingUtils.scala
new file mode 100644
index 0000000000..c44cb61b34
--- /dev/null
+++ b/mllib/src/test/scala/org/apache/spark/ml/util/TestingUtils.scala
@@ -0,0 +1,60 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.util
+
+import org.apache.spark.ml.Transformer
+import org.apache.spark.sql.DataFrame
+import org.apache.spark.sql.types.MetadataBuilder
+import org.scalatest.FunSuite
+
+private[ml] object TestingUtils extends FunSuite {
+
+ /**
+ * Test whether unrelated metadata are preserved for this transformer.
+ * This attaches extra metadata to a column, transforms the column, and check to ensure the
+ * extra metadata have not changed.
+ * @param data Input dataset
+ * @param transformer Transformer to test
+ * @param inputCol Unique input column for Transformer. This must be the ONLY input column.
+ * @param outputCol Output column to test for metadata presence.
+ */
+ def testPreserveMetadata(
+ data: DataFrame,
+ transformer: Transformer,
+ inputCol: String,
+ outputCol: String): Unit = {
+ // Create some fake metadata
+ val origMetadata = data.schema(inputCol).metadata
+ val metaKey = "__testPreserveMetadata__fake_key"
+ val metaValue = 12345
+ assert(!origMetadata.contains(metaKey),
+ s"Unit test with testPreserveMetadata will fail since metadata key was present: $metaKey")
+ val newMetadata =
+ new MetadataBuilder().withMetadata(origMetadata).putLong(metaKey, metaValue).build()
+ // Add metadata to the inputCol
+ val withMetadata = data.select(data(inputCol).as(inputCol, newMetadata))
+ // Transform, and ensure extra metadata was not affected
+ val transformed = transformer.transform(withMetadata)
+ val transMetadata = transformed.schema(outputCol).metadata
+ assert(transMetadata.contains(metaKey),
+ "Unit test with testPreserveMetadata failed; extra metadata key was not present.")
+ assert(transMetadata.getLong(metaKey) === metaValue,
+ "Unit test with testPreserveMetadata failed; extra metadata value was wrong." +
+ s" Expected $metaValue but found ${transMetadata.getLong(metaKey)}")
+ }
+}