aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
authorXusen Yin <yinxusen@gmail.com>2015-10-02 10:25:58 -0700
committerJoseph K. Bradley <joseph@databricks.com>2015-10-02 10:25:58 -0700
commit633aaae0a1e31e9ba634423840e350b22342c6b5 (patch)
tree923bd9cea84eb987f187d98153d93196418de311 /mllib
parent23a9448c04da7130d6c41c37f9fdf03184422dc8 (diff)
downloadspark-633aaae0a1e31e9ba634423840e350b22342c6b5.tar.gz
spark-633aaae0a1e31e9ba634423840e350b22342c6b5.tar.bz2
spark-633aaae0a1e31e9ba634423840e350b22342c6b5.zip
[SPARK-6530] [ML] Add chi-square selector for ml package
See JIRA [here](https://issues.apache.org/jira/browse/SPARK-6530). Author: Xusen Yin <yinxusen@gmail.com> Closes #5742 from yinxusen/SPARK-6530.
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala150
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala2
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/feature/ChiSqSelectorSuite.scala61
3 files changed, 213 insertions, 0 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala
new file mode 100644
index 0000000000..5e4061fba5
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala
@@ -0,0 +1,150 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.feature
+
+import org.apache.spark.annotation.Experimental
+import org.apache.spark.ml._
+import org.apache.spark.ml.attribute.{AttributeGroup, _}
+import org.apache.spark.ml.param._
+import org.apache.spark.ml.param.shared._
+import org.apache.spark.ml.util.Identifiable
+import org.apache.spark.ml.util.SchemaUtils
+import org.apache.spark.mllib.feature
+import org.apache.spark.mllib.linalg.{Vector, VectorUDT}
+import org.apache.spark.mllib.regression.LabeledPoint
+import org.apache.spark.sql._
+import org.apache.spark.sql.functions._
+import org.apache.spark.sql.types.{DoubleType, StructField, StructType}
+
+/**
+ * Params for [[ChiSqSelector]] and [[ChiSqSelectorModel]].
+ */
+private[feature] trait ChiSqSelectorParams extends Params
+ with HasFeaturesCol with HasOutputCol with HasLabelCol {
+
+ /**
+ * Number of features that selector will select (ordered by statistic value descending). If the
+ * number of features is < numTopFeatures, then this will select all features. The default value
+ * of numTopFeatures is 50.
+ * @group param
+ */
+ final val numTopFeatures = new IntParam(this, "numTopFeatures",
+ "Number of features that selector will select, ordered by statistics value descending. If the" +
+ " number of features is < numTopFeatures, then this will select all features.",
+ ParamValidators.gtEq(1))
+ setDefault(numTopFeatures -> 50)
+
+ /** @group getParam */
+ def getNumTopFeatures: Int = $(numTopFeatures)
+}
+
+/**
+ * :: Experimental ::
+ * Chi-Squared feature selection, which selects categorical features to use for predicting a
+ * categorical label.
+ */
+@Experimental
+final class ChiSqSelector(override val uid: String)
+ extends Estimator[ChiSqSelectorModel] with ChiSqSelectorParams {
+
+ def this() = this(Identifiable.randomUID("chiSqSelector"))
+
+ /** @group setParam */
+ def setNumTopFeatures(value: Int): this.type = set(numTopFeatures, value)
+
+ /** @group setParam */
+ def setFeaturesCol(value: String): this.type = set(featuresCol, value)
+
+ /** @group setParam */
+ def setOutputCol(value: String): this.type = set(outputCol, value)
+
+ /** @group setParam */
+ def setLabelCol(value: String): this.type = set(labelCol, value)
+
+ override def fit(dataset: DataFrame): ChiSqSelectorModel = {
+ transformSchema(dataset.schema, logging = true)
+ val input = dataset.select($(labelCol), $(featuresCol)).map {
+ case Row(label: Double, features: Vector) =>
+ LabeledPoint(label, features)
+ }
+ val chiSqSelector = new feature.ChiSqSelector($(numTopFeatures)).fit(input)
+ copyValues(new ChiSqSelectorModel(uid, chiSqSelector).setParent(this))
+ }
+
+ override def transformSchema(schema: StructType): StructType = {
+ SchemaUtils.checkColumnType(schema, $(featuresCol), new VectorUDT)
+ SchemaUtils.checkColumnType(schema, $(labelCol), DoubleType)
+ SchemaUtils.appendColumn(schema, $(outputCol), new VectorUDT)
+ }
+
+ override def copy(extra: ParamMap): ChiSqSelector = defaultCopy(extra)
+}
+
+/**
+ * :: Experimental ::
+ * Model fitted by [[ChiSqSelector]].
+ */
+@Experimental
+final class ChiSqSelectorModel private[ml] (
+ override val uid: String,
+ private val chiSqSelector: feature.ChiSqSelectorModel)
+ extends Model[ChiSqSelectorModel] with ChiSqSelectorParams {
+
+ /** @group setParam */
+ def setFeaturesCol(value: String): this.type = set(featuresCol, value)
+
+ /** @group setParam */
+ def setOutputCol(value: String): this.type = set(outputCol, value)
+
+ /** @group setParam */
+ def setLabelCol(value: String): this.type = set(labelCol, value)
+
+ override def transform(dataset: DataFrame): DataFrame = {
+ val transformedSchema = transformSchema(dataset.schema, logging = true)
+ val newField = transformedSchema.last
+ val selector = udf { chiSqSelector.transform _ }
+ dataset.withColumn($(outputCol), selector(col($(featuresCol))), newField.metadata)
+ }
+
+ override def transformSchema(schema: StructType): StructType = {
+ SchemaUtils.checkColumnType(schema, $(featuresCol), new VectorUDT)
+ val newField = prepOutputField(schema)
+ val outputFields = schema.fields :+ newField
+ StructType(outputFields)
+ }
+
+ /**
+ * Prepare the output column field, including per-feature metadata.
+ */
+ private def prepOutputField(schema: StructType): StructField = {
+ val selector = chiSqSelector.selectedFeatures.toSet
+ val origAttrGroup = AttributeGroup.fromStructField(schema($(featuresCol)))
+ val featureAttributes: Array[Attribute] = if (origAttrGroup.attributes.nonEmpty) {
+ origAttrGroup.attributes.get.zipWithIndex.filter(x => selector.contains(x._2)).map(_._1)
+ } else {
+ Array.fill[Attribute](selector.size)(NominalAttribute.defaultAttr)
+ }
+ val newAttributeGroup = new AttributeGroup($(outputCol), featureAttributes)
+ newAttributeGroup.toStructField()
+ }
+
+ override def copy(extra: ParamMap): ChiSqSelectorModel = {
+ val copied = new ChiSqSelectorModel(uid, chiSqSelector)
+ copyValues(copied, extra).setParent(parent)
+ }
+}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala
index 4743cfd1a2..b1524cf377 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala
@@ -109,6 +109,8 @@ class ChiSqSelectorModel @Since("1.3.0") (
* Creates a ChiSquared feature selector.
* @param numTopFeatures number of features that selector will select
* (ordered by statistic value descending)
+ * Note that if the number of features is < numTopFeatures, then this will
+ * select all features.
*/
@Since("1.3.0")
@Experimental
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/ChiSqSelectorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/ChiSqSelectorSuite.scala
new file mode 100644
index 0000000000..e5a42967bd
--- /dev/null
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/ChiSqSelectorSuite.scala
@@ -0,0 +1,61 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.feature
+
+import org.apache.spark.SparkFunSuite
+import org.apache.spark.mllib.linalg.{Vector, Vectors}
+import org.apache.spark.mllib.regression.LabeledPoint
+import org.apache.spark.mllib.util.MLlibTestSparkContext
+import org.apache.spark.mllib.util.TestingUtils._
+import org.apache.spark.sql.{Row, SQLContext}
+
+class ChiSqSelectorSuite extends SparkFunSuite with MLlibTestSparkContext {
+ test("Test Chi-Square selector") {
+ val sqlContext = SQLContext.getOrCreate(sc)
+ import sqlContext.implicits._
+
+ val data = Seq(
+ LabeledPoint(0.0, Vectors.sparse(3, Array((0, 8.0), (1, 7.0)))),
+ LabeledPoint(1.0, Vectors.sparse(3, Array((1, 9.0), (2, 6.0)))),
+ LabeledPoint(1.0, Vectors.dense(Array(0.0, 9.0, 8.0))),
+ LabeledPoint(2.0, Vectors.dense(Array(8.0, 9.0, 5.0)))
+ )
+
+ val preFilteredData = Seq(
+ Vectors.dense(0.0),
+ Vectors.dense(6.0),
+ Vectors.dense(8.0),
+ Vectors.dense(5.0)
+ )
+
+ val df = sc.parallelize(data.zip(preFilteredData))
+ .map(x => (x._1.label, x._1.features, x._2))
+ .toDF("label", "data", "preFilteredData")
+
+ val model = new ChiSqSelector()
+ .setNumTopFeatures(1)
+ .setFeaturesCol("data")
+ .setLabelCol("label")
+ .setOutputCol("filtered")
+
+ model.fit(df).transform(df).select("filtered", "preFilteredData").collect().foreach {
+ case Row(vec1: Vector, vec2: Vector) =>
+ assert(vec1 ~== vec2 absTol 1e-1)
+ }
+ }
+}