aboutsummaryrefslogtreecommitdiff
path: root/python
diff options
context:
space:
mode:
authorFeynman Liang <fliang@databricks.com>2015-06-29 18:40:30 -0700
committerJoseph K. Bradley <joseph@databricks.com>2015-06-29 18:40:30 -0700
commit620605a4a1123afaab2674e38251f1231dea17ce (patch)
tree2fec235613a66fb012193e8fae90902c6657b63d /python
parent4c1808be4d3aaa37a5a878892e91ca73ea405ffa (diff)
downloadspark-620605a4a1123afaab2674e38251f1231dea17ce.tar.gz
spark-620605a4a1123afaab2674e38251f1231dea17ce.tar.bz2
spark-620605a4a1123afaab2674e38251f1231dea17ce.zip
[SPARK-8456] [ML] Ngram featurizer python
Python API for N-gram feature transformer Author: Feynman Liang <fliang@databricks.com> Closes #6960 from feynmanliang/ngram-featurizer-python and squashes the following commits: f9e37c9 [Feynman Liang] Remove debugging code 4dd81f4 [Feynman Liang] Fix typo and doctest 06c79ac [Feynman Liang] Style guide 26c1175 [Feynman Liang] Add python NGram API
Diffstat (limited to 'python')
-rw-r--r--python/pyspark/ml/feature.py71
-rw-r--r--python/pyspark/ml/tests.py11
2 files changed, 81 insertions, 1 deletions
diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py
index ddb33f427a..8804dace84 100644
--- a/python/pyspark/ml/feature.py
+++ b/python/pyspark/ml/feature.py
@@ -21,7 +21,7 @@ from pyspark.ml.util import keyword_only
from pyspark.ml.wrapper import JavaEstimator, JavaModel, JavaTransformer
from pyspark.mllib.common import inherit_doc
-__all__ = ['Binarizer', 'HashingTF', 'IDF', 'IDFModel', 'Normalizer', 'OneHotEncoder',
+__all__ = ['Binarizer', 'HashingTF', 'IDF', 'IDFModel', 'NGram', 'Normalizer', 'OneHotEncoder',
'PolynomialExpansion', 'RegexTokenizer', 'StandardScaler', 'StandardScalerModel',
'StringIndexer', 'StringIndexerModel', 'Tokenizer', 'VectorAssembler', 'VectorIndexer',
'Word2Vec', 'Word2VecModel']
@@ -266,6 +266,75 @@ class IDFModel(JavaModel):
@inherit_doc
+@ignore_unicode_prefix
+class NGram(JavaTransformer, HasInputCol, HasOutputCol):
+ """
+ A feature transformer that converts the input array of strings into an array of n-grams. Null
+ values in the input array are ignored.
+ It returns an array of n-grams where each n-gram is represented by a space-separated string of
+ words.
+ When the input is empty, an empty array is returned.
+ When the input array length is less than n (number of elements per n-gram), no n-grams are
+ returned.
+
+ >>> df = sqlContext.createDataFrame([Row(inputTokens=["a", "b", "c", "d", "e"])])
+ >>> ngram = NGram(n=2, inputCol="inputTokens", outputCol="nGrams")
+ >>> ngram.transform(df).head()
+ Row(inputTokens=[u'a', u'b', u'c', u'd', u'e'], nGrams=[u'a b', u'b c', u'c d', u'd e'])
+ >>> # Change n-gram length
+ >>> ngram.setParams(n=4).transform(df).head()
+ Row(inputTokens=[u'a', u'b', u'c', u'd', u'e'], nGrams=[u'a b c d', u'b c d e'])
+ >>> # Temporarily modify output column.
+ >>> ngram.transform(df, {ngram.outputCol: "output"}).head()
+ Row(inputTokens=[u'a', u'b', u'c', u'd', u'e'], output=[u'a b c d', u'b c d e'])
+ >>> ngram.transform(df).head()
+ Row(inputTokens=[u'a', u'b', u'c', u'd', u'e'], nGrams=[u'a b c d', u'b c d e'])
+ >>> # Must use keyword arguments to specify params.
+ >>> ngram.setParams("text")
+ Traceback (most recent call last):
+ ...
+ TypeError: Method setParams forces keyword arguments.
+ """
+
+ # a placeholder to make it appear in the generated doc
+ n = Param(Params._dummy(), "n", "number of elements per n-gram (>=1)")
+
+ @keyword_only
+ def __init__(self, n=2, inputCol=None, outputCol=None):
+ """
+ __init__(self, n=2, inputCol=None, outputCol=None)
+ """
+ super(NGram, self).__init__()
+ self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.NGram", self.uid)
+ self.n = Param(self, "n", "number of elements per n-gram (>=1)")
+ self._setDefault(n=2)
+ kwargs = self.__init__._input_kwargs
+ self.setParams(**kwargs)
+
+ @keyword_only
+ def setParams(self, n=2, inputCol=None, outputCol=None):
+ """
+ setParams(self, n=2, inputCol=None, outputCol=None)
+ Sets params for this NGram.
+ """
+ kwargs = self.setParams._input_kwargs
+ return self._set(**kwargs)
+
+ def setN(self, value):
+ """
+ Sets the value of :py:attr:`n`.
+ """
+ self._paramMap[self.n] = value
+ return self
+
+ def getN(self):
+ """
+ Gets the value of n or its default value.
+ """
+ return self.getOrDefault(self.n)
+
+
+@inherit_doc
class Normalizer(JavaTransformer, HasInputCol, HasOutputCol):
"""
Normalize a vector to have unit norm using the given p-norm.
diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py
index 6adbf166f3..c151d21fd6 100644
--- a/python/pyspark/ml/tests.py
+++ b/python/pyspark/ml/tests.py
@@ -252,6 +252,17 @@ class FeatureTests(PySparkTestCase):
output = idf0m.transform(dataset)
self.assertIsNotNone(output.head().idf)
+ def test_ngram(self):
+ sqlContext = SQLContext(self.sc)
+ dataset = sqlContext.createDataFrame([
+ ([["a", "b", "c", "d", "e"]])], ["input"])
+ ngram0 = NGram(n=4, inputCol="input", outputCol="output")
+ self.assertEqual(ngram0.getN(), 4)
+ self.assertEqual(ngram0.getInputCol(), "input")
+ self.assertEqual(ngram0.getOutputCol(), "output")
+ transformedDF = ngram0.transform(dataset)
+ self.assertEquals(transformedDF.head().output, ["a b c d", "b c d e"])
+
if __name__ == "__main__":
unittest.main()