aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
authorLiang-Chi Hsieh <viirya@gmail.com>2015-05-01 08:31:01 -0700
committerXiangrui Meng <meng@databricks.com>2015-05-01 08:31:01 -0700
commit7630213cab1f653212828f045cf1d7d1870abea0 (patch)
tree27579517fc70dc243d54df060c29c9d64b9bb76a /mllib
parent3b514af8a0c2ca496315b99a2b09bc887ac6c5e1 (diff)
downloadspark-7630213cab1f653212828f045cf1d7d1870abea0.tar.gz
spark-7630213cab1f653212828f045cf1d7d1870abea0.tar.bz2
spark-7630213cab1f653212828f045cf1d7d1870abea0.zip
[SPARK-5891] [ML] Add Binarizer ML Transformer
JIRA: https://issues.apache.org/jira/browse/SPARK-5891 Author: Liang-Chi Hsieh <viirya@gmail.com> Closes #5699 from viirya/add_binarizer and squashes the following commits: 1a0b9a4 [Liang-Chi Hsieh] For comments. bc397f2 [Liang-Chi Hsieh] Merge remote-tracking branch 'upstream/master' into add_binarizer cc4f03c [Liang-Chi Hsieh] Implement threshold param and use merged params map. 7564c63 [Liang-Chi Hsieh] Merge remote-tracking branch 'upstream/master' into add_binarizer 1682f8c [Liang-Chi Hsieh] Add Binarizer ML Transformer.
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/Binarizer.scala85
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/feature/BinarizerSuite.scala69
2 files changed, 154 insertions, 0 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Binarizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Binarizer.scala
new file mode 100644
index 0000000000..f3ce6dfca2
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Binarizer.scala
@@ -0,0 +1,85 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.feature
+
+import org.apache.spark.annotation.AlphaComponent
+import org.apache.spark.ml.Transformer
+import org.apache.spark.ml.attribute.BinaryAttribute
+import org.apache.spark.ml.param._
+import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol}
+import org.apache.spark.ml.util.SchemaUtils
+import org.apache.spark.sql._
+import org.apache.spark.sql.functions._
+import org.apache.spark.sql.types.{DoubleType, StructType}
+
+/**
+ * :: AlphaComponent ::
+ * Binarize a column of continuous features given a threshold.
+ */
+@AlphaComponent
+final class Binarizer extends Transformer with HasInputCol with HasOutputCol {
+
+ /**
+ * Param for threshold used to binarize continuous features.
+ * The features greater than the threshold, will be binarized to 1.0.
+ * The features equal to or less than the threshold, will be binarized to 0.0.
+ * @group param
+ */
+ val threshold: DoubleParam =
+ new DoubleParam(this, "threshold", "threshold used to binarize continuous features")
+
+ /** @group getParam */
+ def getThreshold: Double = getOrDefault(threshold)
+
+ /** @group setParam */
+ def setThreshold(value: Double): this.type = set(threshold, value)
+
+ setDefault(threshold -> 0.0)
+
+ /** @group setParam */
+ def setInputCol(value: String): this.type = set(inputCol, value)
+
+ /** @group setParam */
+ def setOutputCol(value: String): this.type = set(outputCol, value)
+
+ override def transform(dataset: DataFrame, paramMap: ParamMap): DataFrame = {
+ transformSchema(dataset.schema, paramMap, logging = true)
+ val map = extractParamMap(paramMap)
+ val td = map(threshold)
+ val binarizer = udf { in: Double => if (in > td) 1.0 else 0.0 }
+ val outputColName = map(outputCol)
+ val metadata = BinaryAttribute.defaultAttr.withName(outputColName).toMetadata()
+ dataset.select(col("*"),
+ binarizer(col(map(inputCol))).as(outputColName, metadata))
+ }
+
+ override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {
+ val map = extractParamMap(paramMap)
+ SchemaUtils.checkColumnType(schema, map(inputCol), DoubleType)
+
+ val inputFields = schema.fields
+ val outputColName = map(outputCol)
+
+ require(inputFields.forall(_.name != outputColName),
+ s"Output column $outputColName already exists.")
+
+ val attr = BinaryAttribute.defaultAttr.withName(outputColName)
+ val outputFields = inputFields :+ attr.toStructField()
+ StructType(outputFields)
+ }
+}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/BinarizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/BinarizerSuite.scala
new file mode 100644
index 0000000000..caf1b75959
--- /dev/null
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/BinarizerSuite.scala
@@ -0,0 +1,69 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.feature
+
+import org.scalatest.FunSuite
+
+import org.apache.spark.mllib.util.MLlibTestSparkContext
+import org.apache.spark.mllib.util.TestingUtils._
+import org.apache.spark.sql.{DataFrame, Row, SQLContext}
+
+
+class BinarizerSuite extends FunSuite with MLlibTestSparkContext {
+
+ @transient var data: Array[Double] = _
+ @transient var sqlContext: SQLContext = _
+
+ override def beforeAll(): Unit = {
+ super.beforeAll()
+ sqlContext = new SQLContext(sc)
+ data = Array(0.1, -0.5, 0.2, -0.3, 0.8, 0.7, -0.1, -0.4)
+ }
+
+ test("Binarize continuous features with default parameter") {
+ val defaultBinarized: Array[Double] = data.map(x => if (x > 0.0) 1.0 else 0.0)
+ val dataFrame: DataFrame = sqlContext.createDataFrame(
+ data.zip(defaultBinarized)).toDF("feature", "expected")
+
+ val binarizer: Binarizer = new Binarizer()
+ .setInputCol("feature")
+ .setOutputCol("binarized_feature")
+
+ binarizer.transform(dataFrame).select("binarized_feature", "expected").collect().foreach {
+ case Row(x: Double, y: Double) =>
+ assert(x === y, "The feature value is not correct after binarization.")
+ }
+ }
+
+ test("Binarize continuous features with setter") {
+ val threshold: Double = 0.2
+ val thresholdBinarized: Array[Double] = data.map(x => if (x > threshold) 1.0 else 0.0)
+ val dataFrame: DataFrame = sqlContext.createDataFrame(
+ data.zip(thresholdBinarized)).toDF("feature", "expected")
+
+ val binarizer: Binarizer = new Binarizer()
+ .setInputCol("feature")
+ .setOutputCol("binarized_feature")
+ .setThreshold(threshold)
+
+ binarizer.transform(dataFrame).select("binarized_feature", "expected").collect().foreach {
+ case Row(x: Double, y: Double) =>
+ assert(x === y, "The feature value is not correct after binarization.")
+ }
+ }
+}