diff options
Diffstat (limited to 'python/pyspark/ml/feature.py')
-rw-r--r-- | python/pyspark/ml/feature.py | 72 |
1 files changed, 58 insertions, 14 deletions
diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py index e088acd0ca..f1ddbb478d 100644 --- a/python/pyspark/ml/feature.py +++ b/python/pyspark/ml/feature.py @@ -16,7 +16,7 @@ # from pyspark.ml.param.shared import HasInputCol, HasOutputCol, HasNumFeatures -from pyspark.ml.util import inherit_doc +from pyspark.ml.util import inherit_doc, keyword_only from pyspark.ml.wrapper import JavaTransformer __all__ = ['Tokenizer', 'HashingTF'] @@ -29,18 +29,45 @@ class Tokenizer(JavaTransformer, HasInputCol, HasOutputCol): splits it by white spaces. >>> from pyspark.sql import Row - >>> dataset = sqlCtx.inferSchema(sc.parallelize([Row(text="a b c")])) - >>> tokenizer = Tokenizer() \ - .setInputCol("text") \ - .setOutputCol("words") - >>> print tokenizer.transform(dataset).head() + >>> df = sc.parallelize([Row(text="a b c")]).toDF() + >>> tokenizer = Tokenizer(inputCol="text", outputCol="words") + >>> print tokenizer.transform(df).head() Row(text=u'a b c', words=[u'a', u'b', u'c']) - >>> print tokenizer.transform(dataset, {tokenizer.outputCol: "tokens"}).head() + >>> # Change a parameter. + >>> print tokenizer.setParams(outputCol="tokens").transform(df).head() Row(text=u'a b c', tokens=[u'a', u'b', u'c']) + >>> # Temporarily modify a parameter. + >>> print tokenizer.transform(df, {tokenizer.outputCol: "words"}).head() + Row(text=u'a b c', words=[u'a', u'b', u'c']) + >>> print tokenizer.transform(df).head() + Row(text=u'a b c', tokens=[u'a', u'b', u'c']) + >>> # Must use keyword arguments to specify params. + >>> tokenizer.setParams("text") + Traceback (most recent call last): + ... + TypeError: Method setParams forces keyword arguments. """ _java_class = "org.apache.spark.ml.feature.Tokenizer" + @keyword_only + def __init__(self, inputCol="input", outputCol="output"): + """ + __init__(self, inputCol="input", outputCol="output") + """ + super(Tokenizer, self).__init__() + kwargs = self.__init__._input_kwargs + self.setParams(**kwargs) + + @keyword_only + def setParams(self, inputCol="input", outputCol="output"): + """ + setParams(self, inputCol="input", outputCol="output") + Sets params for this Tokenizer. + """ + kwargs = self.setParams._input_kwargs + return self._set_params(**kwargs) + @inherit_doc class HashingTF(JavaTransformer, HasInputCol, HasOutputCol, HasNumFeatures): @@ -49,20 +76,37 @@ class HashingTF(JavaTransformer, HasInputCol, HasOutputCol, HasNumFeatures): hashing trick. >>> from pyspark.sql import Row - >>> dataset = sqlCtx.inferSchema(sc.parallelize([Row(words=["a", "b", "c"])])) - >>> hashingTF = HashingTF() \ - .setNumFeatures(10) \ - .setInputCol("words") \ - .setOutputCol("features") - >>> print hashingTF.transform(dataset).head().features + >>> df = sc.parallelize([Row(words=["a", "b", "c"])]).toDF() + >>> hashingTF = HashingTF(numFeatures=10, inputCol="words", outputCol="features") + >>> print hashingTF.transform(df).head().features + (10,[7,8,9],[1.0,1.0,1.0]) + >>> print hashingTF.setParams(outputCol="freqs").transform(df).head().freqs (10,[7,8,9],[1.0,1.0,1.0]) >>> params = {hashingTF.numFeatures: 5, hashingTF.outputCol: "vector"} - >>> print hashingTF.transform(dataset, params).head().vector + >>> print hashingTF.transform(df, params).head().vector (5,[2,3,4],[1.0,1.0,1.0]) """ _java_class = "org.apache.spark.ml.feature.HashingTF" + @keyword_only + def __init__(self, numFeatures=1 << 18, inputCol="input", outputCol="output"): + """ + __init__(self, numFeatures=1 << 18, inputCol="input", outputCol="output") + """ + super(HashingTF, self).__init__() + kwargs = self.__init__._input_kwargs + self.setParams(**kwargs) + + @keyword_only + def setParams(self, numFeatures=1 << 18, inputCol="input", outputCol="output"): + """ + setParams(self, numFeatures=1 << 18, inputCol="input", outputCol="output") + Sets params for this HashingTF. + """ + kwargs = self.setParams._input_kwargs + return self._set_params(**kwargs) + if __name__ == "__main__": import doctest |