aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
authorseddonm1 <seddonm1@gmail.com>2016-02-15 20:15:27 -0800
committerXiangrui Meng <meng@databricks.com>2016-02-15 20:15:27 -0800
commitcbeb006f23838b2f19e700e20b25003aeb3dfb01 (patch)
tree1977fd104bca7031f3fc86b5e2a14928c4ca0cde /mllib
parentadb548365012552e991d51740bfd3c25abf0adec (diff)
downloadspark-cbeb006f23838b2f19e700e20b25003aeb3dfb01.tar.gz
spark-cbeb006f23838b2f19e700e20b25003aeb3dfb01.tar.bz2
spark-cbeb006f23838b2f19e700e20b25003aeb3dfb01.zip
[SPARK-13097][ML] Binarizer allowing Double AND Vector input types
This enhancement extends the existing SparkML Binarizer [SPARK-5891] to allow Vector in addition to the existing Double input column type. A use case for this enhancement is for when a user wants to Binarize many similar feature columns at once using the same threshold value (for example a binary threshold applied to many pixels in an image). This contribution is my original work and I license the work to the project under the project's open source license. viirya mengxr Author: seddonm1 <seddonm1@gmail.com> Closes #10976 from seddonm1/master.
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/Binarizer.scala62
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/feature/BinarizerSuite.scala36
2 files changed, 81 insertions, 17 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Binarizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Binarizer.scala
index 544cf05a30..2f8e3a0371 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/Binarizer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Binarizer.scala
@@ -17,15 +17,18 @@
package org.apache.spark.ml.feature
+import scala.collection.mutable.ArrayBuilder
+
import org.apache.spark.annotation.{Experimental, Since}
import org.apache.spark.ml.Transformer
import org.apache.spark.ml.attribute.BinaryAttribute
import org.apache.spark.ml.param._
import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol}
import org.apache.spark.ml.util._
+import org.apache.spark.mllib.linalg._
import org.apache.spark.sql._
import org.apache.spark.sql.functions._
-import org.apache.spark.sql.types.{DoubleType, StructType}
+import org.apache.spark.sql.types._
/**
* :: Experimental ::
@@ -62,28 +65,53 @@ final class Binarizer(override val uid: String)
def setOutputCol(value: String): this.type = set(outputCol, value)
override def transform(dataset: DataFrame): DataFrame = {
- transformSchema(dataset.schema, logging = true)
+ val outputSchema = transformSchema(dataset.schema, logging = true)
+ val schema = dataset.schema
+ val inputType = schema($(inputCol)).dataType
val td = $(threshold)
- val binarizer = udf { in: Double => if (in > td) 1.0 else 0.0 }
- val outputColName = $(outputCol)
- val metadata = BinaryAttribute.defaultAttr.withName(outputColName).toMetadata()
- dataset.select(col("*"),
- binarizer(col($(inputCol))).as(outputColName, metadata))
+
+ val binarizerDouble = udf { in: Double => if (in > td) 1.0 else 0.0 }
+ val binarizerVector = udf { (data: Vector) =>
+ val indices = ArrayBuilder.make[Int]
+ val values = ArrayBuilder.make[Double]
+
+ data.foreachActive { (index, value) =>
+ if (value > td) {
+ indices += index
+ values += 1.0
+ }
+ }
+
+ Vectors.sparse(data.size, indices.result(), values.result()).compressed
+ }
+
+ val metadata = outputSchema($(outputCol)).metadata
+
+ inputType match {
+ case DoubleType =>
+ dataset.select(col("*"), binarizerDouble(col($(inputCol))).as($(outputCol), metadata))
+ case _: VectorUDT =>
+ dataset.select(col("*"), binarizerVector(col($(inputCol))).as($(outputCol), metadata))
+ }
}
override def transformSchema(schema: StructType): StructType = {
- validateParams()
- SchemaUtils.checkColumnType(schema, $(inputCol), DoubleType)
-
- val inputFields = schema.fields
+ val inputType = schema($(inputCol)).dataType
val outputColName = $(outputCol)
- require(inputFields.forall(_.name != outputColName),
- s"Output column $outputColName already exists.")
-
- val attr = BinaryAttribute.defaultAttr.withName(outputColName)
- val outputFields = inputFields :+ attr.toStructField()
- StructType(outputFields)
+ val outCol: StructField = inputType match {
+ case DoubleType =>
+ BinaryAttribute.defaultAttr.withName(outputColName).toStructField()
+ case _: VectorUDT =>
+ new StructField(outputColName, new VectorUDT, true)
+ case other =>
+ throw new IllegalArgumentException(s"Data type $other is not supported.")
+ }
+
+ if (schema.fieldNames.contains(outputColName)) {
+ throw new IllegalArgumentException(s"Output column $outputColName already exists.")
+ }
+ StructType(schema.fields :+ outCol)
}
override def copy(extra: ParamMap): Binarizer = defaultCopy(extra)
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/BinarizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/BinarizerSuite.scala
index 6d2d8fe714..714b9db3aa 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/BinarizerSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/BinarizerSuite.scala
@@ -20,6 +20,7 @@ package org.apache.spark.ml.feature
import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.param.ParamsSuite
import org.apache.spark.ml.util.DefaultReadWriteTest
+import org.apache.spark.mllib.linalg.{Vector, Vectors}
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.sql.{DataFrame, Row}
@@ -68,6 +69,41 @@ class BinarizerSuite extends SparkFunSuite with MLlibTestSparkContext with Defau
}
}
+ test("Binarize vector of continuous features with default parameter") {
+ val defaultBinarized: Array[Double] = data.map(x => if (x > 0.0) 1.0 else 0.0)
+ val dataFrame: DataFrame = sqlContext.createDataFrame(Seq(
+ (Vectors.dense(data), Vectors.dense(defaultBinarized))
+ )).toDF("feature", "expected")
+
+ val binarizer: Binarizer = new Binarizer()
+ .setInputCol("feature")
+ .setOutputCol("binarized_feature")
+
+ binarizer.transform(dataFrame).select("binarized_feature", "expected").collect().foreach {
+ case Row(x: Vector, y: Vector) =>
+ assert(x == y, "The feature value is not correct after binarization.")
+ }
+ }
+
+ test("Binarize vector of continuous features with setter") {
+ val threshold: Double = 0.2
+ val defaultBinarized: Array[Double] = data.map(x => if (x > threshold) 1.0 else 0.0)
+ val dataFrame: DataFrame = sqlContext.createDataFrame(Seq(
+ (Vectors.dense(data), Vectors.dense(defaultBinarized))
+ )).toDF("feature", "expected")
+
+ val binarizer: Binarizer = new Binarizer()
+ .setInputCol("feature")
+ .setOutputCol("binarized_feature")
+ .setThreshold(threshold)
+
+ binarizer.transform(dataFrame).select("binarized_feature", "expected").collect().foreach {
+ case Row(x: Vector, y: Vector) =>
+ assert(x == y, "The feature value is not correct after binarization.")
+ }
+ }
+
+
test("read/write") {
val t = new Binarizer()
.setInputCol("myInputCol")