aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/test
diff options
context:
space:
mode:
authorJoseph K. Bradley <joseph@databricks.com>2015-04-12 22:38:27 -0700
committerXiangrui Meng <meng@databricks.com>2015-04-12 22:38:27 -0700
commitd3792f54974e16cbe8f10b3091d248e0bdd48986 (patch)
tree89d679f7a9f76599841f169239021f190968654b /mllib/src/test
parentfc17661475443d9f0a8d28e3439feeb7a7bca67b (diff)
downloadspark-d3792f54974e16cbe8f10b3091d248e0bdd48986.tar.gz
spark-d3792f54974e16cbe8f10b3091d248e0bdd48986.tar.bz2
spark-d3792f54974e16cbe8f10b3091d248e0bdd48986.zip
[SPARK-4081] [mllib] VectorIndexer
**Ready for review!** Since the original PR, I moved the code to the spark.ml API and renamed this to VectorIndexer. This introduces a VectorIndexer class which does the following: * VectorIndexer.fit(): collect statistics about how many values each feature in a dataset (RDD[Vector]) can take (limited by maxCategories) * Feature which exceed maxCategories are declared continuous, and the Model will treat them as such. * VectorIndexerModel.transform(): Convert categorical feature values to corresponding 0-based indices Design notes: * This maintains sparsity in vectors by ensuring that categorical feature value 0.0 gets index 0. * This does not yet support transforming data with new (unknown) categorical feature values. That can be added later. * This is necessary for DecisionTree and tree ensembles. Reviewers: Please check my use of metadata and my unit tests for it; I'm not sure if I covered everything in the tests. Other notes: * This also adds a public toMetadata method to AttributeGroup (for simpler construction of metadata). CC: mengxr Author: Joseph K. Bradley <joseph@databricks.com> Closes #3000 from jkbradley/indexer and squashes the following commits: 5956d91 [Joseph K. Bradley] minor cleanups f5c57a8 [Joseph K. Bradley] added Java test suite 643b444 [Joseph K. Bradley] removed FeatureTests 02236c3 [Joseph K. Bradley] Updated VectorIndexer, ready for PR 286d221 [Joseph K. Bradley] Reworked DatasetIndexer for spark.ml API, and renamed it to VectorIndexer 12e6cf2 [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into indexer 6d8f3f1 [Joseph K. Bradley] Added partly done DatasetIndexer to spark.ml 6a2f553 [Joseph K. Bradley] Updated TODO for allowUnknownCategories 3f041f8 [Joseph K. Bradley] Final cleanups for DatasetIndexer 038b9e3 [Joseph K. Bradley] DatasetIndexer now maintains sparsity in SparseVector 3a4a0bd [Joseph K. Bradley] Added another test for DatasetIndexer 2006923 [Joseph K. Bradley] DatasetIndexer now passes tests f409987 [Joseph K. Bradley] partly done with DatasetIndexerSuite 5e7c874 [Joseph K. Bradley] working on DatasetIndexer
Diffstat (limited to 'mllib/src/test')
-rw-r--r--mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorIndexerSuite.java70
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/attribute/AttributeGroupSuite.scala8
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/feature/NormalizerSuite.scala7
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/feature/VectorIndexerSuite.scala255
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/util/TestingUtils.scala60
5 files changed, 394 insertions, 6 deletions
diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorIndexerSuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorIndexerSuite.java
new file mode 100644
index 0000000000..161100134c
--- /dev/null
+++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorIndexerSuite.java
@@ -0,0 +1,70 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.feature;
+
+import java.io.Serializable;
+import java.util.List;
+
+import org.junit.After;
+import org.junit.Assert;
+import org.junit.Before;
+import org.junit.Test;
+
+import com.google.common.collect.Lists;
+
+import org.apache.spark.api.java.JavaSparkContext;
+import org.apache.spark.ml.feature.VectorIndexerSuite.FeatureData;
+import org.apache.spark.mllib.linalg.Vectors;
+import org.apache.spark.sql.DataFrame;
+import org.apache.spark.sql.SQLContext;
+
+
+public class JavaVectorIndexerSuite implements Serializable {
+ private transient JavaSparkContext sc;
+
+ @Before
+ public void setUp() {
+ sc = new JavaSparkContext("local", "JavaVectorIndexerSuite");
+ }
+
+ @After
+ public void tearDown() {
+ sc.stop();
+ sc = null;
+ }
+
+ @Test
+ public void vectorIndexerAPI() {
+ // The tests are to check Java compatibility.
+ List<FeatureData> points = Lists.newArrayList(
+ new FeatureData(Vectors.dense(0.0, -2.0)),
+ new FeatureData(Vectors.dense(1.0, 3.0)),
+ new FeatureData(Vectors.dense(1.0, 4.0))
+ );
+ SQLContext sqlContext = new SQLContext(sc);
+ DataFrame data = sqlContext.createDataFrame(sc.parallelize(points, 2), FeatureData.class);
+ VectorIndexer indexer = new VectorIndexer()
+ .setInputCol("features")
+ .setOutputCol("indexed")
+ .setMaxCategories(2);
+ VectorIndexerModel model = indexer.fit(data);
+ Assert.assertEquals(model.numFeatures(), 2);
+ Assert.assertEquals(model.categoryMaps().size(), 1);
+ DataFrame indexedData = model.transform(data);
+ }
+}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/attribute/AttributeGroupSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/attribute/AttributeGroupSuite.scala
index 3fb6e2ec46..0dcfe5a200 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/attribute/AttributeGroupSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/attribute/AttributeGroupSuite.scala
@@ -43,8 +43,8 @@ class AttributeGroupSuite extends FunSuite {
intercept[NoSuchElementException] {
group("abc")
}
- assert(group === AttributeGroup.fromMetadata(group.toMetadata, group.name))
- assert(group === AttributeGroup.fromStructField(group.toStructField()))
+ assert(group === AttributeGroup.fromMetadata(group.toMetadataImpl, group.name))
+ assert(group === AttributeGroup.fromStructField(group.toStructField))
}
test("attribute group without attributes") {
@@ -53,8 +53,8 @@ class AttributeGroupSuite extends FunSuite {
assert(group0.numAttributes === Some(10))
assert(group0.size === 10)
assert(group0.attributes.isEmpty)
- assert(group0 === AttributeGroup.fromMetadata(group0.toMetadata, group0.name))
- assert(group0 === AttributeGroup.fromStructField(group0.toStructField()))
+ assert(group0 === AttributeGroup.fromMetadata(group0.toMetadataImpl, group0.name))
+ assert(group0 === AttributeGroup.fromStructField(group0.toStructField))
val group1 = new AttributeGroup("item")
assert(group1.name === "item")
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/NormalizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/NormalizerSuite.scala
index a18c335952..9d09f24709 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/NormalizerSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/NormalizerSuite.scala
@@ -24,7 +24,6 @@ import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.mllib.util.TestingUtils._
import org.apache.spark.sql.{DataFrame, Row, SQLContext}
-private case class DataSet(features: Vector)
class NormalizerSuite extends FunSuite with MLlibTestSparkContext {
@@ -63,7 +62,7 @@ class NormalizerSuite extends FunSuite with MLlibTestSparkContext {
)
val sqlContext = new SQLContext(sc)
- dataFrame = sqlContext.createDataFrame(sc.parallelize(data, 2).map(DataSet))
+ dataFrame = sqlContext.createDataFrame(sc.parallelize(data, 2).map(NormalizerSuite.FeatureData))
normalizer = new Normalizer()
.setInputCol("features")
.setOutputCol("normalized_features")
@@ -107,3 +106,7 @@ class NormalizerSuite extends FunSuite with MLlibTestSparkContext {
assertValues(result, l1Normalized)
}
}
+
+private object NormalizerSuite {
+ case class FeatureData(features: Vector)
+}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorIndexerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorIndexerSuite.scala
new file mode 100644
index 0000000000..61c46c85a7
--- /dev/null
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorIndexerSuite.scala
@@ -0,0 +1,255 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.feature
+
+import scala.beans.{BeanInfo, BeanProperty}
+
+import org.scalatest.FunSuite
+
+import org.apache.spark.SparkException
+import org.apache.spark.ml.attribute._
+import org.apache.spark.ml.util.TestingUtils
+import org.apache.spark.mllib.linalg.{SparseVector, Vector, Vectors}
+import org.apache.spark.mllib.util.MLlibTestSparkContext
+import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.{DataFrame, SQLContext}
+
+
+class VectorIndexerSuite extends FunSuite with MLlibTestSparkContext {
+
+ import VectorIndexerSuite.FeatureData
+
+ @transient var sqlContext: SQLContext = _
+
+ // identical, of length 3
+ @transient var densePoints1: DataFrame = _
+ @transient var sparsePoints1: DataFrame = _
+ @transient var point1maxes: Array[Double] = _
+
+ // identical, of length 2
+ @transient var densePoints2: DataFrame = _
+ @transient var sparsePoints2: DataFrame = _
+
+ // different lengths
+ @transient var badPoints: DataFrame = _
+
+ override def beforeAll(): Unit = {
+ super.beforeAll()
+
+ val densePoints1Seq = Seq(
+ Vectors.dense(1.0, 2.0, 0.0),
+ Vectors.dense(0.0, 1.0, 2.0),
+ Vectors.dense(0.0, 0.0, -1.0),
+ Vectors.dense(1.0, 3.0, 2.0))
+ val sparsePoints1Seq = Seq(
+ Vectors.sparse(3, Array(0, 1), Array(1.0, 2.0)),
+ Vectors.sparse(3, Array(1, 2), Array(1.0, 2.0)),
+ Vectors.sparse(3, Array(2), Array(-1.0)),
+ Vectors.sparse(3, Array(0, 1, 2), Array(1.0, 3.0, 2.0)))
+ point1maxes = Array(1.0, 3.0, 2.0)
+
+ val densePoints2Seq = Seq(
+ Vectors.dense(1.0, 1.0, 0.0, 1.0),
+ Vectors.dense(0.0, 1.0, 1.0, 1.0),
+ Vectors.dense(-1.0, 1.0, 2.0, 0.0))
+ val sparsePoints2Seq = Seq(
+ Vectors.sparse(4, Array(0, 1, 3), Array(1.0, 1.0, 1.0)),
+ Vectors.sparse(4, Array(1, 2, 3), Array(1.0, 1.0, 1.0)),
+ Vectors.sparse(4, Array(0, 1, 2), Array(-1.0, 1.0, 2.0)))
+
+ val badPointsSeq = Seq(
+ Vectors.sparse(2, Array(0, 1), Array(1.0, 1.0)),
+ Vectors.sparse(3, Array(2), Array(-1.0)))
+
+ // Sanity checks for assumptions made in tests
+ assert(densePoints1Seq.head.size == sparsePoints1Seq.head.size)
+ assert(densePoints2Seq.head.size == sparsePoints2Seq.head.size)
+ assert(densePoints1Seq.head.size != densePoints2Seq.head.size)
+ def checkPair(dvSeq: Seq[Vector], svSeq: Seq[Vector]): Unit = {
+ assert(dvSeq.zip(svSeq).forall { case (dv, sv) => dv.toArray === sv.toArray },
+ "typo in unit test")
+ }
+ checkPair(densePoints1Seq, sparsePoints1Seq)
+ checkPair(densePoints2Seq, sparsePoints2Seq)
+
+ sqlContext = new SQLContext(sc)
+ densePoints1 = sqlContext.createDataFrame(sc.parallelize(densePoints1Seq, 2).map(FeatureData))
+ sparsePoints1 = sqlContext.createDataFrame(sc.parallelize(sparsePoints1Seq, 2).map(FeatureData))
+ densePoints2 = sqlContext.createDataFrame(sc.parallelize(densePoints2Seq, 2).map(FeatureData))
+ sparsePoints2 = sqlContext.createDataFrame(sc.parallelize(sparsePoints2Seq, 2).map(FeatureData))
+ badPoints = sqlContext.createDataFrame(sc.parallelize(badPointsSeq, 2).map(FeatureData))
+ }
+
+ private def getIndexer: VectorIndexer =
+ new VectorIndexer().setInputCol("features").setOutputCol("indexed")
+
+ test("Cannot fit an empty DataFrame") {
+ val rdd = sqlContext.createDataFrame(sc.parallelize(Array.empty[Vector], 2).map(FeatureData))
+ val vectorIndexer = getIndexer
+ intercept[IllegalArgumentException] {
+ vectorIndexer.fit(rdd)
+ }
+ }
+
+ test("Throws error when given RDDs with different size vectors") {
+ val vectorIndexer = getIndexer
+ val model = vectorIndexer.fit(densePoints1) // vectors of length 3
+ model.transform(densePoints1) // should work
+ model.transform(sparsePoints1) // should work
+ intercept[IllegalArgumentException] {
+ model.transform(densePoints2)
+ println("Did not throw error when fit, transform were called on vectors of different lengths")
+ }
+ intercept[SparkException] {
+ vectorIndexer.fit(badPoints)
+ println("Did not throw error when fitting vectors of different lengths in same RDD.")
+ }
+ }
+
+ test("Same result with dense and sparse vectors") {
+ def testDenseSparse(densePoints: DataFrame, sparsePoints: DataFrame): Unit = {
+ val denseVectorIndexer = getIndexer.setMaxCategories(2)
+ val sparseVectorIndexer = getIndexer.setMaxCategories(2)
+ val denseModel = denseVectorIndexer.fit(densePoints)
+ val sparseModel = sparseVectorIndexer.fit(sparsePoints)
+ val denseMap = denseModel.categoryMaps
+ val sparseMap = sparseModel.categoryMaps
+ assert(denseMap.keys.toSet == sparseMap.keys.toSet,
+ "Categorical features chosen from dense vs. sparse vectors did not match.")
+ assert(denseMap == sparseMap,
+ "Categorical feature value indexes chosen from dense vs. sparse vectors did not match.")
+ }
+ testDenseSparse(densePoints1, sparsePoints1)
+ testDenseSparse(densePoints2, sparsePoints2)
+ }
+
+ test("Builds valid categorical feature value index, transform correctly, check metadata") {
+ def checkCategoryMaps(
+ data: DataFrame,
+ maxCategories: Int,
+ categoricalFeatures: Set[Int]): Unit = {
+ val collectedData = data.collect().map(_.getAs[Vector](0))
+ val errMsg = s"checkCategoryMaps failed for input with maxCategories=$maxCategories," +
+ s" categoricalFeatures=${categoricalFeatures.mkString(", ")}"
+ try {
+ val vectorIndexer = getIndexer.setMaxCategories(maxCategories)
+ val model = vectorIndexer.fit(data)
+ val categoryMaps = model.categoryMaps
+ assert(categoryMaps.keys.toSet === categoricalFeatures) // Chose correct categorical features
+ val transformed = model.transform(data).select("indexed")
+ val indexedRDD: RDD[Vector] = transformed.map(_.getAs[Vector](0))
+ val featureAttrs = AttributeGroup.fromStructField(transformed.schema("indexed"))
+ assert(featureAttrs.name === "indexed")
+ assert(featureAttrs.attributes.get.length === model.numFeatures)
+ categoricalFeatures.foreach { feature: Int =>
+ val origValueSet = collectedData.map(_(feature)).toSet
+ val targetValueIndexSet = Range(0, origValueSet.size).toSet
+ val catMap = categoryMaps(feature)
+ assert(catMap.keys.toSet === origValueSet) // Correct categories
+ assert(catMap.values.toSet === targetValueIndexSet) // Correct category indices
+ if (origValueSet.contains(0.0)) {
+ assert(catMap(0.0) === 0) // value 0 gets index 0
+ }
+ // Check transformed data
+ assert(indexedRDD.map(_(feature)).collect().toSet === targetValueIndexSet)
+ // Check metadata
+ val featureAttr = featureAttrs(feature)
+ assert(featureAttr.index.get === feature)
+ featureAttr match {
+ case attr: BinaryAttribute =>
+ assert(attr.values.get === origValueSet.toArray.sorted.map(_.toString))
+ case attr: NominalAttribute =>
+ assert(attr.values.get === origValueSet.toArray.sorted.map(_.toString))
+ assert(attr.isOrdinal.get === false)
+ case _ =>
+ throw new RuntimeException(errMsg + s". Categorical feature $feature failed" +
+ s" metadata check. Found feature attribute: $featureAttr.")
+ }
+ }
+ // Check numerical feature metadata.
+ Range(0, model.numFeatures).filter(feature => !categoricalFeatures.contains(feature))
+ .foreach { feature: Int =>
+ val featureAttr = featureAttrs(feature)
+ featureAttr match {
+ case attr: NumericAttribute =>
+ assert(featureAttr.index.get === feature)
+ case _ =>
+ throw new RuntimeException(errMsg + s". Numerical feature $feature failed" +
+ s" metadata check. Found feature attribute: $featureAttr.")
+ }
+ }
+ } catch {
+ case e: org.scalatest.exceptions.TestFailedException =>
+ println(errMsg)
+ throw e
+ }
+ }
+ checkCategoryMaps(densePoints1, maxCategories = 2, categoricalFeatures = Set(0))
+ checkCategoryMaps(densePoints1, maxCategories = 3, categoricalFeatures = Set(0, 2))
+ checkCategoryMaps(densePoints2, maxCategories = 2, categoricalFeatures = Set(1, 3))
+ }
+
+ test("Maintain sparsity for sparse vectors") {
+ def checkSparsity(data: DataFrame, maxCategories: Int): Unit = {
+ val points = data.collect().map(_.getAs[Vector](0))
+ val vectorIndexer = getIndexer.setMaxCategories(maxCategories)
+ val model = vectorIndexer.fit(data)
+ val indexedPoints = model.transform(data).select("indexed").map(_.getAs[Vector](0)).collect()
+ points.zip(indexedPoints).foreach {
+ case (orig: SparseVector, indexed: SparseVector) =>
+ assert(orig.indices.length == indexed.indices.length)
+ case _ => throw new UnknownError("Unit test has a bug in it.") // should never happen
+ }
+ }
+ checkSparsity(sparsePoints1, maxCategories = 2)
+ checkSparsity(sparsePoints2, maxCategories = 2)
+ }
+
+ test("Preserve metadata") {
+ // For continuous features, preserve name and stats.
+ val featureAttributes: Array[Attribute] = point1maxes.zipWithIndex.map { case (maxVal, i) =>
+ NumericAttribute.defaultAttr.withName(i.toString).withMax(maxVal)
+ }
+ val attrGroup = new AttributeGroup("features", featureAttributes)
+ val densePoints1WithMeta =
+ densePoints1.select(densePoints1("features").as("features", attrGroup.toMetadata))
+ val vectorIndexer = getIndexer.setMaxCategories(2)
+ val model = vectorIndexer.fit(densePoints1WithMeta)
+ // Check that ML metadata are preserved.
+ val indexedPoints = model.transform(densePoints1WithMeta)
+ val transAttributes: Array[Attribute] =
+ AttributeGroup.fromStructField(indexedPoints.schema("indexed")).attributes.get
+ featureAttributes.zip(transAttributes).foreach { case (orig, trans) =>
+ assert(orig.name === trans.name)
+ (orig, trans) match {
+ case (orig: NumericAttribute, trans: NumericAttribute) =>
+ assert(orig.max.nonEmpty && orig.max === trans.max)
+ case _ =>
+ // do nothing
+ // TODO: Once input features marked as categorical are handled correctly, check that here.
+ }
+ }
+ // Check that non-ML metadata are preserved.
+ TestingUtils.testPreserveMetadata(densePoints1WithMeta, model, "features", "indexed")
+ }
+}
+
+private[feature] object VectorIndexerSuite {
+ @BeanInfo
+ case class FeatureData(@BeanProperty features: Vector)
+}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/util/TestingUtils.scala b/mllib/src/test/scala/org/apache/spark/ml/util/TestingUtils.scala
new file mode 100644
index 0000000000..c44cb61b34
--- /dev/null
+++ b/mllib/src/test/scala/org/apache/spark/ml/util/TestingUtils.scala
@@ -0,0 +1,60 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.util
+
+import org.apache.spark.ml.Transformer
+import org.apache.spark.sql.DataFrame
+import org.apache.spark.sql.types.MetadataBuilder
+import org.scalatest.FunSuite
+
+private[ml] object TestingUtils extends FunSuite {
+
+ /**
+ * Test whether unrelated metadata are preserved for this transformer.
+ * This attaches extra metadata to a column, transforms the column, and check to ensure the
+ * extra metadata have not changed.
+ * @param data Input dataset
+ * @param transformer Transformer to test
+ * @param inputCol Unique input column for Transformer. This must be the ONLY input column.
+ * @param outputCol Output column to test for metadata presence.
+ */
+ def testPreserveMetadata(
+ data: DataFrame,
+ transformer: Transformer,
+ inputCol: String,
+ outputCol: String): Unit = {
+ // Create some fake metadata
+ val origMetadata = data.schema(inputCol).metadata
+ val metaKey = "__testPreserveMetadata__fake_key"
+ val metaValue = 12345
+ assert(!origMetadata.contains(metaKey),
+ s"Unit test with testPreserveMetadata will fail since metadata key was present: $metaKey")
+ val newMetadata =
+ new MetadataBuilder().withMetadata(origMetadata).putLong(metaKey, metaValue).build()
+ // Add metadata to the inputCol
+ val withMetadata = data.select(data(inputCol).as(inputCol, newMetadata))
+ // Transform, and ensure extra metadata was not affected
+ val transformed = transformer.transform(withMetadata)
+ val transMetadata = transformed.schema(outputCol).metadata
+ assert(transMetadata.contains(metaKey),
+ "Unit test with testPreserveMetadata failed; extra metadata key was not present.")
+ assert(transMetadata.getLong(metaKey) === metaValue,
+ "Unit test with testPreserveMetadata failed; extra metadata value was wrong." +
+ s" Expected $metaValue but found ${transMetadata.getLong(metaKey)}")
+ }
+}