From 85b96372cf0fd055f89fc639f45c1f2cb02a378f Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Thu, 21 May 2015 18:04:45 -0700 Subject: [SPARK-7219] [MLLIB] Output feature attributes in HashingTF This PR updates `HashingTF` to output ML attributes that tell the number of features in the output column. We need to expand `UnaryTransformer` to support output metadata. A `df outputMetadata: Metadata` is not sufficient because the metadata may also depends on the input data. Though this is not true for `HashingTF`, I think it is reasonable to update `UnaryTransformer` in a separate PR. `checkParams` is added to verify common requirements for params. I will send a separate PR to use it in other test suites. jkbradley Author: Xiangrui Meng Closes #6308 from mengxr/SPARK-7219 and squashes the following commits: 9bd2922 [Xiangrui Meng] address comments e82a68a [Xiangrui Meng] remove sqlContext from test suite 995535b [Xiangrui Meng] Merge remote-tracking branch 'apache/master' into SPARK-7219 2194703 [Xiangrui Meng] add test for attributes 178ae23 [Xiangrui Meng] update HashingTF with tests 91a6106 [Xiangrui Meng] WIP --- .../org/apache/spark/ml/feature/HashingTF.scala | 34 +++++++++---- .../apache/spark/ml/feature/HashingTFSuite.scala | 55 ++++++++++++++++++++++ .../org/apache/spark/ml/param/ParamsSuite.scala | 20 ++++++++ 3 files changed, 101 insertions(+), 8 deletions(-) create mode 100644 mllib/src/test/scala/org/apache/spark/ml/feature/HashingTFSuite.scala (limited to 'mllib') diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala index 30033ced68..8942d45219 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala @@ -18,22 +18,31 @@ package org.apache.spark.ml.feature import org.apache.spark.annotation.AlphaComponent -import org.apache.spark.ml.UnaryTransformer +import org.apache.spark.ml.Transformer +import org.apache.spark.ml.attribute.AttributeGroup +import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol} import org.apache.spark.ml.param.{IntParam, ParamValidators} -import org.apache.spark.ml.util.Identifiable +import org.apache.spark.ml.util.{Identifiable, SchemaUtils} import org.apache.spark.mllib.feature -import org.apache.spark.mllib.linalg.{Vector, VectorUDT} -import org.apache.spark.sql.types.DataType +import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.functions.{udf, col} +import org.apache.spark.sql.types.{ArrayType, StructType} /** * :: AlphaComponent :: * Maps a sequence of terms to their term frequencies using the hashing trick. */ @AlphaComponent -class HashingTF(override val uid: String) extends UnaryTransformer[Iterable[_], Vector, HashingTF] { +class HashingTF(override val uid: String) extends Transformer with HasInputCol with HasOutputCol { def this() = this(Identifiable.randomUID("hashingTF")) + /** @group setParam */ + def setInputCol(value: String): this.type = set(inputCol, value) + + /** @group setParam */ + def setOutputCol(value: String): this.type = set(outputCol, value) + /** * Number of features. Should be > 0. * (default = 2^18^) @@ -50,10 +59,19 @@ class HashingTF(override val uid: String) extends UnaryTransformer[Iterable[_], /** @group setParam */ def setNumFeatures(value: Int): this.type = set(numFeatures, value) - override protected def createTransformFunc: Iterable[_] => Vector = { + override def transform(dataset: DataFrame): DataFrame = { + val outputSchema = transformSchema(dataset.schema) val hashingTF = new feature.HashingTF($(numFeatures)) - hashingTF.transform + val t = udf { terms: Seq[_] => hashingTF.transform(terms) } + val metadata = outputSchema($(outputCol)).metadata + dataset.select(col("*"), t(col($(inputCol))).as($(outputCol), metadata)) } - override protected def outputDataType: DataType = new VectorUDT() + override def transformSchema(schema: StructType): StructType = { + val inputType = schema($(inputCol)).dataType + require(inputType.isInstanceOf[ArrayType], + s"The input column must be ArrayType, but got $inputType.") + val attrGroup = new AttributeGroup($(outputCol), $(numFeatures)) + SchemaUtils.appendColumn(schema, attrGroup.toStructField()) + } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/HashingTFSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/HashingTFSuite.scala new file mode 100644 index 0000000000..2e4beb0bff --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/HashingTFSuite.scala @@ -0,0 +1,55 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.feature + +import org.scalatest.FunSuite + +import org.apache.spark.ml.attribute.AttributeGroup +import org.apache.spark.ml.param.ParamsSuite +import org.apache.spark.mllib.linalg.{Vector, Vectors} +import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.mllib.util.TestingUtils._ +import org.apache.spark.util.Utils + +class HashingTFSuite extends FunSuite with MLlibTestSparkContext { + + test("params") { + val hashingTF = new HashingTF + ParamsSuite.checkParams(hashingTF, 3) + } + + test("hashingTF") { + val df = sqlContext.createDataFrame(Seq( + (0, "a a b b c d".split(" ").toSeq) + )).toDF("id", "words") + val n = 100 + val hashingTF = new HashingTF() + .setInputCol("words") + .setOutputCol("features") + .setNumFeatures(n) + val output = hashingTF.transform(df) + val attrGroup = AttributeGroup.fromStructField(output.schema("features")) + require(attrGroup.numAttributes === Some(n)) + val features = output.select("features").first().getAs[Vector](0) + // Assume perfect hash on "a", "b", "c", and "d". + def idx(any: Any): Int = Utils.nonNegativeMod(any.##, n) + val expected = Vectors.sparse(n, + Seq((idx("a"), 2.0), (idx("b"), 2.0), (idx("c"), 1.0), (idx("d"), 1.0))) + assert(features ~== expected absTol 1e-14) + } +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala index b96874f3a8..d270ad7613 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala @@ -201,3 +201,23 @@ class ParamsSuite extends FunSuite { assert(inArray(1) && inArray(2) && !inArray(0)) } } + +object ParamsSuite extends FunSuite { + + /** + * Checks common requirements for [[Params.params]]: 1) number of params; 2) params are ordered + * by names; 3) param parent has the same UID as the object's UID; 4) param name is the same as + * the param method name. + */ + def checkParams(obj: Params, expectedNumParams: Int): Unit = { + val params = obj.params + require(params.length === expectedNumParams, + s"Expect $expectedNumParams params but got ${params.length}: ${params.map(_.name).toSeq}.") + val paramNames = params.map(_.name) + require(paramNames === paramNames.sorted) + params.foreach { p => + assert(p.parent === obj.uid) + assert(obj.getParam(p.name) === p) + } + } +} -- cgit v1.2.3