diff options
author | Feynman Liang <fliang@databricks.com> | 2015-06-22 14:15:35 -0700 |
---|---|---|
committer | Joseph K. Bradley <joseph@databricks.com> | 2015-06-22 14:15:35 -0700 |
commit | afe35f0519bc7dcb85010a7eedcff854d4fc313a (patch) | |
tree | d123b5a16e88c8a5a3762df08bdfd450c5802bcf /mllib/src/main | |
parent | 5ab9fcfb01a0ad2f6c103f67c1a785d3b49e33f0 (diff) | |
download | spark-afe35f0519bc7dcb85010a7eedcff854d4fc313a.tar.gz spark-afe35f0519bc7dcb85010a7eedcff854d4fc313a.tar.bz2 spark-afe35f0519bc7dcb85010a7eedcff854d4fc313a.zip |
[SPARK-8455] [ML] Implement n-gram feature transformer
Implementation of n-gram feature transformer for ML.
Author: Feynman Liang <fliang@databricks.com>
Closes #6887 from feynmanliang/ngram-featurizer and squashes the following commits:
d2c839f [Feynman Liang] Make n > input length yield empty output
9fadd36 [Feynman Liang] Add empty and corner test cases, fix names and spaces
fe93873 [Feynman Liang] Implement n-gram feature transformer
Diffstat (limited to 'mllib/src/main')
-rw-r--r-- | mllib/src/main/scala/org/apache/spark/ml/feature/NGram.scala | 69 |
1 files changed, 69 insertions, 0 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/NGram.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/NGram.scala new file mode 100644 index 0000000000..8de10eb51f --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/NGram.scala @@ -0,0 +1,69 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.feature + +import org.apache.spark.annotation.Experimental +import org.apache.spark.ml.UnaryTransformer +import org.apache.spark.ml.param._ +import org.apache.spark.ml.util.Identifiable +import org.apache.spark.sql.types.{ArrayType, DataType, StringType} + +/** + * :: Experimental :: + * A feature transformer that converts the input array of strings into an array of n-grams. Null + * values in the input array are ignored. + * It returns an array of n-grams where each n-gram is represented by a space-separated string of + * words. + * + * When the input is empty, an empty array is returned. + * When the input array length is less than n (number of elements per n-gram), no n-grams are + * returned. + */ +@Experimental +class NGram(override val uid: String) + extends UnaryTransformer[Seq[String], Seq[String], NGram] { + + def this() = this(Identifiable.randomUID("ngram")) + + /** + * Minimum n-gram length, >= 1. + * Default: 2, bigram features + * @group param + */ + val n: IntParam = new IntParam(this, "n", "number elements per n-gram (>=1)", + ParamValidators.gtEq(1)) + + /** @group setParam */ + def setN(value: Int): this.type = set(n, value) + + /** @group getParam */ + def getN: Int = $(n) + + setDefault(n -> 2) + + override protected def createTransformFunc: Seq[String] => Seq[String] = { + _.iterator.sliding($(n)).withPartial(false).map(_.mkString(" ")).toSeq + } + + override protected def validateInputType(inputType: DataType): Unit = { + require(inputType.sameType(ArrayType(StringType)), + s"Input type must be ArrayType(StringType) but got $inputType.") + } + + override protected def outputDataType: DataType = new ArrayType(StringType, false) +} |