aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
authorFeynman Liang <fliang@databricks.com>2015-06-22 14:15:35 -0700
committerJoseph K. Bradley <joseph@databricks.com>2015-06-22 14:15:35 -0700
commitafe35f0519bc7dcb85010a7eedcff854d4fc313a (patch)
treed123b5a16e88c8a5a3762df08bdfd450c5802bcf /mllib
parent5ab9fcfb01a0ad2f6c103f67c1a785d3b49e33f0 (diff)
downloadspark-afe35f0519bc7dcb85010a7eedcff854d4fc313a.tar.gz
spark-afe35f0519bc7dcb85010a7eedcff854d4fc313a.tar.bz2
spark-afe35f0519bc7dcb85010a7eedcff854d4fc313a.zip
[SPARK-8455] [ML] Implement n-gram feature transformer
Implementation of n-gram feature transformer for ML. Author: Feynman Liang <fliang@databricks.com> Closes #6887 from feynmanliang/ngram-featurizer and squashes the following commits: d2c839f [Feynman Liang] Make n > input length yield empty output 9fadd36 [Feynman Liang] Add empty and corner test cases, fix names and spaces fe93873 [Feynman Liang] Implement n-gram feature transformer
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/NGram.scala69
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/feature/NGramSuite.scala94
2 files changed, 163 insertions, 0 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/NGram.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/NGram.scala
new file mode 100644
index 0000000000..8de10eb51f
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/NGram.scala
@@ -0,0 +1,69 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.feature
+
+import org.apache.spark.annotation.Experimental
+import org.apache.spark.ml.UnaryTransformer
+import org.apache.spark.ml.param._
+import org.apache.spark.ml.util.Identifiable
+import org.apache.spark.sql.types.{ArrayType, DataType, StringType}
+
+/**
+ * :: Experimental ::
+ * A feature transformer that converts the input array of strings into an array of n-grams. Null
+ * values in the input array are ignored.
+ * It returns an array of n-grams where each n-gram is represented by a space-separated string of
+ * words.
+ *
+ * When the input is empty, an empty array is returned.
+ * When the input array length is less than n (number of elements per n-gram), no n-grams are
+ * returned.
+ */
+@Experimental
+class NGram(override val uid: String)
+ extends UnaryTransformer[Seq[String], Seq[String], NGram] {
+
+ def this() = this(Identifiable.randomUID("ngram"))
+
+ /**
+ * Minimum n-gram length, >= 1.
+ * Default: 2, bigram features
+ * @group param
+ */
+ val n: IntParam = new IntParam(this, "n", "number elements per n-gram (>=1)",
+ ParamValidators.gtEq(1))
+
+ /** @group setParam */
+ def setN(value: Int): this.type = set(n, value)
+
+ /** @group getParam */
+ def getN: Int = $(n)
+
+ setDefault(n -> 2)
+
+ override protected def createTransformFunc: Seq[String] => Seq[String] = {
+ _.iterator.sliding($(n)).withPartial(false).map(_.mkString(" ")).toSeq
+ }
+
+ override protected def validateInputType(inputType: DataType): Unit = {
+ require(inputType.sameType(ArrayType(StringType)),
+ s"Input type must be ArrayType(StringType) but got $inputType.")
+ }
+
+ override protected def outputDataType: DataType = new ArrayType(StringType, false)
+}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/NGramSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/NGramSuite.scala
new file mode 100644
index 0000000000..ab97e3dbc6
--- /dev/null
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/NGramSuite.scala
@@ -0,0 +1,94 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.feature
+
+import scala.beans.BeanInfo
+
+import org.apache.spark.SparkFunSuite
+import org.apache.spark.mllib.util.MLlibTestSparkContext
+import org.apache.spark.sql.{DataFrame, Row}
+
+@BeanInfo
+case class NGramTestData(inputTokens: Array[String], wantedNGrams: Array[String])
+
+class NGramSuite extends SparkFunSuite with MLlibTestSparkContext {
+ import org.apache.spark.ml.feature.NGramSuite._
+
+ test("default behavior yields bigram features") {
+ val nGram = new NGram()
+ .setInputCol("inputTokens")
+ .setOutputCol("nGrams")
+ val dataset = sqlContext.createDataFrame(Seq(
+ NGramTestData(
+ Array("Test", "for", "ngram", "."),
+ Array("Test for", "for ngram", "ngram .")
+ )))
+ testNGram(nGram, dataset)
+ }
+
+ test("NGramLength=4 yields length 4 n-grams") {
+ val nGram = new NGram()
+ .setInputCol("inputTokens")
+ .setOutputCol("nGrams")
+ .setN(4)
+ val dataset = sqlContext.createDataFrame(Seq(
+ NGramTestData(
+ Array("a", "b", "c", "d", "e"),
+ Array("a b c d", "b c d e")
+ )))
+ testNGram(nGram, dataset)
+ }
+
+ test("empty input yields empty output") {
+ val nGram = new NGram()
+ .setInputCol("inputTokens")
+ .setOutputCol("nGrams")
+ .setN(4)
+ val dataset = sqlContext.createDataFrame(Seq(
+ NGramTestData(
+ Array(),
+ Array()
+ )))
+ testNGram(nGram, dataset)
+ }
+
+ test("input array < n yields empty output") {
+ val nGram = new NGram()
+ .setInputCol("inputTokens")
+ .setOutputCol("nGrams")
+ .setN(6)
+ val dataset = sqlContext.createDataFrame(Seq(
+ NGramTestData(
+ Array("a", "b", "c", "d", "e"),
+ Array()
+ )))
+ testNGram(nGram, dataset)
+ }
+}
+
+object NGramSuite extends SparkFunSuite {
+
+ def testNGram(t: NGram, dataset: DataFrame): Unit = {
+ t.transform(dataset)
+ .select("nGrams", "wantedNGrams")
+ .collect()
+ .foreach { case Row(actualNGrams, wantedNGrams) =>
+ assert(actualNGrams === wantedNGrams)
+ }
+ }
+}