From 633aaae0a1e31e9ba634423840e350b22342c6b5 Mon Sep 17 00:00:00 2001 From: Xusen Yin Date: Fri, 2 Oct 2015 10:25:58 -0700 Subject: [SPARK-6530] [ML] Add chi-square selector for ml package See JIRA [here](https://issues.apache.org/jira/browse/SPARK-6530). Author: Xusen Yin Closes #5742 from yinxusen/SPARK-6530. --- .../apache/spark/ml/feature/ChiSqSelector.scala | 150 +++++++++++++++++++++ .../apache/spark/mllib/feature/ChiSqSelector.scala | 2 + .../spark/ml/feature/ChiSqSelectorSuite.scala | 61 +++++++++ 3 files changed, 213 insertions(+) create mode 100644 mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala create mode 100644 mllib/src/test/scala/org/apache/spark/ml/feature/ChiSqSelectorSuite.scala (limited to 'mllib') diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala new file mode 100644 index 0000000000..5e4061fba5 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala @@ -0,0 +1,150 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.feature + +import org.apache.spark.annotation.Experimental +import org.apache.spark.ml._ +import org.apache.spark.ml.attribute.{AttributeGroup, _} +import org.apache.spark.ml.param._ +import org.apache.spark.ml.param.shared._ +import org.apache.spark.ml.util.Identifiable +import org.apache.spark.ml.util.SchemaUtils +import org.apache.spark.mllib.feature +import org.apache.spark.mllib.linalg.{Vector, VectorUDT} +import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.sql._ +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.types.{DoubleType, StructField, StructType} + +/** + * Params for [[ChiSqSelector]] and [[ChiSqSelectorModel]]. + */ +private[feature] trait ChiSqSelectorParams extends Params + with HasFeaturesCol with HasOutputCol with HasLabelCol { + + /** + * Number of features that selector will select (ordered by statistic value descending). If the + * number of features is < numTopFeatures, then this will select all features. The default value + * of numTopFeatures is 50. + * @group param + */ + final val numTopFeatures = new IntParam(this, "numTopFeatures", + "Number of features that selector will select, ordered by statistics value descending. If the" + + " number of features is < numTopFeatures, then this will select all features.", + ParamValidators.gtEq(1)) + setDefault(numTopFeatures -> 50) + + /** @group getParam */ + def getNumTopFeatures: Int = $(numTopFeatures) +} + +/** + * :: Experimental :: + * Chi-Squared feature selection, which selects categorical features to use for predicting a + * categorical label. + */ +@Experimental +final class ChiSqSelector(override val uid: String) + extends Estimator[ChiSqSelectorModel] with ChiSqSelectorParams { + + def this() = this(Identifiable.randomUID("chiSqSelector")) + + /** @group setParam */ + def setNumTopFeatures(value: Int): this.type = set(numTopFeatures, value) + + /** @group setParam */ + def setFeaturesCol(value: String): this.type = set(featuresCol, value) + + /** @group setParam */ + def setOutputCol(value: String): this.type = set(outputCol, value) + + /** @group setParam */ + def setLabelCol(value: String): this.type = set(labelCol, value) + + override def fit(dataset: DataFrame): ChiSqSelectorModel = { + transformSchema(dataset.schema, logging = true) + val input = dataset.select($(labelCol), $(featuresCol)).map { + case Row(label: Double, features: Vector) => + LabeledPoint(label, features) + } + val chiSqSelector = new feature.ChiSqSelector($(numTopFeatures)).fit(input) + copyValues(new ChiSqSelectorModel(uid, chiSqSelector).setParent(this)) + } + + override def transformSchema(schema: StructType): StructType = { + SchemaUtils.checkColumnType(schema, $(featuresCol), new VectorUDT) + SchemaUtils.checkColumnType(schema, $(labelCol), DoubleType) + SchemaUtils.appendColumn(schema, $(outputCol), new VectorUDT) + } + + override def copy(extra: ParamMap): ChiSqSelector = defaultCopy(extra) +} + +/** + * :: Experimental :: + * Model fitted by [[ChiSqSelector]]. + */ +@Experimental +final class ChiSqSelectorModel private[ml] ( + override val uid: String, + private val chiSqSelector: feature.ChiSqSelectorModel) + extends Model[ChiSqSelectorModel] with ChiSqSelectorParams { + + /** @group setParam */ + def setFeaturesCol(value: String): this.type = set(featuresCol, value) + + /** @group setParam */ + def setOutputCol(value: String): this.type = set(outputCol, value) + + /** @group setParam */ + def setLabelCol(value: String): this.type = set(labelCol, value) + + override def transform(dataset: DataFrame): DataFrame = { + val transformedSchema = transformSchema(dataset.schema, logging = true) + val newField = transformedSchema.last + val selector = udf { chiSqSelector.transform _ } + dataset.withColumn($(outputCol), selector(col($(featuresCol))), newField.metadata) + } + + override def transformSchema(schema: StructType): StructType = { + SchemaUtils.checkColumnType(schema, $(featuresCol), new VectorUDT) + val newField = prepOutputField(schema) + val outputFields = schema.fields :+ newField + StructType(outputFields) + } + + /** + * Prepare the output column field, including per-feature metadata. + */ + private def prepOutputField(schema: StructType): StructField = { + val selector = chiSqSelector.selectedFeatures.toSet + val origAttrGroup = AttributeGroup.fromStructField(schema($(featuresCol))) + val featureAttributes: Array[Attribute] = if (origAttrGroup.attributes.nonEmpty) { + origAttrGroup.attributes.get.zipWithIndex.filter(x => selector.contains(x._2)).map(_._1) + } else { + Array.fill[Attribute](selector.size)(NominalAttribute.defaultAttr) + } + val newAttributeGroup = new AttributeGroup($(outputCol), featureAttributes) + newAttributeGroup.toStructField() + } + + override def copy(extra: ParamMap): ChiSqSelectorModel = { + val copied = new ChiSqSelectorModel(uid, chiSqSelector) + copyValues(copied, extra).setParent(parent) + } +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala index 4743cfd1a2..b1524cf377 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala @@ -109,6 +109,8 @@ class ChiSqSelectorModel @Since("1.3.0") ( * Creates a ChiSquared feature selector. * @param numTopFeatures number of features that selector will select * (ordered by statistic value descending) + * Note that if the number of features is < numTopFeatures, then this will + * select all features. */ @Since("1.3.0") @Experimental diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/ChiSqSelectorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/ChiSqSelectorSuite.scala new file mode 100644 index 0000000000..e5a42967bd --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/ChiSqSelectorSuite.scala @@ -0,0 +1,61 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.feature + +import org.apache.spark.SparkFunSuite +import org.apache.spark.mllib.linalg.{Vector, Vectors} +import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.mllib.util.TestingUtils._ +import org.apache.spark.sql.{Row, SQLContext} + +class ChiSqSelectorSuite extends SparkFunSuite with MLlibTestSparkContext { + test("Test Chi-Square selector") { + val sqlContext = SQLContext.getOrCreate(sc) + import sqlContext.implicits._ + + val data = Seq( + LabeledPoint(0.0, Vectors.sparse(3, Array((0, 8.0), (1, 7.0)))), + LabeledPoint(1.0, Vectors.sparse(3, Array((1, 9.0), (2, 6.0)))), + LabeledPoint(1.0, Vectors.dense(Array(0.0, 9.0, 8.0))), + LabeledPoint(2.0, Vectors.dense(Array(8.0, 9.0, 5.0))) + ) + + val preFilteredData = Seq( + Vectors.dense(0.0), + Vectors.dense(6.0), + Vectors.dense(8.0), + Vectors.dense(5.0) + ) + + val df = sc.parallelize(data.zip(preFilteredData)) + .map(x => (x._1.label, x._1.features, x._2)) + .toDF("label", "data", "preFilteredData") + + val model = new ChiSqSelector() + .setNumTopFeatures(1) + .setFeaturesCol("data") + .setLabelCol("label") + .setOutputCol("filtered") + + model.fit(df).transform(df).select("filtered", "preFilteredData").collect().foreach { + case Row(vec1: Vector, vec2: Vector) => + assert(vec1 ~== vec2 absTol 1e-1) + } + } +} -- cgit v1.2.3