aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark/mllib/tree.py
diff options
context:
space:
mode:
Diffstat (limited to 'python/pyspark/mllib/tree.py')
-rw-r--r--python/pyspark/mllib/tree.py55
1 files changed, 11 insertions, 44 deletions
diff --git a/python/pyspark/mllib/tree.py b/python/pyspark/mllib/tree.py
index 64ee79d83e..5d1a3c0962 100644
--- a/python/pyspark/mllib/tree.py
+++ b/python/pyspark/mllib/tree.py
@@ -15,36 +15,22 @@
# limitations under the License.
#
-from py4j.java_collections import MapConverter
-
from pyspark import SparkContext, RDD
-from pyspark.serializers import BatchedSerializer, PickleSerializer
-from pyspark.mllib.linalg import Vector, _convert_to_vector, _to_java_object_rdd
+from pyspark.mllib.common import callMLlibFunc, JavaModelWrapper
+from pyspark.mllib.linalg import _convert_to_vector
from pyspark.mllib.regression import LabeledPoint
__all__ = ['DecisionTreeModel', 'DecisionTree']
-class DecisionTreeModel(object):
+class DecisionTreeModel(JavaModelWrapper):
"""
A decision tree model for classification or regression.
EXPERIMENTAL: This is an experimental API.
- It will probably be modified for Spark v1.2.
+ It will probably be modified in future.
"""
-
- def __init__(self, sc, java_model):
- """
- :param sc: Spark context
- :param java_model: Handle to Java model object
- """
- self._sc = sc
- self._java_model = java_model
-
- def __del__(self):
- self._sc._gateway.detach(self._java_model)
-
def predict(self, x):
"""
Predict the label of one or more examples.
@@ -52,24 +38,11 @@ class DecisionTreeModel(object):
:param x: Data point (feature vector),
or an RDD of data points (feature vectors).
"""
- SerDe = self._sc._jvm.SerDe
- ser = PickleSerializer()
if isinstance(x, RDD):
- # Bulk prediction
- first = x.take(1)
- if not first:
- return self._sc.parallelize([])
- if not isinstance(first[0], Vector):
- x = x.map(_convert_to_vector)
- jPred = self._java_model.predict(_to_java_object_rdd(x)).toJavaRDD()
- jpyrdd = self._sc._jvm.SerDe.javaToPython(jPred)
- return RDD(jpyrdd, self._sc, BatchedSerializer(ser, 1024))
+ return self.call("predict", x.map(_convert_to_vector))
else:
- # Assume x is a single data point.
- bytes = bytearray(ser.dumps(_convert_to_vector(x)))
- vec = self._sc._jvm.SerDe.loads(bytes)
- return self._java_model.predict(vec)
+ return self.call("predict", _convert_to_vector(x))
def numNodes(self):
return self._java_model.numNodes()
@@ -98,19 +71,13 @@ class DecisionTree(object):
"""
@staticmethod
- def _train(data, type, numClasses, categoricalFeaturesInfo,
- impurity="gini", maxDepth=5, maxBins=32, minInstancesPerNode=1,
- minInfoGain=0.0):
+ def _train(data, type, numClasses, features, impurity="gini", maxDepth=5, maxBins=32,
+ minInstancesPerNode=1, minInfoGain=0.0):
first = data.first()
assert isinstance(first, LabeledPoint), "the data should be RDD of LabeledPoint"
- sc = data.context
- jrdd = _to_java_object_rdd(data)
- cfiMap = MapConverter().convert(categoricalFeaturesInfo,
- sc._gateway._gateway_client)
- model = sc._jvm.PythonMLLibAPI().trainDecisionTreeModel(
- jrdd, type, numClasses, cfiMap,
- impurity, maxDepth, maxBins, minInstancesPerNode, minInfoGain)
- return DecisionTreeModel(sc, model)
+ model = callMLlibFunc("trainDecisionTreeModel", data, type, numClasses, features,
+ impurity, maxDepth, maxBins, minInstancesPerNode, minInfoGain)
+ return DecisionTreeModel(model)
@staticmethod
def trainClassifier(data, numClasses, categoricalFeaturesInfo,