aboutsummaryrefslogtreecommitdiff
path: root/mllib/src
diff options
context:
space:
mode:
authorXusen Yin <yinxusen@gmail.com>2015-04-24 08:29:49 -0700
committerXiangrui Meng <meng@databricks.com>2015-04-24 08:29:49 -0700
commit6e57d57b32ba2aa0514692074897b5edd34e0dd6 (patch)
tree36a83dbc0a757d09b1f529997320830a29516832 /mllib/src
parent78b39c7e0de8c9dc748cfbf8f78578a9524b6a94 (diff)
downloadspark-6e57d57b32ba2aa0514692074897b5edd34e0dd6.tar.gz
spark-6e57d57b32ba2aa0514692074897b5edd34e0dd6.tar.bz2
spark-6e57d57b32ba2aa0514692074897b5edd34e0dd6.zip
[SPARK-6528] [ML] Add IDF transformer
See [SPARK-6528](https://issues.apache.org/jira/browse/SPARK-6528). Add IDF transformer in ML package. Author: Xusen Yin <yinxusen@gmail.com> Closes #5266 from yinxusen/SPARK-6528 and squashes the following commits: 741db31 [Xusen Yin] get param from new paramMap d169967 [Xusen Yin] add final to param and IDF class c9c3759 [Xusen Yin] simplify test suite 5867c09 [Xusen Yin] refine IDF transformer with new interfaces 7727cae [Xusen Yin] Merge branch 'master' into SPARK-6528 4338a37 [Xusen Yin] Merge branch 'master' into SPARK-6528 aef2cdf [Xusen Yin] add doc and group for param 5760b49 [Xusen Yin] fix code style 2add691 [Xusen Yin] fix code style and test 03fbecb [Xusen Yin] remove duplicated code 2aa4be0 [Xusen Yin] clean test suite 4802c67 [Xusen Yin] add IDF transformer and test suite
Diffstat (limited to 'mllib/src')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala116
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/feature/IDFSuite.scala101
2 files changed, 217 insertions, 0 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala
new file mode 100644
index 0000000000..e6a62d998b
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala
@@ -0,0 +1,116 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.feature
+
+import org.apache.spark.annotation.AlphaComponent
+import org.apache.spark.ml._
+import org.apache.spark.ml.param._
+import org.apache.spark.ml.param.shared._
+import org.apache.spark.ml.util.SchemaUtils
+import org.apache.spark.mllib.feature
+import org.apache.spark.mllib.linalg.{Vector, VectorUDT}
+import org.apache.spark.sql._
+import org.apache.spark.sql.functions._
+import org.apache.spark.sql.types.StructType
+
+/**
+ * Params for [[IDF]] and [[IDFModel]].
+ */
+private[feature] trait IDFBase extends Params with HasInputCol with HasOutputCol {
+
+ /**
+ * The minimum of documents in which a term should appear.
+ * @group param
+ */
+ final val minDocFreq = new IntParam(
+ this, "minDocFreq", "minimum of documents in which a term should appear for filtering")
+
+ setDefault(minDocFreq -> 0)
+
+ /** @group getParam */
+ def getMinDocFreq: Int = getOrDefault(minDocFreq)
+
+ /** @group setParam */
+ def setMinDocFreq(value: Int): this.type = set(minDocFreq, value)
+
+ /**
+ * Validate and transform the input schema.
+ */
+ protected def validateAndTransformSchema(schema: StructType, paramMap: ParamMap): StructType = {
+ val map = extractParamMap(paramMap)
+ SchemaUtils.checkColumnType(schema, map(inputCol), new VectorUDT)
+ SchemaUtils.appendColumn(schema, map(outputCol), new VectorUDT)
+ }
+}
+
+/**
+ * :: AlphaComponent ::
+ * Compute the Inverse Document Frequency (IDF) given a collection of documents.
+ */
+@AlphaComponent
+final class IDF extends Estimator[IDFModel] with IDFBase {
+
+ /** @group setParam */
+ def setInputCol(value: String): this.type = set(inputCol, value)
+
+ /** @group setParam */
+ def setOutputCol(value: String): this.type = set(outputCol, value)
+
+ override def fit(dataset: DataFrame, paramMap: ParamMap): IDFModel = {
+ transformSchema(dataset.schema, paramMap, logging = true)
+ val map = extractParamMap(paramMap)
+ val input = dataset.select(map(inputCol)).map { case Row(v: Vector) => v }
+ val idf = new feature.IDF(map(minDocFreq)).fit(input)
+ val model = new IDFModel(this, map, idf)
+ Params.inheritValues(map, this, model)
+ model
+ }
+
+ override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {
+ validateAndTransformSchema(schema, paramMap)
+ }
+}
+
+/**
+ * :: AlphaComponent ::
+ * Model fitted by [[IDF]].
+ */
+@AlphaComponent
+class IDFModel private[ml] (
+ override val parent: IDF,
+ override val fittingParamMap: ParamMap,
+ idfModel: feature.IDFModel)
+ extends Model[IDFModel] with IDFBase {
+
+ /** @group setParam */
+ def setInputCol(value: String): this.type = set(inputCol, value)
+
+ /** @group setParam */
+ def setOutputCol(value: String): this.type = set(outputCol, value)
+
+ override def transform(dataset: DataFrame, paramMap: ParamMap): DataFrame = {
+ transformSchema(dataset.schema, paramMap, logging = true)
+ val map = extractParamMap(paramMap)
+ val idf = udf { vec: Vector => idfModel.transform(vec) }
+ dataset.withColumn(map(outputCol), idf(col(map(inputCol))))
+ }
+
+ override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {
+ validateAndTransformSchema(schema, paramMap)
+ }
+}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/IDFSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/IDFSuite.scala
new file mode 100644
index 0000000000..eaee3443c1
--- /dev/null
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/IDFSuite.scala
@@ -0,0 +1,101 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.feature
+
+import org.scalatest.FunSuite
+
+import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, Vectors}
+import org.apache.spark.mllib.util.MLlibTestSparkContext
+import org.apache.spark.mllib.util.TestingUtils._
+import org.apache.spark.sql.{Row, SQLContext}
+
+class IDFSuite extends FunSuite with MLlibTestSparkContext {
+
+ @transient var sqlContext: SQLContext = _
+
+ override def beforeAll(): Unit = {
+ super.beforeAll()
+ sqlContext = new SQLContext(sc)
+ }
+
+ def scaleDataWithIDF(dataSet: Array[Vector], model: Vector): Array[Vector] = {
+ dataSet.map {
+ case data: DenseVector =>
+ val res = data.toArray.zip(model.toArray).map { case (x, y) => x * y }
+ Vectors.dense(res)
+ case data: SparseVector =>
+ val res = data.indices.zip(data.values).map { case (id, value) =>
+ (id, value * model(id))
+ }
+ Vectors.sparse(data.size, res)
+ }
+ }
+
+ test("compute IDF with default parameter") {
+ val numOfFeatures = 4
+ val data = Array(
+ Vectors.sparse(numOfFeatures, Array(1, 3), Array(1.0, 2.0)),
+ Vectors.dense(0.0, 1.0, 2.0, 3.0),
+ Vectors.sparse(numOfFeatures, Array(1), Array(1.0))
+ )
+ val numOfData = data.size
+ val idf = Vectors.dense(Array(0, 3, 1, 2).map { x =>
+ math.log((numOfData + 1.0) / (x + 1.0))
+ })
+ val expected = scaleDataWithIDF(data, idf)
+
+ val df = sqlContext.createDataFrame(data.zip(expected)).toDF("features", "expected")
+
+ val idfModel = new IDF()
+ .setInputCol("features")
+ .setOutputCol("idfValue")
+ .fit(df)
+
+ idfModel.transform(df).select("idfValue", "expected").collect().foreach {
+ case Row(x: Vector, y: Vector) =>
+ assert(x ~== y absTol 1e-5, "Transformed vector is different with expected vector.")
+ }
+ }
+
+ test("compute IDF with setter") {
+ val numOfFeatures = 4
+ val data = Array(
+ Vectors.sparse(numOfFeatures, Array(1, 3), Array(1.0, 2.0)),
+ Vectors.dense(0.0, 1.0, 2.0, 3.0),
+ Vectors.sparse(numOfFeatures, Array(1), Array(1.0))
+ )
+ val numOfData = data.size
+ val idf = Vectors.dense(Array(0, 3, 1, 2).map { x =>
+ if (x > 0) math.log((numOfData + 1.0) / (x + 1.0)) else 0
+ })
+ val expected = scaleDataWithIDF(data, idf)
+
+ val df = sqlContext.createDataFrame(data.zip(expected)).toDF("features", "expected")
+
+ val idfModel = new IDF()
+ .setInputCol("features")
+ .setOutputCol("idfValue")
+ .setMinDocFreq(1)
+ .fit(df)
+
+ idfModel.transform(df).select("idfValue", "expected").collect().foreach {
+ case Row(x: Vector, y: Vector) =>
+ assert(x ~== y absTol 1e-5, "Transformed vector is different with expected vector.")
+ }
+ }
+}