diff options
author | seddonm1 <seddonm1@gmail.com> | 2016-02-15 20:15:27 -0800 |
---|---|---|
committer | Xiangrui Meng <meng@databricks.com> | 2016-02-15 20:15:27 -0800 |
commit | cbeb006f23838b2f19e700e20b25003aeb3dfb01 (patch) | |
tree | 1977fd104bca7031f3fc86b5e2a14928c4ca0cde /mllib/src/test | |
parent | adb548365012552e991d51740bfd3c25abf0adec (diff) | |
download | spark-cbeb006f23838b2f19e700e20b25003aeb3dfb01.tar.gz spark-cbeb006f23838b2f19e700e20b25003aeb3dfb01.tar.bz2 spark-cbeb006f23838b2f19e700e20b25003aeb3dfb01.zip |
[SPARK-13097][ML] Binarizer allowing Double AND Vector input types
This enhancement extends the existing SparkML Binarizer [SPARK-5891] to allow Vector in addition to the existing Double input column type.
A use case for this enhancement is for when a user wants to Binarize many similar feature columns at once using the same threshold value (for example a binary threshold applied to many pixels in an image).
This contribution is my original work and I license the work to the project under the project's open source license.
viirya mengxr
Author: seddonm1 <seddonm1@gmail.com>
Closes #10976 from seddonm1/master.
Diffstat (limited to 'mllib/src/test')
-rw-r--r-- | mllib/src/test/scala/org/apache/spark/ml/feature/BinarizerSuite.scala | 36 |
1 files changed, 36 insertions, 0 deletions
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/BinarizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/BinarizerSuite.scala index 6d2d8fe714..714b9db3aa 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/BinarizerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/BinarizerSuite.scala @@ -20,6 +20,7 @@ package org.apache.spark.ml.feature import org.apache.spark.SparkFunSuite import org.apache.spark.ml.param.ParamsSuite import org.apache.spark.ml.util.DefaultReadWriteTest +import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.sql.{DataFrame, Row} @@ -68,6 +69,41 @@ class BinarizerSuite extends SparkFunSuite with MLlibTestSparkContext with Defau } } + test("Binarize vector of continuous features with default parameter") { + val defaultBinarized: Array[Double] = data.map(x => if (x > 0.0) 1.0 else 0.0) + val dataFrame: DataFrame = sqlContext.createDataFrame(Seq( + (Vectors.dense(data), Vectors.dense(defaultBinarized)) + )).toDF("feature", "expected") + + val binarizer: Binarizer = new Binarizer() + .setInputCol("feature") + .setOutputCol("binarized_feature") + + binarizer.transform(dataFrame).select("binarized_feature", "expected").collect().foreach { + case Row(x: Vector, y: Vector) => + assert(x == y, "The feature value is not correct after binarization.") + } + } + + test("Binarize vector of continuous features with setter") { + val threshold: Double = 0.2 + val defaultBinarized: Array[Double] = data.map(x => if (x > threshold) 1.0 else 0.0) + val dataFrame: DataFrame = sqlContext.createDataFrame(Seq( + (Vectors.dense(data), Vectors.dense(defaultBinarized)) + )).toDF("feature", "expected") + + val binarizer: Binarizer = new Binarizer() + .setInputCol("feature") + .setOutputCol("binarized_feature") + .setThreshold(threshold) + + binarizer.transform(dataFrame).select("binarized_feature", "expected").collect().foreach { + case Row(x: Vector, y: Vector) => + assert(x == y, "The feature value is not correct after binarization.") + } + } + + test("read/write") { val t = new Binarizer() .setInputCol("myInputCol") |