aboutsummaryrefslogtreecommitdiff
path: root/python
diff options
context:
space:
mode:
Diffstat (limited to 'python')
-rw-r--r--python/docs/conf.py4
-rw-r--r--python/docs/index.rst1
-rw-r--r--python/docs/pyspark.ml.rst29
-rw-r--r--python/docs/pyspark.rst1
-rw-r--r--python/pyspark/ml/__init__.py21
-rw-r--r--python/pyspark/ml/classification.py76
-rw-r--r--python/pyspark/ml/feature.py82
-rw-r--r--python/pyspark/ml/param/__init__.py82
-rw-r--r--python/pyspark/ml/param/_gen_shared_params.py98
-rw-r--r--python/pyspark/ml/param/shared.py260
-rw-r--r--python/pyspark/ml/pipeline.py154
-rw-r--r--python/pyspark/ml/tests.py115
-rw-r--r--python/pyspark/ml/util.py46
-rw-r--r--python/pyspark/ml/wrapper.py149
-rw-r--r--python/pyspark/sql.py14
-rwxr-xr-xpython/run-tests8
16 files changed, 1124 insertions, 16 deletions
diff --git a/python/docs/conf.py b/python/docs/conf.py
index e58d97ae6a..b00dce95d6 100644
--- a/python/docs/conf.py
+++ b/python/docs/conf.py
@@ -55,9 +55,9 @@ copyright = u'2014, Author'
# built documents.
#
# The short X.Y version.
-version = '1.2-SNAPSHOT'
+version = '1.3-SNAPSHOT'
# The full version, including alpha/beta/rc tags.
-release = '1.2-SNAPSHOT'
+release = '1.3-SNAPSHOT'
# The language for content autogenerated by Sphinx. Refer to documentation
# for a list of supported languages.
diff --git a/python/docs/index.rst b/python/docs/index.rst
index 703bef644d..d150de9d5c 100644
--- a/python/docs/index.rst
+++ b/python/docs/index.rst
@@ -14,6 +14,7 @@ Contents:
pyspark
pyspark.sql
pyspark.streaming
+ pyspark.ml
pyspark.mllib
diff --git a/python/docs/pyspark.ml.rst b/python/docs/pyspark.ml.rst
new file mode 100644
index 0000000000..f10d1339a9
--- /dev/null
+++ b/python/docs/pyspark.ml.rst
@@ -0,0 +1,29 @@
+pyspark.ml package
+=====================
+
+Submodules
+----------
+
+pyspark.ml module
+-----------------
+
+.. automodule:: pyspark.ml
+ :members:
+ :undoc-members:
+ :inherited-members:
+
+pyspark.ml.feature module
+-------------------------
+
+.. automodule:: pyspark.ml.feature
+ :members:
+ :undoc-members:
+ :inherited-members:
+
+pyspark.ml.classification module
+--------------------------------
+
+.. automodule:: pyspark.ml.classification
+ :members:
+ :undoc-members:
+ :inherited-members:
diff --git a/python/docs/pyspark.rst b/python/docs/pyspark.rst
index e81be3b6cb..0df12c49ad 100644
--- a/python/docs/pyspark.rst
+++ b/python/docs/pyspark.rst
@@ -9,6 +9,7 @@ Subpackages
pyspark.sql
pyspark.streaming
+ pyspark.ml
pyspark.mllib
Contents
diff --git a/python/pyspark/ml/__init__.py b/python/pyspark/ml/__init__.py
new file mode 100644
index 0000000000..47fed80f42
--- /dev/null
+++ b/python/pyspark/ml/__init__.py
@@ -0,0 +1,21 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+from pyspark.ml.param import *
+from pyspark.ml.pipeline import *
+
+__all__ = ["Param", "Params", "Transformer", "Estimator", "Pipeline"]
diff --git a/python/pyspark/ml/classification.py b/python/pyspark/ml/classification.py
new file mode 100644
index 0000000000..6bd2aa8e47
--- /dev/null
+++ b/python/pyspark/ml/classification.py
@@ -0,0 +1,76 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+from pyspark.ml.util import inherit_doc
+from pyspark.ml.wrapper import JavaEstimator, JavaModel
+from pyspark.ml.param.shared import HasFeaturesCol, HasLabelCol, HasPredictionCol, HasMaxIter,\
+ HasRegParam
+
+
+__all__ = ['LogisticRegression', 'LogisticRegressionModel']
+
+
+@inherit_doc
+class LogisticRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, HasMaxIter,
+ HasRegParam):
+ """
+ Logistic regression.
+
+ >>> from pyspark.sql import Row
+ >>> from pyspark.mllib.linalg import Vectors
+ >>> dataset = sqlCtx.inferSchema(sc.parallelize([ \
+ Row(label=1.0, features=Vectors.dense(1.0)), \
+ Row(label=0.0, features=Vectors.sparse(1, [], []))]))
+ >>> lr = LogisticRegression() \
+ .setMaxIter(5) \
+ .setRegParam(0.01)
+ >>> model = lr.fit(dataset)
+ >>> test0 = sqlCtx.inferSchema(sc.parallelize([Row(features=Vectors.dense(-1.0))]))
+ >>> print model.transform(test0).head().prediction
+ 0.0
+ >>> test1 = sqlCtx.inferSchema(sc.parallelize([Row(features=Vectors.sparse(1, [0], [1.0]))]))
+ >>> print model.transform(test1).head().prediction
+ 1.0
+ """
+ _java_class = "org.apache.spark.ml.classification.LogisticRegression"
+
+ def _create_model(self, java_model):
+ return LogisticRegressionModel(java_model)
+
+
+class LogisticRegressionModel(JavaModel):
+ """
+ Model fitted by LogisticRegression.
+ """
+
+
+if __name__ == "__main__":
+ import doctest
+ from pyspark.context import SparkContext
+ from pyspark.sql import SQLContext
+ globs = globals().copy()
+ # The small batch size here ensures that we see multiple batches,
+ # even in these small test examples:
+ sc = SparkContext("local[2]", "ml.feature tests")
+ sqlCtx = SQLContext(sc)
+ globs['sc'] = sc
+ globs['sqlCtx'] = sqlCtx
+ (failure_count, test_count) = doctest.testmod(
+ globs=globs, optionflags=doctest.ELLIPSIS)
+ sc.stop()
+ if failure_count:
+ exit(-1)
diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py
new file mode 100644
index 0000000000..e088acd0ca
--- /dev/null
+++ b/python/pyspark/ml/feature.py
@@ -0,0 +1,82 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+from pyspark.ml.param.shared import HasInputCol, HasOutputCol, HasNumFeatures
+from pyspark.ml.util import inherit_doc
+from pyspark.ml.wrapper import JavaTransformer
+
+__all__ = ['Tokenizer', 'HashingTF']
+
+
+@inherit_doc
+class Tokenizer(JavaTransformer, HasInputCol, HasOutputCol):
+ """
+ A tokenizer that converts the input string to lowercase and then
+ splits it by white spaces.
+
+ >>> from pyspark.sql import Row
+ >>> dataset = sqlCtx.inferSchema(sc.parallelize([Row(text="a b c")]))
+ >>> tokenizer = Tokenizer() \
+ .setInputCol("text") \
+ .setOutputCol("words")
+ >>> print tokenizer.transform(dataset).head()
+ Row(text=u'a b c', words=[u'a', u'b', u'c'])
+ >>> print tokenizer.transform(dataset, {tokenizer.outputCol: "tokens"}).head()
+ Row(text=u'a b c', tokens=[u'a', u'b', u'c'])
+ """
+
+ _java_class = "org.apache.spark.ml.feature.Tokenizer"
+
+
+@inherit_doc
+class HashingTF(JavaTransformer, HasInputCol, HasOutputCol, HasNumFeatures):
+ """
+ Maps a sequence of terms to their term frequencies using the
+ hashing trick.
+
+ >>> from pyspark.sql import Row
+ >>> dataset = sqlCtx.inferSchema(sc.parallelize([Row(words=["a", "b", "c"])]))
+ >>> hashingTF = HashingTF() \
+ .setNumFeatures(10) \
+ .setInputCol("words") \
+ .setOutputCol("features")
+ >>> print hashingTF.transform(dataset).head().features
+ (10,[7,8,9],[1.0,1.0,1.0])
+ >>> params = {hashingTF.numFeatures: 5, hashingTF.outputCol: "vector"}
+ >>> print hashingTF.transform(dataset, params).head().vector
+ (5,[2,3,4],[1.0,1.0,1.0])
+ """
+
+ _java_class = "org.apache.spark.ml.feature.HashingTF"
+
+
+if __name__ == "__main__":
+ import doctest
+ from pyspark.context import SparkContext
+ from pyspark.sql import SQLContext
+ globs = globals().copy()
+ # The small batch size here ensures that we see multiple batches,
+ # even in these small test examples:
+ sc = SparkContext("local[2]", "ml.feature tests")
+ sqlCtx = SQLContext(sc)
+ globs['sc'] = sc
+ globs['sqlCtx'] = sqlCtx
+ (failure_count, test_count) = doctest.testmod(
+ globs=globs, optionflags=doctest.ELLIPSIS)
+ sc.stop()
+ if failure_count:
+ exit(-1)
diff --git a/python/pyspark/ml/param/__init__.py b/python/pyspark/ml/param/__init__.py
new file mode 100644
index 0000000000..5566792cea
--- /dev/null
+++ b/python/pyspark/ml/param/__init__.py
@@ -0,0 +1,82 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+from abc import ABCMeta
+
+from pyspark.ml.util import Identifiable
+
+
+__all__ = ['Param', 'Params']
+
+
+class Param(object):
+ """
+ A param with self-contained documentation and optionally default value.
+ """
+
+ def __init__(self, parent, name, doc, defaultValue=None):
+ if not isinstance(parent, Identifiable):
+ raise ValueError("Parent must be identifiable but got type %s." % type(parent).__name__)
+ self.parent = parent
+ self.name = str(name)
+ self.doc = str(doc)
+ self.defaultValue = defaultValue
+
+ def __str__(self):
+ return str(self.parent) + "-" + self.name
+
+ def __repr__(self):
+ return "Param(parent=%r, name=%r, doc=%r, defaultValue=%r)" % \
+ (self.parent, self.name, self.doc, self.defaultValue)
+
+
+class Params(Identifiable):
+ """
+ Components that take parameters. This also provides an internal
+ param map to store parameter values attached to the instance.
+ """
+
+ __metaclass__ = ABCMeta
+
+ def __init__(self):
+ super(Params, self).__init__()
+ #: embedded param map
+ self.paramMap = {}
+
+ @property
+ def params(self):
+ """
+ Returns all params. The default implementation uses
+ :py:func:`dir` to get all attributes of type
+ :py:class:`Param`.
+ """
+ return filter(lambda attr: isinstance(attr, Param),
+ [getattr(self, x) for x in dir(self) if x != "params"])
+
+ def _merge_params(self, params):
+ paramMap = self.paramMap.copy()
+ paramMap.update(params)
+ return paramMap
+
+ @staticmethod
+ def _dummy():
+ """
+ Returns a dummy Params instance used as a placeholder to generate docs.
+ """
+ dummy = Params()
+ dummy.uid = "undefined"
+ return dummy
diff --git a/python/pyspark/ml/param/_gen_shared_params.py b/python/pyspark/ml/param/_gen_shared_params.py
new file mode 100644
index 0000000000..5eb81106f1
--- /dev/null
+++ b/python/pyspark/ml/param/_gen_shared_params.py
@@ -0,0 +1,98 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+header = """#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#"""
+
+
+def _gen_param_code(name, doc, defaultValue):
+ """
+ Generates Python code for a shared param class.
+
+ :param name: param name
+ :param doc: param doc
+ :param defaultValue: string representation of the param
+ :return: code string
+ """
+ # TODO: How to correctly inherit instance attributes?
+ template = '''class Has$Name(Params):
+ """
+ Params with $name.
+ """
+
+ # a placeholder to make it appear in the generated doc
+ $name = Param(Params._dummy(), "$name", "$doc", $defaultValue)
+
+ def __init__(self):
+ super(Has$Name, self).__init__()
+ #: param for $doc
+ self.$name = Param(self, "$name", "$doc", $defaultValue)
+
+ def set$Name(self, value):
+ """
+ Sets the value of :py:attr:`$name`.
+ """
+ self.paramMap[self.$name] = value
+ return self
+
+ def get$Name(self):
+ """
+ Gets the value of $name or its default value.
+ """
+ if self.$name in self.paramMap:
+ return self.paramMap[self.$name]
+ else:
+ return self.$name.defaultValue'''
+
+ upperCamelName = name[0].upper() + name[1:]
+ return template \
+ .replace("$name", name) \
+ .replace("$Name", upperCamelName) \
+ .replace("$doc", doc) \
+ .replace("$defaultValue", defaultValue)
+
+if __name__ == "__main__":
+ print header
+ print "\n# DO NOT MODIFY. The code is generated by _gen_shared_params.py.\n"
+ print "from pyspark.ml.param import Param, Params\n\n"
+ shared = [
+ ("maxIter", "max number of iterations", "100"),
+ ("regParam", "regularization constant", "0.1"),
+ ("featuresCol", "features column name", "'features'"),
+ ("labelCol", "label column name", "'label'"),
+ ("predictionCol", "prediction column name", "'prediction'"),
+ ("inputCol", "input column name", "'input'"),
+ ("outputCol", "output column name", "'output'"),
+ ("numFeatures", "number of features", "1 << 18")]
+ code = []
+ for name, doc, defaultValue in shared:
+ code.append(_gen_param_code(name, doc, defaultValue))
+ print "\n\n\n".join(code)
diff --git a/python/pyspark/ml/param/shared.py b/python/pyspark/ml/param/shared.py
new file mode 100644
index 0000000000..586822f2de
--- /dev/null
+++ b/python/pyspark/ml/param/shared.py
@@ -0,0 +1,260 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+# DO NOT MODIFY. The code is generated by _gen_shared_params.py.
+
+from pyspark.ml.param import Param, Params
+
+
+class HasMaxIter(Params):
+ """
+ Params with maxIter.
+ """
+
+ # a placeholder to make it appear in the generated doc
+ maxIter = Param(Params._dummy(), "maxIter", "max number of iterations", 100)
+
+ def __init__(self):
+ super(HasMaxIter, self).__init__()
+ #: param for max number of iterations
+ self.maxIter = Param(self, "maxIter", "max number of iterations", 100)
+
+ def setMaxIter(self, value):
+ """
+ Sets the value of :py:attr:`maxIter`.
+ """
+ self.paramMap[self.maxIter] = value
+ return self
+
+ def getMaxIter(self):
+ """
+ Gets the value of maxIter or its default value.
+ """
+ if self.maxIter in self.paramMap:
+ return self.paramMap[self.maxIter]
+ else:
+ return self.maxIter.defaultValue
+
+
+class HasRegParam(Params):
+ """
+ Params with regParam.
+ """
+
+ # a placeholder to make it appear in the generated doc
+ regParam = Param(Params._dummy(), "regParam", "regularization constant", 0.1)
+
+ def __init__(self):
+ super(HasRegParam, self).__init__()
+ #: param for regularization constant
+ self.regParam = Param(self, "regParam", "regularization constant", 0.1)
+
+ def setRegParam(self, value):
+ """
+ Sets the value of :py:attr:`regParam`.
+ """
+ self.paramMap[self.regParam] = value
+ return self
+
+ def getRegParam(self):
+ """
+ Gets the value of regParam or its default value.
+ """
+ if self.regParam in self.paramMap:
+ return self.paramMap[self.regParam]
+ else:
+ return self.regParam.defaultValue
+
+
+class HasFeaturesCol(Params):
+ """
+ Params with featuresCol.
+ """
+
+ # a placeholder to make it appear in the generated doc
+ featuresCol = Param(Params._dummy(), "featuresCol", "features column name", 'features')
+
+ def __init__(self):
+ super(HasFeaturesCol, self).__init__()
+ #: param for features column name
+ self.featuresCol = Param(self, "featuresCol", "features column name", 'features')
+
+ def setFeaturesCol(self, value):
+ """
+ Sets the value of :py:attr:`featuresCol`.
+ """
+ self.paramMap[self.featuresCol] = value
+ return self
+
+ def getFeaturesCol(self):
+ """
+ Gets the value of featuresCol or its default value.
+ """
+ if self.featuresCol in self.paramMap:
+ return self.paramMap[self.featuresCol]
+ else:
+ return self.featuresCol.defaultValue
+
+
+class HasLabelCol(Params):
+ """
+ Params with labelCol.
+ """
+
+ # a placeholder to make it appear in the generated doc
+ labelCol = Param(Params._dummy(), "labelCol", "label column name", 'label')
+
+ def __init__(self):
+ super(HasLabelCol, self).__init__()
+ #: param for label column name
+ self.labelCol = Param(self, "labelCol", "label column name", 'label')
+
+ def setLabelCol(self, value):
+ """
+ Sets the value of :py:attr:`labelCol`.
+ """
+ self.paramMap[self.labelCol] = value
+ return self
+
+ def getLabelCol(self):
+ """
+ Gets the value of labelCol or its default value.
+ """
+ if self.labelCol in self.paramMap:
+ return self.paramMap[self.labelCol]
+ else:
+ return self.labelCol.defaultValue
+
+
+class HasPredictionCol(Params):
+ """
+ Params with predictionCol.
+ """
+
+ # a placeholder to make it appear in the generated doc
+ predictionCol = Param(Params._dummy(), "predictionCol", "prediction column name", 'prediction')
+
+ def __init__(self):
+ super(HasPredictionCol, self).__init__()
+ #: param for prediction column name
+ self.predictionCol = Param(self, "predictionCol", "prediction column name", 'prediction')
+
+ def setPredictionCol(self, value):
+ """
+ Sets the value of :py:attr:`predictionCol`.
+ """
+ self.paramMap[self.predictionCol] = value
+ return self
+
+ def getPredictionCol(self):
+ """
+ Gets the value of predictionCol or its default value.
+ """
+ if self.predictionCol in self.paramMap:
+ return self.paramMap[self.predictionCol]
+ else:
+ return self.predictionCol.defaultValue
+
+
+class HasInputCol(Params):
+ """
+ Params with inputCol.
+ """
+
+ # a placeholder to make it appear in the generated doc
+ inputCol = Param(Params._dummy(), "inputCol", "input column name", 'input')
+
+ def __init__(self):
+ super(HasInputCol, self).__init__()
+ #: param for input column name
+ self.inputCol = Param(self, "inputCol", "input column name", 'input')
+
+ def setInputCol(self, value):
+ """
+ Sets the value of :py:attr:`inputCol`.
+ """
+ self.paramMap[self.inputCol] = value
+ return self
+
+ def getInputCol(self):
+ """
+ Gets the value of inputCol or its default value.
+ """
+ if self.inputCol in self.paramMap:
+ return self.paramMap[self.inputCol]
+ else:
+ return self.inputCol.defaultValue
+
+
+class HasOutputCol(Params):
+ """
+ Params with outputCol.
+ """
+
+ # a placeholder to make it appear in the generated doc
+ outputCol = Param(Params._dummy(), "outputCol", "output column name", 'output')
+
+ def __init__(self):
+ super(HasOutputCol, self).__init__()
+ #: param for output column name
+ self.outputCol = Param(self, "outputCol", "output column name", 'output')
+
+ def setOutputCol(self, value):
+ """
+ Sets the value of :py:attr:`outputCol`.
+ """
+ self.paramMap[self.outputCol] = value
+ return self
+
+ def getOutputCol(self):
+ """
+ Gets the value of outputCol or its default value.
+ """
+ if self.outputCol in self.paramMap:
+ return self.paramMap[self.outputCol]
+ else:
+ return self.outputCol.defaultValue
+
+
+class HasNumFeatures(Params):
+ """
+ Params with numFeatures.
+ """
+
+ # a placeholder to make it appear in the generated doc
+ numFeatures = Param(Params._dummy(), "numFeatures", "number of features", 1 << 18)
+
+ def __init__(self):
+ super(HasNumFeatures, self).__init__()
+ #: param for number of features
+ self.numFeatures = Param(self, "numFeatures", "number of features", 1 << 18)
+
+ def setNumFeatures(self, value):
+ """
+ Sets the value of :py:attr:`numFeatures`.
+ """
+ self.paramMap[self.numFeatures] = value
+ return self
+
+ def getNumFeatures(self):
+ """
+ Gets the value of numFeatures or its default value.
+ """
+ if self.numFeatures in self.paramMap:
+ return self.paramMap[self.numFeatures]
+ else:
+ return self.numFeatures.defaultValue
diff --git a/python/pyspark/ml/pipeline.py b/python/pyspark/ml/pipeline.py
new file mode 100644
index 0000000000..2d239f8c80
--- /dev/null
+++ b/python/pyspark/ml/pipeline.py
@@ -0,0 +1,154 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+from abc import ABCMeta, abstractmethod
+
+from pyspark.ml.param import Param, Params
+from pyspark.ml.util import inherit_doc
+
+
+__all__ = ['Estimator', 'Transformer', 'Pipeline', 'PipelineModel']
+
+
+@inherit_doc
+class Estimator(Params):
+ """
+ Abstract class for estimators that fit models to data.
+ """
+
+ __metaclass__ = ABCMeta
+
+ @abstractmethod
+ def fit(self, dataset, params={}):
+ """
+ Fits a model to the input dataset with optional parameters.
+
+ :param dataset: input dataset, which is an instance of
+ :py:class:`pyspark.sql.SchemaRDD`
+ :param params: an optional param map that overwrites embedded
+ params
+ :returns: fitted model
+ """
+ raise NotImplementedError()
+
+
+@inherit_doc
+class Transformer(Params):
+ """
+ Abstract class for transformers that transform one dataset into
+ another.
+ """
+
+ __metaclass__ = ABCMeta
+
+ @abstractmethod
+ def transform(self, dataset, params={}):
+ """
+ Transforms the input dataset with optional parameters.
+
+ :param dataset: input dataset, which is an instance of
+ :py:class:`pyspark.sql.SchemaRDD`
+ :param params: an optional param map that overwrites embedded
+ params
+ :returns: transformed dataset
+ """
+ raise NotImplementedError()
+
+
+@inherit_doc
+class Pipeline(Estimator):
+ """
+ A simple pipeline, which acts as an estimator. A Pipeline consists
+ of a sequence of stages, each of which is either an
+ :py:class:`Estimator` or a :py:class:`Transformer`. When
+ :py:meth:`Pipeline.fit` is called, the stages are executed in
+ order. If a stage is an :py:class:`Estimator`, its
+ :py:meth:`Estimator.fit` method will be called on the input
+ dataset to fit a model. Then the model, which is a transformer,
+ will be used to transform the dataset as the input to the next
+ stage. If a stage is a :py:class:`Transformer`, its
+ :py:meth:`Transformer.transform` method will be called to produce
+ the dataset for the next stage. The fitted model from a
+ :py:class:`Pipeline` is an :py:class:`PipelineModel`, which
+ consists of fitted models and transformers, corresponding to the
+ pipeline stages. If there are no stages, the pipeline acts as an
+ identity transformer.
+ """
+
+ def __init__(self):
+ super(Pipeline, self).__init__()
+ #: Param for pipeline stages.
+ self.stages = Param(self, "stages", "pipeline stages")
+
+ def setStages(self, value):
+ """
+ Set pipeline stages.
+ :param value: a list of transformers or estimators
+ :return: the pipeline instance
+ """
+ self.paramMap[self.stages] = value
+ return self
+
+ def getStages(self):
+ """
+ Get pipeline stages.
+ """
+ if self.stages in self.paramMap:
+ return self.paramMap[self.stages]
+
+ def fit(self, dataset, params={}):
+ paramMap = self._merge_params(params)
+ stages = paramMap[self.stages]
+ for stage in stages:
+ if not (isinstance(stage, Estimator) or isinstance(stage, Transformer)):
+ raise ValueError(
+ "Cannot recognize a pipeline stage of type %s." % type(stage).__name__)
+ indexOfLastEstimator = -1
+ for i, stage in enumerate(stages):
+ if isinstance(stage, Estimator):
+ indexOfLastEstimator = i
+ transformers = []
+ for i, stage in enumerate(stages):
+ if i <= indexOfLastEstimator:
+ if isinstance(stage, Transformer):
+ transformers.append(stage)
+ dataset = stage.transform(dataset, paramMap)
+ else: # must be an Estimator
+ model = stage.fit(dataset, paramMap)
+ transformers.append(model)
+ if i < indexOfLastEstimator:
+ dataset = model.transform(dataset, paramMap)
+ else:
+ transformers.append(stage)
+ return PipelineModel(transformers)
+
+
+@inherit_doc
+class PipelineModel(Transformer):
+ """
+ Represents a compiled pipeline with transformers and fitted models.
+ """
+
+ def __init__(self, transformers):
+ super(PipelineModel, self).__init__()
+ self.transformers = transformers
+
+ def transform(self, dataset, params={}):
+ paramMap = self._merge_params(params)
+ for t in self.transformers:
+ dataset = t.transform(dataset, paramMap)
+ return dataset
diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py
new file mode 100644
index 0000000000..b627c2b4e9
--- /dev/null
+++ b/python/pyspark/ml/tests.py
@@ -0,0 +1,115 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+"""
+Unit tests for Spark ML Python APIs.
+"""
+
+import sys
+
+if sys.version_info[:2] <= (2, 6):
+ try:
+ import unittest2 as unittest
+ except ImportError:
+ sys.stderr.write('Please install unittest2 to test with Python 2.6 or earlier')
+ sys.exit(1)
+else:
+ import unittest
+
+from pyspark.tests import ReusedPySparkTestCase as PySparkTestCase
+from pyspark.sql import DataFrame
+from pyspark.ml.param import Param
+from pyspark.ml.pipeline import Transformer, Estimator, Pipeline
+
+
+class MockDataset(DataFrame):
+
+ def __init__(self):
+ self.index = 0
+
+
+class MockTransformer(Transformer):
+
+ def __init__(self):
+ super(MockTransformer, self).__init__()
+ self.fake = Param(self, "fake", "fake", None)
+ self.dataset_index = None
+ self.fake_param_value = None
+
+ def transform(self, dataset, params={}):
+ self.dataset_index = dataset.index
+ if self.fake in params:
+ self.fake_param_value = params[self.fake]
+ dataset.index += 1
+ return dataset
+
+
+class MockEstimator(Estimator):
+
+ def __init__(self):
+ super(MockEstimator, self).__init__()
+ self.fake = Param(self, "fake", "fake", None)
+ self.dataset_index = None
+ self.fake_param_value = None
+ self.model = None
+
+ def fit(self, dataset, params={}):
+ self.dataset_index = dataset.index
+ if self.fake in params:
+ self.fake_param_value = params[self.fake]
+ model = MockModel()
+ self.model = model
+ return model
+
+
+class MockModel(MockTransformer, Transformer):
+
+ def __init__(self):
+ super(MockModel, self).__init__()
+
+
+class PipelineTests(PySparkTestCase):
+
+ def test_pipeline(self):
+ dataset = MockDataset()
+ estimator0 = MockEstimator()
+ transformer1 = MockTransformer()
+ estimator2 = MockEstimator()
+ transformer3 = MockTransformer()
+ pipeline = Pipeline() \
+ .setStages([estimator0, transformer1, estimator2, transformer3])
+ pipeline_model = pipeline.fit(dataset, {estimator0.fake: 0, transformer1.fake: 1})
+ self.assertEqual(0, estimator0.dataset_index)
+ self.assertEqual(0, estimator0.fake_param_value)
+ model0 = estimator0.model
+ self.assertEqual(0, model0.dataset_index)
+ self.assertEqual(1, transformer1.dataset_index)
+ self.assertEqual(1, transformer1.fake_param_value)
+ self.assertEqual(2, estimator2.dataset_index)
+ model2 = estimator2.model
+ self.assertIsNone(model2.dataset_index, "The model produced by the last estimator should "
+ "not be called during fit.")
+ dataset = pipeline_model.transform(dataset)
+ self.assertEqual(2, model0.dataset_index)
+ self.assertEqual(3, transformer1.dataset_index)
+ self.assertEqual(4, model2.dataset_index)
+ self.assertEqual(5, transformer3.dataset_index)
+ self.assertEqual(6, dataset.index)
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/python/pyspark/ml/util.py b/python/pyspark/ml/util.py
new file mode 100644
index 0000000000..b1caa84b63
--- /dev/null
+++ b/python/pyspark/ml/util.py
@@ -0,0 +1,46 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import uuid
+
+
+def inherit_doc(cls):
+ for name, func in vars(cls).items():
+ # only inherit docstring for public functions
+ if name.startswith("_"):
+ continue
+ if not func.__doc__:
+ for parent in cls.__bases__:
+ parent_func = getattr(parent, name, None)
+ if parent_func and getattr(parent_func, "__doc__", None):
+ func.__doc__ = parent_func.__doc__
+ break
+ return cls
+
+
+class Identifiable(object):
+ """
+ Object with a unique ID.
+ """
+
+ def __init__(self):
+ #: A unique id for the object. The default implementation
+ #: concatenates the class name, "-", and 8 random hex chars.
+ self.uid = type(self).__name__ + "-" + uuid.uuid4().hex[:8]
+
+ def __repr__(self):
+ return self.uid
diff --git a/python/pyspark/ml/wrapper.py b/python/pyspark/ml/wrapper.py
new file mode 100644
index 0000000000..9e12ddc3d9
--- /dev/null
+++ b/python/pyspark/ml/wrapper.py
@@ -0,0 +1,149 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+from abc import ABCMeta
+
+from pyspark import SparkContext
+from pyspark.sql import DataFrame
+from pyspark.ml.param import Params
+from pyspark.ml.pipeline import Estimator, Transformer
+from pyspark.ml.util import inherit_doc
+
+
+def _jvm():
+ """
+ Returns the JVM view associated with SparkContext. Must be called
+ after SparkContext is initialized.
+ """
+ jvm = SparkContext._jvm
+ if jvm:
+ return jvm
+ else:
+ raise AttributeError("Cannot load _jvm from SparkContext. Is SparkContext initialized?")
+
+
+@inherit_doc
+class JavaWrapper(Params):
+ """
+ Utility class to help create wrapper classes from Java/Scala
+ implementations of pipeline components.
+ """
+
+ __metaclass__ = ABCMeta
+
+ #: Fully-qualified class name of the wrapped Java component.
+ _java_class = None
+
+ def _java_obj(self):
+ """
+ Returns or creates a Java object.
+ """
+ java_obj = _jvm()
+ for name in self._java_class.split("."):
+ java_obj = getattr(java_obj, name)
+ return java_obj()
+
+ def _transfer_params_to_java(self, params, java_obj):
+ """
+ Transforms the embedded params and additional params to the
+ input Java object.
+ :param params: additional params (overwriting embedded values)
+ :param java_obj: Java object to receive the params
+ """
+ paramMap = self._merge_params(params)
+ for param in self.params:
+ if param in paramMap:
+ java_obj.set(param.name, paramMap[param])
+
+ def _empty_java_param_map(self):
+ """
+ Returns an empty Java ParamMap reference.
+ """
+ return _jvm().org.apache.spark.ml.param.ParamMap()
+
+ def _create_java_param_map(self, params, java_obj):
+ paramMap = self._empty_java_param_map()
+ for param, value in params.items():
+ if param.parent is self:
+ paramMap.put(java_obj.getParam(param.name), value)
+ return paramMap
+
+
+@inherit_doc
+class JavaEstimator(Estimator, JavaWrapper):
+ """
+ Base class for :py:class:`Estimator`s that wrap Java/Scala
+ implementations.
+ """
+
+ __metaclass__ = ABCMeta
+
+ def _create_model(self, java_model):
+ """
+ Creates a model from the input Java model reference.
+ """
+ return JavaModel(java_model)
+
+ def _fit_java(self, dataset, params={}):
+ """
+ Fits a Java model to the input dataset.
+ :param dataset: input dataset, which is an instance of
+ :py:class:`pyspark.sql.SchemaRDD`
+ :param params: additional params (overwriting embedded values)
+ :return: fitted Java model
+ """
+ java_obj = self._java_obj()
+ self._transfer_params_to_java(params, java_obj)
+ return java_obj.fit(dataset._jdf, self._empty_java_param_map())
+
+ def fit(self, dataset, params={}):
+ java_model = self._fit_java(dataset, params)
+ return self._create_model(java_model)
+
+
+@inherit_doc
+class JavaTransformer(Transformer, JavaWrapper):
+ """
+ Base class for :py:class:`Transformer`s that wrap Java/Scala
+ implementations.
+ """
+
+ __metaclass__ = ABCMeta
+
+ def transform(self, dataset, params={}):
+ java_obj = self._java_obj()
+ self._transfer_params_to_java({}, java_obj)
+ java_param_map = self._create_java_param_map(params, java_obj)
+ return DataFrame(java_obj.transform(dataset._jdf, java_param_map),
+ dataset.sql_ctx)
+
+
+@inherit_doc
+class JavaModel(JavaTransformer):
+ """
+ Base class for :py:class:`Model`s that wrap Java/Scala
+ implementations.
+ """
+
+ __metaclass__ = ABCMeta
+
+ def __init__(self, java_model):
+ super(JavaTransformer, self).__init__()
+ self._java_model = java_model
+
+ def _java_obj(self):
+ return self._java_model
diff --git a/python/pyspark/sql.py b/python/pyspark/sql.py
index 7d7550c854..c3a6938f56 100644
--- a/python/pyspark/sql.py
+++ b/python/pyspark/sql.py
@@ -1794,20 +1794,6 @@ class Row(tuple):
return "<Row(%s)>" % ", ".join(self)
-def inherit_doc(cls):
- for name, func in vars(cls).items():
- # only inherit docstring for public functions
- if name.startswith("_"):
- continue
- if not func.__doc__:
- for parent in cls.__bases__:
- parent_func = getattr(parent, name, None)
- if parent_func and getattr(parent_func, "__doc__", None):
- func.__doc__ = parent_func.__doc__
- break
- return cls
-
-
class DataFrame(object):
"""A collection of rows that have the same columns.
diff --git a/python/run-tests b/python/run-tests
index 53c34557d9..84cb89b1a9 100755
--- a/python/run-tests
+++ b/python/run-tests
@@ -82,6 +82,13 @@ function run_mllib_tests() {
run_test "pyspark/mllib/tests.py"
}
+function run_ml_tests() {
+ echo "Run ml tests ..."
+ run_test "pyspark/ml/feature.py"
+ run_test "pyspark/ml/classification.py"
+ run_test "pyspark/ml/tests.py"
+}
+
function run_streaming_tests() {
echo "Run streaming tests ..."
run_test "pyspark/streaming/util.py"
@@ -103,6 +110,7 @@ $PYSPARK_PYTHON --version
run_core_tests
run_sql_tests
run_mllib_tests
+run_ml_tests
run_streaming_tests
# Try to test with PyPy