diff options
author | Davies Liu <davies@databricks.com> | 2014-10-30 22:25:18 -0700 |
---|---|---|
committer | Xiangrui Meng <meng@databricks.com> | 2014-10-30 22:25:18 -0700 |
commit | 872fc669b497fb255db3212568f2a14c2ba0d5db (patch) | |
tree | 6dcaa7e0b251fa5f233171e2878a4dc428db2348 /python/pyspark/mllib/tree.py | |
parent | 0734d09320fe37edd3a02718511cda0bda852478 (diff) | |
download | spark-872fc669b497fb255db3212568f2a14c2ba0d5db.tar.gz spark-872fc669b497fb255db3212568f2a14c2ba0d5db.tar.bz2 spark-872fc669b497fb255db3212568f2a14c2ba0d5db.zip |
[SPARK-4124] [MLlib] [PySpark] simplify serialization in MLlib Python API
Create several helper functions to call MLlib Java API, convert the arguments to Java type and convert return value to Python object automatically, this simplify serialization in MLlib Python API very much.
After this, the MLlib Python API does not need to deal with serialization details anymore, it's easier to add new API.
cc mengxr
Author: Davies Liu <davies@databricks.com>
Closes #2995 from davies/cleanup and squashes the following commits:
8fa6ec6 [Davies Liu] address comments
16b85a0 [Davies Liu] Merge branch 'master' of github.com:apache/spark into cleanup
43743e5 [Davies Liu] bugfix
731331f [Davies Liu] simplify serialization in MLlib Python API
Diffstat (limited to 'python/pyspark/mllib/tree.py')
-rw-r--r-- | python/pyspark/mllib/tree.py | 55 |
1 files changed, 11 insertions, 44 deletions
diff --git a/python/pyspark/mllib/tree.py b/python/pyspark/mllib/tree.py index 64ee79d83e..5d1a3c0962 100644 --- a/python/pyspark/mllib/tree.py +++ b/python/pyspark/mllib/tree.py @@ -15,36 +15,22 @@ # limitations under the License. # -from py4j.java_collections import MapConverter - from pyspark import SparkContext, RDD -from pyspark.serializers import BatchedSerializer, PickleSerializer -from pyspark.mllib.linalg import Vector, _convert_to_vector, _to_java_object_rdd +from pyspark.mllib.common import callMLlibFunc, JavaModelWrapper +from pyspark.mllib.linalg import _convert_to_vector from pyspark.mllib.regression import LabeledPoint __all__ = ['DecisionTreeModel', 'DecisionTree'] -class DecisionTreeModel(object): +class DecisionTreeModel(JavaModelWrapper): """ A decision tree model for classification or regression. EXPERIMENTAL: This is an experimental API. - It will probably be modified for Spark v1.2. + It will probably be modified in future. """ - - def __init__(self, sc, java_model): - """ - :param sc: Spark context - :param java_model: Handle to Java model object - """ - self._sc = sc - self._java_model = java_model - - def __del__(self): - self._sc._gateway.detach(self._java_model) - def predict(self, x): """ Predict the label of one or more examples. @@ -52,24 +38,11 @@ class DecisionTreeModel(object): :param x: Data point (feature vector), or an RDD of data points (feature vectors). """ - SerDe = self._sc._jvm.SerDe - ser = PickleSerializer() if isinstance(x, RDD): - # Bulk prediction - first = x.take(1) - if not first: - return self._sc.parallelize([]) - if not isinstance(first[0], Vector): - x = x.map(_convert_to_vector) - jPred = self._java_model.predict(_to_java_object_rdd(x)).toJavaRDD() - jpyrdd = self._sc._jvm.SerDe.javaToPython(jPred) - return RDD(jpyrdd, self._sc, BatchedSerializer(ser, 1024)) + return self.call("predict", x.map(_convert_to_vector)) else: - # Assume x is a single data point. - bytes = bytearray(ser.dumps(_convert_to_vector(x))) - vec = self._sc._jvm.SerDe.loads(bytes) - return self._java_model.predict(vec) + return self.call("predict", _convert_to_vector(x)) def numNodes(self): return self._java_model.numNodes() @@ -98,19 +71,13 @@ class DecisionTree(object): """ @staticmethod - def _train(data, type, numClasses, categoricalFeaturesInfo, - impurity="gini", maxDepth=5, maxBins=32, minInstancesPerNode=1, - minInfoGain=0.0): + def _train(data, type, numClasses, features, impurity="gini", maxDepth=5, maxBins=32, + minInstancesPerNode=1, minInfoGain=0.0): first = data.first() assert isinstance(first, LabeledPoint), "the data should be RDD of LabeledPoint" - sc = data.context - jrdd = _to_java_object_rdd(data) - cfiMap = MapConverter().convert(categoricalFeaturesInfo, - sc._gateway._gateway_client) - model = sc._jvm.PythonMLLibAPI().trainDecisionTreeModel( - jrdd, type, numClasses, cfiMap, - impurity, maxDepth, maxBins, minInstancesPerNode, minInfoGain) - return DecisionTreeModel(sc, model) + model = callMLlibFunc("trainDecisionTreeModel", data, type, numClasses, features, + impurity, maxDepth, maxBins, minInstancesPerNode, minInfoGain) + return DecisionTreeModel(model) @staticmethod def trainClassifier(data, numClasses, categoricalFeaturesInfo, |