aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/test/scala
diff options
context:
space:
mode:
authorseddonm1 <seddonm1@gmail.com>2016-02-15 20:15:27 -0800
committerXiangrui Meng <meng@databricks.com>2016-02-15 20:15:27 -0800
commitcbeb006f23838b2f19e700e20b25003aeb3dfb01 (patch)
tree1977fd104bca7031f3fc86b5e2a14928c4ca0cde /mllib/src/test/scala
parentadb548365012552e991d51740bfd3c25abf0adec (diff)
downloadspark-cbeb006f23838b2f19e700e20b25003aeb3dfb01.tar.gz
spark-cbeb006f23838b2f19e700e20b25003aeb3dfb01.tar.bz2
spark-cbeb006f23838b2f19e700e20b25003aeb3dfb01.zip
[SPARK-13097][ML] Binarizer allowing Double AND Vector input types
This enhancement extends the existing SparkML Binarizer [SPARK-5891] to allow Vector in addition to the existing Double input column type. A use case for this enhancement is for when a user wants to Binarize many similar feature columns at once using the same threshold value (for example a binary threshold applied to many pixels in an image). This contribution is my original work and I license the work to the project under the project's open source license. viirya mengxr Author: seddonm1 <seddonm1@gmail.com> Closes #10976 from seddonm1/master.
Diffstat (limited to 'mllib/src/test/scala')
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/feature/BinarizerSuite.scala36
1 files changed, 36 insertions, 0 deletions
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/BinarizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/BinarizerSuite.scala
index 6d2d8fe714..714b9db3aa 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/BinarizerSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/BinarizerSuite.scala
@@ -20,6 +20,7 @@ package org.apache.spark.ml.feature
import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.param.ParamsSuite
import org.apache.spark.ml.util.DefaultReadWriteTest
+import org.apache.spark.mllib.linalg.{Vector, Vectors}
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.sql.{DataFrame, Row}
@@ -68,6 +69,41 @@ class BinarizerSuite extends SparkFunSuite with MLlibTestSparkContext with Defau
}
}
+ test("Binarize vector of continuous features with default parameter") {
+ val defaultBinarized: Array[Double] = data.map(x => if (x > 0.0) 1.0 else 0.0)
+ val dataFrame: DataFrame = sqlContext.createDataFrame(Seq(
+ (Vectors.dense(data), Vectors.dense(defaultBinarized))
+ )).toDF("feature", "expected")
+
+ val binarizer: Binarizer = new Binarizer()
+ .setInputCol("feature")
+ .setOutputCol("binarized_feature")
+
+ binarizer.transform(dataFrame).select("binarized_feature", "expected").collect().foreach {
+ case Row(x: Vector, y: Vector) =>
+ assert(x == y, "The feature value is not correct after binarization.")
+ }
+ }
+
+ test("Binarize vector of continuous features with setter") {
+ val threshold: Double = 0.2
+ val defaultBinarized: Array[Double] = data.map(x => if (x > threshold) 1.0 else 0.0)
+ val dataFrame: DataFrame = sqlContext.createDataFrame(Seq(
+ (Vectors.dense(data), Vectors.dense(defaultBinarized))
+ )).toDF("feature", "expected")
+
+ val binarizer: Binarizer = new Binarizer()
+ .setInputCol("feature")
+ .setOutputCol("binarized_feature")
+ .setThreshold(threshold)
+
+ binarizer.transform(dataFrame).select("binarized_feature", "expected").collect().foreach {
+ case Row(x: Vector, y: Vector) =>
+ assert(x == y, "The feature value is not correct after binarization.")
+ }
+ }
+
+
test("read/write") {
val t = new Binarizer()
.setInputCol("myInputCol")