aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
authorXiangrui Meng <meng@databricks.com>2015-05-21 18:04:45 -0700
committerJoseph K. Bradley <joseph@databricks.com>2015-05-21 18:04:45 -0700
commit85b96372cf0fd055f89fc639f45c1f2cb02a378f (patch)
treeefdc362523217e9c8e3da9e4c2ba1743ad44d094 /mllib
parentf5db4b416c922db7a8f1b0c098b4f08647106231 (diff)
downloadspark-85b96372cf0fd055f89fc639f45c1f2cb02a378f.tar.gz
spark-85b96372cf0fd055f89fc639f45c1f2cb02a378f.tar.bz2
spark-85b96372cf0fd055f89fc639f45c1f2cb02a378f.zip
[SPARK-7219] [MLLIB] Output feature attributes in HashingTF
This PR updates `HashingTF` to output ML attributes that tell the number of features in the output column. We need to expand `UnaryTransformer` to support output metadata. A `df outputMetadata: Metadata` is not sufficient because the metadata may also depends on the input data. Though this is not true for `HashingTF`, I think it is reasonable to update `UnaryTransformer` in a separate PR. `checkParams` is added to verify common requirements for params. I will send a separate PR to use it in other test suites. jkbradley Author: Xiangrui Meng <meng@databricks.com> Closes #6308 from mengxr/SPARK-7219 and squashes the following commits: 9bd2922 [Xiangrui Meng] address comments e82a68a [Xiangrui Meng] remove sqlContext from test suite 995535b [Xiangrui Meng] Merge remote-tracking branch 'apache/master' into SPARK-7219 2194703 [Xiangrui Meng] add test for attributes 178ae23 [Xiangrui Meng] update HashingTF with tests 91a6106 [Xiangrui Meng] WIP
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala34
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/feature/HashingTFSuite.scala55
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala20
3 files changed, 101 insertions, 8 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala
index 30033ced68..8942d45219 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala
@@ -18,22 +18,31 @@
package org.apache.spark.ml.feature
import org.apache.spark.annotation.AlphaComponent
-import org.apache.spark.ml.UnaryTransformer
+import org.apache.spark.ml.Transformer
+import org.apache.spark.ml.attribute.AttributeGroup
+import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol}
import org.apache.spark.ml.param.{IntParam, ParamValidators}
-import org.apache.spark.ml.util.Identifiable
+import org.apache.spark.ml.util.{Identifiable, SchemaUtils}
import org.apache.spark.mllib.feature
-import org.apache.spark.mllib.linalg.{Vector, VectorUDT}
-import org.apache.spark.sql.types.DataType
+import org.apache.spark.sql.DataFrame
+import org.apache.spark.sql.functions.{udf, col}
+import org.apache.spark.sql.types.{ArrayType, StructType}
/**
* :: AlphaComponent ::
* Maps a sequence of terms to their term frequencies using the hashing trick.
*/
@AlphaComponent
-class HashingTF(override val uid: String) extends UnaryTransformer[Iterable[_], Vector, HashingTF] {
+class HashingTF(override val uid: String) extends Transformer with HasInputCol with HasOutputCol {
def this() = this(Identifiable.randomUID("hashingTF"))
+ /** @group setParam */
+ def setInputCol(value: String): this.type = set(inputCol, value)
+
+ /** @group setParam */
+ def setOutputCol(value: String): this.type = set(outputCol, value)
+
/**
* Number of features. Should be > 0.
* (default = 2^18^)
@@ -50,10 +59,19 @@ class HashingTF(override val uid: String) extends UnaryTransformer[Iterable[_],
/** @group setParam */
def setNumFeatures(value: Int): this.type = set(numFeatures, value)
- override protected def createTransformFunc: Iterable[_] => Vector = {
+ override def transform(dataset: DataFrame): DataFrame = {
+ val outputSchema = transformSchema(dataset.schema)
val hashingTF = new feature.HashingTF($(numFeatures))
- hashingTF.transform
+ val t = udf { terms: Seq[_] => hashingTF.transform(terms) }
+ val metadata = outputSchema($(outputCol)).metadata
+ dataset.select(col("*"), t(col($(inputCol))).as($(outputCol), metadata))
}
- override protected def outputDataType: DataType = new VectorUDT()
+ override def transformSchema(schema: StructType): StructType = {
+ val inputType = schema($(inputCol)).dataType
+ require(inputType.isInstanceOf[ArrayType],
+ s"The input column must be ArrayType, but got $inputType.")
+ val attrGroup = new AttributeGroup($(outputCol), $(numFeatures))
+ SchemaUtils.appendColumn(schema, attrGroup.toStructField())
+ }
}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/HashingTFSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/HashingTFSuite.scala
new file mode 100644
index 0000000000..2e4beb0bff
--- /dev/null
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/HashingTFSuite.scala
@@ -0,0 +1,55 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.feature
+
+import org.scalatest.FunSuite
+
+import org.apache.spark.ml.attribute.AttributeGroup
+import org.apache.spark.ml.param.ParamsSuite
+import org.apache.spark.mllib.linalg.{Vector, Vectors}
+import org.apache.spark.mllib.util.MLlibTestSparkContext
+import org.apache.spark.mllib.util.TestingUtils._
+import org.apache.spark.util.Utils
+
+class HashingTFSuite extends FunSuite with MLlibTestSparkContext {
+
+ test("params") {
+ val hashingTF = new HashingTF
+ ParamsSuite.checkParams(hashingTF, 3)
+ }
+
+ test("hashingTF") {
+ val df = sqlContext.createDataFrame(Seq(
+ (0, "a a b b c d".split(" ").toSeq)
+ )).toDF("id", "words")
+ val n = 100
+ val hashingTF = new HashingTF()
+ .setInputCol("words")
+ .setOutputCol("features")
+ .setNumFeatures(n)
+ val output = hashingTF.transform(df)
+ val attrGroup = AttributeGroup.fromStructField(output.schema("features"))
+ require(attrGroup.numAttributes === Some(n))
+ val features = output.select("features").first().getAs[Vector](0)
+ // Assume perfect hash on "a", "b", "c", and "d".
+ def idx(any: Any): Int = Utils.nonNegativeMod(any.##, n)
+ val expected = Vectors.sparse(n,
+ Seq((idx("a"), 2.0), (idx("b"), 2.0), (idx("c"), 1.0), (idx("d"), 1.0)))
+ assert(features ~== expected absTol 1e-14)
+ }
+}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala
index b96874f3a8..d270ad7613 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala
@@ -201,3 +201,23 @@ class ParamsSuite extends FunSuite {
assert(inArray(1) && inArray(2) && !inArray(0))
}
}
+
+object ParamsSuite extends FunSuite {
+
+ /**
+ * Checks common requirements for [[Params.params]]: 1) number of params; 2) params are ordered
+ * by names; 3) param parent has the same UID as the object's UID; 4) param name is the same as
+ * the param method name.
+ */
+ def checkParams(obj: Params, expectedNumParams: Int): Unit = {
+ val params = obj.params
+ require(params.length === expectedNumParams,
+ s"Expect $expectedNumParams params but got ${params.length}: ${params.map(_.name).toSeq}.")
+ val paramNames = params.map(_.name)
+ require(paramNames === paramNames.sorted)
+ params.foreach { p =>
+ assert(p.parent === obj.uid)
+ assert(obj.getParam(p.name) === p)
+ }
+ }
+}