aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
authorJoseph K. Bradley <joseph@databricks.com>2015-04-29 16:35:17 -0700
committerJoseph K. Bradley <joseph@databricks.com>2015-04-29 16:35:17 -0700
commitb1ef6a60ff6ea2adb43c6544e5311c11f4364f64 (patch)
tree2e20c40e2ae65b6c657f8361f74eab881c011253 /mllib
parentf8cbb0a4b37b0d4ba49515d888cb52dea9eb01f1 (diff)
downloadspark-b1ef6a60ff6ea2adb43c6544e5311c11f4364f64.tar.gz
spark-b1ef6a60ff6ea2adb43c6544e5311c11f4364f64.tar.bz2
spark-b1ef6a60ff6ea2adb43c6544e5311c11f4364f64.zip
[SPARK-7259] [ML] VectorIndexer: do not copy non-ML metadata to output column
Changed VectorIndexer so it does not carry non-ML metadata from the input to the output column. Removed ml.util.TestingUtils since VectorIndexer was the only use. CC: mengxr Author: Joseph K. Bradley <joseph@databricks.com> Closes #5789 from jkbradley/vector-indexer-metadata and squashes the following commits: b28e159 [Joseph K. Bradley] Changed VectorIndexer so it does not carry non-ML metadata from the input to the output column. Removed ml.util.TestingUtils since VectorIndexer was the only use.
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala69
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/feature/VectorIndexerSuite.scala7
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/util/TestingUtils.scala60
3 files changed, 37 insertions, 99 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala
index 452faa06e2..1e5ffd15af 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala
@@ -233,6 +233,7 @@ private object VectorIndexer {
* - Continuous features (columns) are left unchanged.
* This also appends metadata to the output column, marking features as Numeric (continuous),
* Nominal (categorical), or Binary (either continuous or categorical).
+ * Non-ML metadata is not carried over from the input to the output column.
*
* This maintains vector sparsity.
*
@@ -283,34 +284,40 @@ class VectorIndexerModel private[ml] (
// TODO: Check more carefully about whether this whole class will be included in a closure.
+ /** Per-vector transform function */
private val transformFunc: Vector => Vector = {
- val sortedCategoricalFeatureIndices = categoryMaps.keys.toArray.sorted
+ val sortedCatFeatureIndices = categoryMaps.keys.toArray.sorted
val localVectorMap = categoryMaps
- val f: Vector => Vector = {
- case dv: DenseVector =>
- val tmpv = dv.copy
- localVectorMap.foreach { case (featureIndex: Int, categoryMap: Map[Double, Int]) =>
- tmpv.values(featureIndex) = categoryMap(tmpv(featureIndex))
- }
- tmpv
- case sv: SparseVector =>
- // We use the fact that categorical value 0 is always mapped to index 0.
- val tmpv = sv.copy
- var catFeatureIdx = 0 // index into sortedCategoricalFeatureIndices
- var k = 0 // index into non-zero elements of sparse vector
- while (catFeatureIdx < sortedCategoricalFeatureIndices.length && k < tmpv.indices.length) {
- val featureIndex = sortedCategoricalFeatureIndices(catFeatureIdx)
- if (featureIndex < tmpv.indices(k)) {
- catFeatureIdx += 1
- } else if (featureIndex > tmpv.indices(k)) {
- k += 1
- } else {
- tmpv.values(k) = localVectorMap(featureIndex)(tmpv.values(k))
- catFeatureIdx += 1
- k += 1
+ val localNumFeatures = numFeatures
+ val f: Vector => Vector = { (v: Vector) =>
+ assert(v.size == localNumFeatures, "VectorIndexerModel expected vector of length" +
+ s" $numFeatures but found length ${v.size}")
+ v match {
+ case dv: DenseVector =>
+ val tmpv = dv.copy
+ localVectorMap.foreach { case (featureIndex: Int, categoryMap: Map[Double, Int]) =>
+ tmpv.values(featureIndex) = categoryMap(tmpv(featureIndex))
}
- }
- tmpv
+ tmpv
+ case sv: SparseVector =>
+ // We use the fact that categorical value 0 is always mapped to index 0.
+ val tmpv = sv.copy
+ var catFeatureIdx = 0 // index into sortedCatFeatureIndices
+ var k = 0 // index into non-zero elements of sparse vector
+ while (catFeatureIdx < sortedCatFeatureIndices.length && k < tmpv.indices.length) {
+ val featureIndex = sortedCatFeatureIndices(catFeatureIdx)
+ if (featureIndex < tmpv.indices(k)) {
+ catFeatureIdx += 1
+ } else if (featureIndex > tmpv.indices(k)) {
+ k += 1
+ } else {
+ tmpv.values(k) = localVectorMap(featureIndex)(tmpv.values(k))
+ catFeatureIdx += 1
+ k += 1
+ }
+ }
+ tmpv
+ }
}
f
}
@@ -326,13 +333,6 @@ class VectorIndexerModel private[ml] (
val map = extractParamMap(paramMap)
val newField = prepOutputField(dataset.schema, map)
val newCol = callUDF(transformFunc, new VectorUDT, dataset(map(inputCol)))
- // For now, just check the first row of inputCol for vector length.
- val firstRow = dataset.select(map(inputCol)).take(1)
- if (firstRow.length != 0) {
- val actualNumFeatures = firstRow(0).getAs[Vector](0).size
- require(numFeatures == actualNumFeatures, "VectorIndexerModel expected vector of length" +
- s" $numFeatures but found length $actualNumFeatures")
- }
dataset.withColumn(map(outputCol), newCol.as(map(outputCol), newField.metadata))
}
@@ -345,6 +345,7 @@ class VectorIndexerModel private[ml] (
s"VectorIndexerModel requires output column parameter: $outputCol")
SchemaUtils.checkColumnType(schema, map(inputCol), dataType)
+ // If the input metadata specifies numFeatures, compare with expected numFeatures.
val origAttrGroup = AttributeGroup.fromStructField(schema(map(inputCol)))
val origNumFeatures: Option[Int] = if (origAttrGroup.attributes.nonEmpty) {
Some(origAttrGroup.attributes.get.length)
@@ -364,7 +365,7 @@ class VectorIndexerModel private[ml] (
* Prepare the output column field, including per-feature metadata.
* @param schema Input schema
* @param map Parameter map (with this class' embedded parameter map folded in)
- * @return Output column field
+ * @return Output column field. This field does not contain non-ML metadata.
*/
private def prepOutputField(schema: StructType, map: ParamMap): StructField = {
val origAttrGroup = AttributeGroup.fromStructField(schema(map(inputCol)))
@@ -391,6 +392,6 @@ class VectorIndexerModel private[ml] (
partialFeatureAttributes
}
val newAttributeGroup = new AttributeGroup(map(outputCol), featureAttributes)
- newAttributeGroup.toStructField(schema(map(inputCol)).metadata)
+ newAttributeGroup.toStructField()
}
}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorIndexerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorIndexerSuite.scala
index 1b261b2643..38dc83b124 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorIndexerSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorIndexerSuite.scala
@@ -23,7 +23,6 @@ import org.scalatest.FunSuite
import org.apache.spark.SparkException
import org.apache.spark.ml.attribute._
-import org.apache.spark.ml.util.TestingUtils
import org.apache.spark.mllib.linalg.{SparseVector, Vector, Vectors}
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.rdd.RDD
@@ -111,8 +110,8 @@ class VectorIndexerSuite extends FunSuite with MLlibTestSparkContext {
val model = vectorIndexer.fit(densePoints1) // vectors of length 3
model.transform(densePoints1) // should work
model.transform(sparsePoints1) // should work
- intercept[IllegalArgumentException] {
- model.transform(densePoints2)
+ intercept[SparkException] {
+ model.transform(densePoints2).collect()
println("Did not throw error when fit, transform were called on vectors of different lengths")
}
intercept[SparkException] {
@@ -245,8 +244,6 @@ class VectorIndexerSuite extends FunSuite with MLlibTestSparkContext {
// TODO: Once input features marked as categorical are handled correctly, check that here.
}
}
- // Check that non-ML metadata are preserved.
- TestingUtils.testPreserveMetadata(densePoints1WithMeta, model, "features", "indexed")
}
}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/util/TestingUtils.scala b/mllib/src/test/scala/org/apache/spark/ml/util/TestingUtils.scala
deleted file mode 100644
index c44cb61b34..0000000000
--- a/mllib/src/test/scala/org/apache/spark/ml/util/TestingUtils.scala
+++ /dev/null
@@ -1,60 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.ml.util
-
-import org.apache.spark.ml.Transformer
-import org.apache.spark.sql.DataFrame
-import org.apache.spark.sql.types.MetadataBuilder
-import org.scalatest.FunSuite
-
-private[ml] object TestingUtils extends FunSuite {
-
- /**
- * Test whether unrelated metadata are preserved for this transformer.
- * This attaches extra metadata to a column, transforms the column, and check to ensure the
- * extra metadata have not changed.
- * @param data Input dataset
- * @param transformer Transformer to test
- * @param inputCol Unique input column for Transformer. This must be the ONLY input column.
- * @param outputCol Output column to test for metadata presence.
- */
- def testPreserveMetadata(
- data: DataFrame,
- transformer: Transformer,
- inputCol: String,
- outputCol: String): Unit = {
- // Create some fake metadata
- val origMetadata = data.schema(inputCol).metadata
- val metaKey = "__testPreserveMetadata__fake_key"
- val metaValue = 12345
- assert(!origMetadata.contains(metaKey),
- s"Unit test with testPreserveMetadata will fail since metadata key was present: $metaKey")
- val newMetadata =
- new MetadataBuilder().withMetadata(origMetadata).putLong(metaKey, metaValue).build()
- // Add metadata to the inputCol
- val withMetadata = data.select(data(inputCol).as(inputCol, newMetadata))
- // Transform, and ensure extra metadata was not affected
- val transformed = transformer.transform(withMetadata)
- val transMetadata = transformed.schema(outputCol).metadata
- assert(transMetadata.contains(metaKey),
- "Unit test with testPreserveMetadata failed; extra metadata key was not present.")
- assert(transMetadata.getLong(metaKey) === metaValue,
- "Unit test with testPreserveMetadata failed; extra metadata value was wrong." +
- s" Expected $metaValue but found ${transMetadata.getLong(metaKey)}")
- }
-}