From cbeb006f23838b2f19e700e20b25003aeb3dfb01 Mon Sep 17 00:00:00 2001 From: seddonm1 Date: Mon, 15 Feb 2016 20:15:27 -0800 Subject: [SPARK-13097][ML] Binarizer allowing Double AND Vector input types This enhancement extends the existing SparkML Binarizer [SPARK-5891] to allow Vector in addition to the existing Double input column type. A use case for this enhancement is for when a user wants to Binarize many similar feature columns at once using the same threshold value (for example a binary threshold applied to many pixels in an image). This contribution is my original work and I license the work to the project under the project's open source license. viirya mengxr Author: seddonm1 Closes #10976 from seddonm1/master. --- .../org/apache/spark/ml/feature/Binarizer.scala | 62 ++++++++++++++++------ .../apache/spark/ml/feature/BinarizerSuite.scala | 36 +++++++++++++ 2 files changed, 81 insertions(+), 17 deletions(-) (limited to 'mllib') diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Binarizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Binarizer.scala index 544cf05a30..2f8e3a0371 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Binarizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Binarizer.scala @@ -17,15 +17,18 @@ package org.apache.spark.ml.feature +import scala.collection.mutable.ArrayBuilder + import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.ml.Transformer import org.apache.spark.ml.attribute.BinaryAttribute import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol} import org.apache.spark.ml.util._ +import org.apache.spark.mllib.linalg._ import org.apache.spark.sql._ import org.apache.spark.sql.functions._ -import org.apache.spark.sql.types.{DoubleType, StructType} +import org.apache.spark.sql.types._ /** * :: Experimental :: @@ -62,28 +65,53 @@ final class Binarizer(override val uid: String) def setOutputCol(value: String): this.type = set(outputCol, value) override def transform(dataset: DataFrame): DataFrame = { - transformSchema(dataset.schema, logging = true) + val outputSchema = transformSchema(dataset.schema, logging = true) + val schema = dataset.schema + val inputType = schema($(inputCol)).dataType val td = $(threshold) - val binarizer = udf { in: Double => if (in > td) 1.0 else 0.0 } - val outputColName = $(outputCol) - val metadata = BinaryAttribute.defaultAttr.withName(outputColName).toMetadata() - dataset.select(col("*"), - binarizer(col($(inputCol))).as(outputColName, metadata)) + + val binarizerDouble = udf { in: Double => if (in > td) 1.0 else 0.0 } + val binarizerVector = udf { (data: Vector) => + val indices = ArrayBuilder.make[Int] + val values = ArrayBuilder.make[Double] + + data.foreachActive { (index, value) => + if (value > td) { + indices += index + values += 1.0 + } + } + + Vectors.sparse(data.size, indices.result(), values.result()).compressed + } + + val metadata = outputSchema($(outputCol)).metadata + + inputType match { + case DoubleType => + dataset.select(col("*"), binarizerDouble(col($(inputCol))).as($(outputCol), metadata)) + case _: VectorUDT => + dataset.select(col("*"), binarizerVector(col($(inputCol))).as($(outputCol), metadata)) + } } override def transformSchema(schema: StructType): StructType = { - validateParams() - SchemaUtils.checkColumnType(schema, $(inputCol), DoubleType) - - val inputFields = schema.fields + val inputType = schema($(inputCol)).dataType val outputColName = $(outputCol) - require(inputFields.forall(_.name != outputColName), - s"Output column $outputColName already exists.") - - val attr = BinaryAttribute.defaultAttr.withName(outputColName) - val outputFields = inputFields :+ attr.toStructField() - StructType(outputFields) + val outCol: StructField = inputType match { + case DoubleType => + BinaryAttribute.defaultAttr.withName(outputColName).toStructField() + case _: VectorUDT => + new StructField(outputColName, new VectorUDT, true) + case other => + throw new IllegalArgumentException(s"Data type $other is not supported.") + } + + if (schema.fieldNames.contains(outputColName)) { + throw new IllegalArgumentException(s"Output column $outputColName already exists.") + } + StructType(schema.fields :+ outCol) } override def copy(extra: ParamMap): Binarizer = defaultCopy(extra) diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/BinarizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/BinarizerSuite.scala index 6d2d8fe714..714b9db3aa 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/BinarizerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/BinarizerSuite.scala @@ -20,6 +20,7 @@ package org.apache.spark.ml.feature import org.apache.spark.SparkFunSuite import org.apache.spark.ml.param.ParamsSuite import org.apache.spark.ml.util.DefaultReadWriteTest +import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.sql.{DataFrame, Row} @@ -68,6 +69,41 @@ class BinarizerSuite extends SparkFunSuite with MLlibTestSparkContext with Defau } } + test("Binarize vector of continuous features with default parameter") { + val defaultBinarized: Array[Double] = data.map(x => if (x > 0.0) 1.0 else 0.0) + val dataFrame: DataFrame = sqlContext.createDataFrame(Seq( + (Vectors.dense(data), Vectors.dense(defaultBinarized)) + )).toDF("feature", "expected") + + val binarizer: Binarizer = new Binarizer() + .setInputCol("feature") + .setOutputCol("binarized_feature") + + binarizer.transform(dataFrame).select("binarized_feature", "expected").collect().foreach { + case Row(x: Vector, y: Vector) => + assert(x == y, "The feature value is not correct after binarization.") + } + } + + test("Binarize vector of continuous features with setter") { + val threshold: Double = 0.2 + val defaultBinarized: Array[Double] = data.map(x => if (x > threshold) 1.0 else 0.0) + val dataFrame: DataFrame = sqlContext.createDataFrame(Seq( + (Vectors.dense(data), Vectors.dense(defaultBinarized)) + )).toDF("feature", "expected") + + val binarizer: Binarizer = new Binarizer() + .setInputCol("feature") + .setOutputCol("binarized_feature") + .setThreshold(threshold) + + binarizer.transform(dataFrame).select("binarized_feature", "expected").collect().foreach { + case Row(x: Vector, y: Vector) => + assert(x == y, "The feature value is not correct after binarization.") + } + } + + test("read/write") { val t = new Binarizer() .setInputCol("myInputCol") -- cgit v1.2.3