aboutsummaryrefslogtreecommitdiff
path: root/python
diff options
context:
space:
mode:
authorJoseph K. Bradley <joseph.kurata.bradley@gmail.com>2014-08-02 13:07:17 -0700
committerXiangrui Meng <meng@databricks.com>2014-08-02 13:07:17 -0700
commit3f67382e7c9c3f6a8f6ce124ab3fcb1a9c1a264f (patch)
tree1a39b613599d552f2fbdd1679f78f205887d1698 /python
parente09e18b3123c20e9b9497cf606473da500349d4d (diff)
downloadspark-3f67382e7c9c3f6a8f6ce124ab3fcb1a9c1a264f.tar.gz
spark-3f67382e7c9c3f6a8f6ce124ab3fcb1a9c1a264f.tar.bz2
spark-3f67382e7c9c3f6a8f6ce124ab3fcb1a9c1a264f.zip
[SPARK-2478] [mllib] DecisionTree Python API
Added experimental Python API for Decision Trees. API: * class DecisionTreeModel ** predict() for single examples and RDDs, taking both feature vectors and LabeledPoints ** numNodes() ** depth() ** __str__() * class DecisionTree ** trainClassifier() ** trainRegressor() ** train() Examples and testing: * Added example testing classification and regression with batch prediction: examples/src/main/python/mllib/tree.py * Have also tested example usage in doc of python/pyspark/mllib/tree.py which tests single-example prediction with dense and sparse vectors Also: Small bug fix in python/pyspark/mllib/_common.py: In _linear_predictor_typecheck, changed check for RDD to use isinstance() instead of type() in order to catch RDD subclasses. CC mengxr manishamde Author: Joseph K. Bradley <joseph.kurata.bradley@gmail.com> Closes #1727 from jkbradley/decisiontree-python-new and squashes the following commits: 3744488 [Joseph K. Bradley] Renamed test tree.py to decision_tree_runner.py Small updates based on github review. 6b86a9d [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into decisiontree-python-new affceb9 [Joseph K. Bradley] * Fixed bug in doc tests in pyspark/mllib/util.py caused by change in loadLibSVMFile behavior. (It used to threshold labels at 0 to make them 0/1, but it now leaves them as they are.) * Fixed small bug in loadLibSVMFile: If a data file had no features, then loadLibSVMFile would create a single all-zero feature. 67a29bc [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into decisiontree-python-new cf46ad7 [Joseph K. Bradley] Python DecisionTreeModel * predict(empty RDD) returns an empty RDD instead of an error. * Removed support for calling predict() on LabeledPoint and RDD[LabeledPoint] * predict() does not cache serialized RDD any more. aa29873 [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into decisiontree-python-new bf21be4 [Joseph K. Bradley] removed old run() func from DecisionTree fa10ea7 [Joseph K. Bradley] Small style update 7968692 [Joseph K. Bradley] small braces typo fix e34c263 [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into decisiontree-python-new 4801b40 [Joseph K. Bradley] Small style update to DecisionTreeSuite db0eab2 [Joseph K. Bradley] Merge branch 'decisiontree-bugfix2' into decisiontree-python-new 6873fa9 [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into decisiontree-python-new 225822f [Joseph K. Bradley] Bug: In DecisionTree, the method sequentialBinSearchForOrderedCategoricalFeatureInClassification() indexed bins from 0 to (math.pow(2, featureCategories.toInt - 1) - 1). This upper bound is the bound for unordered categorical features, not ordered ones. The upper bound should be the arity (i.e., max value) of the feature. 93953f1 [Joseph K. Bradley] Likely done with Python API. 6df89a9 [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into decisiontree-python-new 4562c08 [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into decisiontree-python-new 665ba78 [Joseph K. Bradley] Small updates towards Python DecisionTree API 188cb0d [Joseph K. Bradley] Merge branch 'decisiontree-bugfix' into decisiontree-python-new 6622247 [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into decisiontree-python-new b8fac57 [Joseph K. Bradley] Finished Python DecisionTree API and example but need to test a bit more. 2b20c61 [Joseph K. Bradley] Small doc and style updates 1b29c13 [Joseph K. Bradley] Merge branch 'decisiontree-bugfix' into decisiontree-python-new 584449a [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into decisiontree-python-new dab0b67 [Joseph K. Bradley] Added documentation for DecisionTree internals 8bb8aa0 [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into decisiontree-bugfix 978cfcf [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into decisiontree-bugfix 6eed482 [Joseph K. Bradley] In DecisionTree: Changed from using procedural syntax for functions returning Unit to explicitly writing Unit return type. 376dca2 [Joseph K. Bradley] Updated meaning of maxDepth by 1 to fit scikit-learn and rpart. * In code, replaced usages of maxDepth <-- maxDepth + 1 * In params, replace settings of maxDepth <-- maxDepth - 1 e06e423 [Joseph K. Bradley] Merge branch 'decisiontree-bugfix' into decisiontree-python-new bab3f19 [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into decisiontree-python-new 59750f8 [Joseph K. Bradley] * Updated Strategy to check numClassesForClassification only if algo=Classification. * Updates based on comments: ** DecisionTreeRunner *** Made dataFormat arg default to libsvm ** Small cleanups ** tree.Node: Made recursive helper methods private, and renamed them. 52e17c5 [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into decisiontree-bugfix f5a036c [Joseph K. Bradley] Merge branch 'decisiontree-bugfix' into decisiontree-python-new da50db7 [Joseph K. Bradley] Added one more test to DecisionTreeSuite: stump with 2 continuous variables for binary classification. Caused problems in past, but fixed now. 8e227ea [Joseph K. Bradley] Changed Strategy so it only requires numClassesForClassification >= 2 for classification cd1d933 [Joseph K. Bradley] Merge branch 'decisiontree-bugfix' into decisiontree-python-new 8ea8750 [Joseph K. Bradley] Bug fix: Off-by-1 when finding thresholds for splits for continuous features. 8a758db [Joseph K. Bradley] Merge branch 'decisiontree-bugfix' into decisiontree-python-new 5fe44ed [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into decisiontree-python-new 2283df8 [Joseph K. Bradley] 2 bug fixes. 73fbea2 [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into decisiontree-bugfix 5f920a1 [Joseph K. Bradley] Demonstration of bug before submitting fix: Updated DecisionTreeSuite so that 3 tests fail. Will describe bug in next commit. f825352 [Joseph K. Bradley] Wrote Python API and example for DecisionTree. Also added toString, depth, and numNodes methods to DecisionTreeModel.
Diffstat (limited to 'python')
-rw-r--r--python/pyspark/mllib/_common.py33
-rw-r--r--python/pyspark/mllib/tests.py36
-rw-r--r--python/pyspark/mllib/tree.py225
-rw-r--r--python/pyspark/mllib/util.py14
-rwxr-xr-xpython/run-tests1
5 files changed, 291 insertions, 18 deletions
diff --git a/python/pyspark/mllib/_common.py b/python/pyspark/mllib/_common.py
index c6ca6a75df..9c1565affb 100644
--- a/python/pyspark/mllib/_common.py
+++ b/python/pyspark/mllib/_common.py
@@ -343,22 +343,35 @@ def _copyto(array, buffer, offset, shape, dtype):
temp_array[...] = array
-def _get_unmangled_rdd(data, serializer):
+def _get_unmangled_rdd(data, serializer, cache=True):
+ """
+ :param cache: If True, the serialized RDD is cached. (default = True)
+ WARNING: Users should unpersist() this later!
+ """
dataBytes = data.map(serializer)
dataBytes._bypass_serializer = True
- dataBytes.cache() # TODO: users should unpersist() this later!
+ if cache:
+ dataBytes.cache()
return dataBytes
-# Map a pickled Python RDD of Python dense or sparse vectors to a Java RDD of
-# _serialized_double_vectors
-def _get_unmangled_double_vector_rdd(data):
- return _get_unmangled_rdd(data, _serialize_double_vector)
+def _get_unmangled_double_vector_rdd(data, cache=True):
+ """
+ Map a pickled Python RDD of Python dense or sparse vectors to a Java RDD of
+ _serialized_double_vectors.
+ :param cache: If True, the serialized RDD is cached. (default = True)
+ WARNING: Users should unpersist() this later!
+ """
+ return _get_unmangled_rdd(data, _serialize_double_vector, cache)
-# Map a pickled Python RDD of LabeledPoint to a Java RDD of _serialized_labeled_points
-def _get_unmangled_labeled_point_rdd(data):
- return _get_unmangled_rdd(data, _serialize_labeled_point)
+def _get_unmangled_labeled_point_rdd(data, cache=True):
+ """
+ Map a pickled Python RDD of LabeledPoint to a Java RDD of _serialized_labeled_points.
+ :param cache: If True, the serialized RDD is cached. (default = True)
+ WARNING: Users should unpersist() this later!
+ """
+ return _get_unmangled_rdd(data, _serialize_labeled_point, cache)
# Common functions for dealing with and training linear models
@@ -380,7 +393,7 @@ def _linear_predictor_typecheck(x, coeffs):
if x.size != coeffs.shape[0]:
raise RuntimeError("Got sparse vector of size %d; wanted %d" % (
x.size, coeffs.shape[0]))
- elif (type(x) == RDD):
+ elif isinstance(x, RDD):
raise RuntimeError("Bulk predict not yet supported.")
else:
raise TypeError("Argument of type " + type(x).__name__ + " unsupported")
diff --git a/python/pyspark/mllib/tests.py b/python/pyspark/mllib/tests.py
index 37ccf1d590..9d1e5be637 100644
--- a/python/pyspark/mllib/tests.py
+++ b/python/pyspark/mllib/tests.py
@@ -100,6 +100,7 @@ class ListTests(PySparkTestCase):
def test_classification(self):
from pyspark.mllib.classification import LogisticRegressionWithSGD, SVMWithSGD, NaiveBayes
+ from pyspark.mllib.tree import DecisionTree
data = [
LabeledPoint(0.0, [1, 0, 0]),
LabeledPoint(1.0, [0, 1, 1]),
@@ -127,9 +128,19 @@ class ListTests(PySparkTestCase):
self.assertTrue(nb_model.predict(features[2]) <= 0)
self.assertTrue(nb_model.predict(features[3]) > 0)
+ categoricalFeaturesInfo = {0: 3} # feature 0 has 3 categories
+ dt_model = \
+ DecisionTree.trainClassifier(rdd, numClasses=2,
+ categoricalFeaturesInfo=categoricalFeaturesInfo)
+ self.assertTrue(dt_model.predict(features[0]) <= 0)
+ self.assertTrue(dt_model.predict(features[1]) > 0)
+ self.assertTrue(dt_model.predict(features[2]) <= 0)
+ self.assertTrue(dt_model.predict(features[3]) > 0)
+
def test_regression(self):
from pyspark.mllib.regression import LinearRegressionWithSGD, LassoWithSGD, \
RidgeRegressionWithSGD
+ from pyspark.mllib.tree import DecisionTree
data = [
LabeledPoint(-1.0, [0, -1]),
LabeledPoint(1.0, [0, 1]),
@@ -157,6 +168,14 @@ class ListTests(PySparkTestCase):
self.assertTrue(rr_model.predict(features[2]) <= 0)
self.assertTrue(rr_model.predict(features[3]) > 0)
+ categoricalFeaturesInfo = {0: 2} # feature 0 has 2 categories
+ dt_model = \
+ DecisionTree.trainRegressor(rdd, categoricalFeaturesInfo=categoricalFeaturesInfo)
+ self.assertTrue(dt_model.predict(features[0]) <= 0)
+ self.assertTrue(dt_model.predict(features[1]) > 0)
+ self.assertTrue(dt_model.predict(features[2]) <= 0)
+ self.assertTrue(dt_model.predict(features[3]) > 0)
+
@unittest.skipIf(not _have_scipy, "SciPy not installed")
class SciPyTests(PySparkTestCase):
@@ -229,6 +248,7 @@ class SciPyTests(PySparkTestCase):
def test_classification(self):
from pyspark.mllib.classification import LogisticRegressionWithSGD, SVMWithSGD, NaiveBayes
+ from pyspark.mllib.tree import DecisionTree
data = [
LabeledPoint(0.0, self.scipy_matrix(2, {0: 1.0})),
LabeledPoint(1.0, self.scipy_matrix(2, {1: 1.0})),
@@ -256,9 +276,18 @@ class SciPyTests(PySparkTestCase):
self.assertTrue(nb_model.predict(features[2]) <= 0)
self.assertTrue(nb_model.predict(features[3]) > 0)
+ categoricalFeaturesInfo = {0: 3} # feature 0 has 3 categories
+ dt_model = DecisionTree.trainClassifier(rdd, numClasses=2,
+ categoricalFeaturesInfo=categoricalFeaturesInfo)
+ self.assertTrue(dt_model.predict(features[0]) <= 0)
+ self.assertTrue(dt_model.predict(features[1]) > 0)
+ self.assertTrue(dt_model.predict(features[2]) <= 0)
+ self.assertTrue(dt_model.predict(features[3]) > 0)
+
def test_regression(self):
from pyspark.mllib.regression import LinearRegressionWithSGD, LassoWithSGD, \
RidgeRegressionWithSGD
+ from pyspark.mllib.tree import DecisionTree
data = [
LabeledPoint(-1.0, self.scipy_matrix(2, {1: -1.0})),
LabeledPoint(1.0, self.scipy_matrix(2, {1: 1.0})),
@@ -286,6 +315,13 @@ class SciPyTests(PySparkTestCase):
self.assertTrue(rr_model.predict(features[2]) <= 0)
self.assertTrue(rr_model.predict(features[3]) > 0)
+ categoricalFeaturesInfo = {0: 2} # feature 0 has 2 categories
+ dt_model = DecisionTree.trainRegressor(rdd, categoricalFeaturesInfo=categoricalFeaturesInfo)
+ self.assertTrue(dt_model.predict(features[0]) <= 0)
+ self.assertTrue(dt_model.predict(features[1]) > 0)
+ self.assertTrue(dt_model.predict(features[2]) <= 0)
+ self.assertTrue(dt_model.predict(features[3]) > 0)
+
if __name__ == "__main__":
if not _have_scipy:
diff --git a/python/pyspark/mllib/tree.py b/python/pyspark/mllib/tree.py
new file mode 100644
index 0000000000..1e0006df75
--- /dev/null
+++ b/python/pyspark/mllib/tree.py
@@ -0,0 +1,225 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+from py4j.java_collections import MapConverter
+
+from pyspark import SparkContext, RDD
+from pyspark.mllib._common import \
+ _get_unmangled_rdd, _get_unmangled_double_vector_rdd, _serialize_double_vector, \
+ _deserialize_labeled_point, _get_unmangled_labeled_point_rdd, \
+ _deserialize_double
+from pyspark.mllib.regression import LabeledPoint
+from pyspark.serializers import NoOpSerializer
+
+class DecisionTreeModel(object):
+ """
+ A decision tree model for classification or regression.
+
+ EXPERIMENTAL: This is an experimental API.
+ It will probably be modified for Spark v1.2.
+ """
+
+ def __init__(self, sc, java_model):
+ """
+ :param sc: Spark context
+ :param java_model: Handle to Java model object
+ """
+ self._sc = sc
+ self._java_model = java_model
+
+ def __del__(self):
+ self._sc._gateway.detach(self._java_model)
+
+ def predict(self, x):
+ """
+ Predict the label of one or more examples.
+ :param x: Data point (feature vector),
+ or an RDD of data points (feature vectors).
+ """
+ pythonAPI = self._sc._jvm.PythonMLLibAPI()
+ if isinstance(x, RDD):
+ # Bulk prediction
+ if x.count() == 0:
+ return self._sc.parallelize([])
+ dataBytes = _get_unmangled_double_vector_rdd(x, cache=False)
+ jSerializedPreds = \
+ pythonAPI.predictDecisionTreeModel(self._java_model,
+ dataBytes._jrdd)
+ serializedPreds = RDD(jSerializedPreds, self._sc, NoOpSerializer())
+ return serializedPreds.map(lambda bytes: _deserialize_double(bytearray(bytes)))
+ else:
+ # Assume x is a single data point.
+ x_ = _serialize_double_vector(x)
+ return pythonAPI.predictDecisionTreeModel(self._java_model, x_)
+
+ def numNodes(self):
+ return self._java_model.numNodes()
+
+ def depth(self):
+ return self._java_model.depth()
+
+ def __str__(self):
+ return self._java_model.toString()
+
+
+class DecisionTree(object):
+ """
+ Learning algorithm for a decision tree model
+ for classification or regression.
+
+ EXPERIMENTAL: This is an experimental API.
+ It will probably be modified for Spark v1.2.
+
+ Example usage:
+ >>> from numpy import array, ndarray
+ >>> from pyspark.mllib.regression import LabeledPoint
+ >>> from pyspark.mllib.tree import DecisionTree
+ >>> from pyspark.mllib.linalg import SparseVector
+ >>>
+ >>> data = [
+ ... LabeledPoint(0.0, [0.0]),
+ ... LabeledPoint(1.0, [1.0]),
+ ... LabeledPoint(1.0, [2.0]),
+ ... LabeledPoint(1.0, [3.0])
+ ... ]
+ >>>
+ >>> model = DecisionTree.trainClassifier(sc.parallelize(data), numClasses=2)
+ >>> print(model)
+ DecisionTreeModel classifier
+ If (feature 0 <= 0.5)
+ Predict: 0.0
+ Else (feature 0 > 0.5)
+ Predict: 1.0
+
+ >>> model.predict(array([1.0])) > 0
+ True
+ >>> model.predict(array([0.0])) == 0
+ True
+ >>> sparse_data = [
+ ... LabeledPoint(0.0, SparseVector(2, {0: 0.0})),
+ ... LabeledPoint(1.0, SparseVector(2, {1: 1.0})),
+ ... LabeledPoint(0.0, SparseVector(2, {0: 0.0})),
+ ... LabeledPoint(1.0, SparseVector(2, {1: 2.0}))
+ ... ]
+ >>>
+ >>> model = DecisionTree.trainRegressor(sc.parallelize(sparse_data))
+ >>> model.predict(array([0.0, 1.0])) == 1
+ True
+ >>> model.predict(array([0.0, 0.0])) == 0
+ True
+ >>> model.predict(SparseVector(2, {1: 1.0})) == 1
+ True
+ >>> model.predict(SparseVector(2, {1: 0.0})) == 0
+ True
+ """
+
+ @staticmethod
+ def trainClassifier(data, numClasses, categoricalFeaturesInfo={},
+ impurity="gini", maxDepth=4, maxBins=100):
+ """
+ Train a DecisionTreeModel for classification.
+
+ :param data: Training data: RDD of LabeledPoint.
+ Labels are integers {0,1,...,numClasses}.
+ :param numClasses: Number of classes for classification.
+ :param categoricalFeaturesInfo: Map from categorical feature index
+ to number of categories.
+ Any feature not in this map
+ is treated as continuous.
+ :param impurity: Supported values: "entropy" or "gini"
+ :param maxDepth: Max depth of tree.
+ E.g., depth 0 means 1 leaf node.
+ Depth 1 means 1 internal node + 2 leaf nodes.
+ :param maxBins: Number of bins used for finding splits at each node.
+ :return: DecisionTreeModel
+ """
+ return DecisionTree.train(data, "classification", numClasses,
+ categoricalFeaturesInfo,
+ impurity, maxDepth, maxBins)
+
+ @staticmethod
+ def trainRegressor(data, categoricalFeaturesInfo={},
+ impurity="variance", maxDepth=4, maxBins=100):
+ """
+ Train a DecisionTreeModel for regression.
+
+ :param data: Training data: RDD of LabeledPoint.
+ Labels are real numbers.
+ :param categoricalFeaturesInfo: Map from categorical feature index
+ to number of categories.
+ Any feature not in this map
+ is treated as continuous.
+ :param impurity: Supported values: "variance"
+ :param maxDepth: Max depth of tree.
+ E.g., depth 0 means 1 leaf node.
+ Depth 1 means 1 internal node + 2 leaf nodes.
+ :param maxBins: Number of bins used for finding splits at each node.
+ :return: DecisionTreeModel
+ """
+ return DecisionTree.train(data, "regression", 0,
+ categoricalFeaturesInfo,
+ impurity, maxDepth, maxBins)
+
+
+ @staticmethod
+ def train(data, algo, numClasses, categoricalFeaturesInfo,
+ impurity, maxDepth, maxBins=100):
+ """
+ Train a DecisionTreeModel for classification or regression.
+
+ :param data: Training data: RDD of LabeledPoint.
+ For classification, labels are integers
+ {0,1,...,numClasses}.
+ For regression, labels are real numbers.
+ :param algo: "classification" or "regression"
+ :param numClasses: Number of classes for classification.
+ :param categoricalFeaturesInfo: Map from categorical feature index
+ to number of categories.
+ Any feature not in this map
+ is treated as continuous.
+ :param impurity: For classification: "entropy" or "gini".
+ For regression: "variance".
+ :param maxDepth: Max depth of tree.
+ E.g., depth 0 means 1 leaf node.
+ Depth 1 means 1 internal node + 2 leaf nodes.
+ :param maxBins: Number of bins used for finding splits at each node.
+ :return: DecisionTreeModel
+ """
+ sc = data.context
+ dataBytes = _get_unmangled_labeled_point_rdd(data)
+ categoricalFeaturesInfoJMap = \
+ MapConverter().convert(categoricalFeaturesInfo,
+ sc._gateway._gateway_client)
+ model = sc._jvm.PythonMLLibAPI().trainDecisionTreeModel(
+ dataBytes._jrdd, algo,
+ numClasses, categoricalFeaturesInfoJMap,
+ impurity, maxDepth, maxBins)
+ dataBytes.unpersist()
+ return DecisionTreeModel(sc, model)
+
+
+def _test():
+ import doctest
+ globs = globals().copy()
+ globs['sc'] = SparkContext('local[4]', 'PythonTest', batchSize=2)
+ (failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS)
+ globs['sc'].stop()
+ if failure_count:
+ exit(-1)
+
+if __name__ == "__main__":
+ _test()
diff --git a/python/pyspark/mllib/util.py b/python/pyspark/mllib/util.py
index d94900cefd..639cda6350 100644
--- a/python/pyspark/mllib/util.py
+++ b/python/pyspark/mllib/util.py
@@ -16,6 +16,7 @@
#
import numpy as np
+import warnings
from pyspark.mllib.linalg import Vectors, SparseVector
from pyspark.mllib.regression import LabeledPoint
@@ -29,9 +30,9 @@ class MLUtils:
Helper methods to load, save and pre-process data used in MLlib.
"""
- @deprecated
@staticmethod
def _parse_libsvm_line(line, multiclass):
+ warnings.warn("deprecated", DeprecationWarning)
return _parse_libsvm_line(line)
@staticmethod
@@ -67,9 +68,9 @@ class MLUtils:
" but got " % type(v))
return " ".join(items)
- @deprecated
@staticmethod
def loadLibSVMFile(sc, path, multiclass=False, numFeatures=-1, minPartitions=None):
+ warnings.warn("deprecated", DeprecationWarning)
return loadLibSVMFile(sc, path, numFeatures, minPartitions)
@staticmethod
@@ -106,7 +107,6 @@ class MLUtils:
>>> tempFile.write("+1 1:1.0 3:2.0 5:3.0\\n-1\\n-1 2:4.0 4:5.0 6:6.0")
>>> tempFile.flush()
>>> examples = MLUtils.loadLibSVMFile(sc, tempFile.name).collect()
- >>> multiclass_examples = MLUtils.loadLibSVMFile(sc, tempFile.name).collect()
>>> tempFile.close()
>>> type(examples[0]) == LabeledPoint
True
@@ -115,20 +115,18 @@ class MLUtils:
>>> type(examples[1]) == LabeledPoint
True
>>> print examples[1]
- (0.0,(6,[],[]))
+ (-1.0,(6,[],[]))
>>> type(examples[2]) == LabeledPoint
True
>>> print examples[2]
- (0.0,(6,[1,3,5],[4.0,5.0,6.0]))
- >>> multiclass_examples[1].label
- -1.0
+ (-1.0,(6,[1,3,5],[4.0,5.0,6.0]))
"""
lines = sc.textFile(path, minPartitions)
parsed = lines.map(lambda l: MLUtils._parse_libsvm_line(l))
if numFeatures <= 0:
parsed.cache()
- numFeatures = parsed.map(lambda x: 0 if x[1].size == 0 else x[1][-1]).reduce(max) + 1
+ numFeatures = parsed.map(lambda x: -1 if x[1].size == 0 else x[1][-1]).reduce(max) + 1
return parsed.map(lambda x: LabeledPoint(x[0], Vectors.sparse(numFeatures, x[1], x[2])))
@staticmethod
diff --git a/python/run-tests b/python/run-tests
index 5049e15ce5..48feba2f5b 100755
--- a/python/run-tests
+++ b/python/run-tests
@@ -71,6 +71,7 @@ run_test "pyspark/mllib/random.py"
run_test "pyspark/mllib/recommendation.py"
run_test "pyspark/mllib/regression.py"
run_test "pyspark/mllib/tests.py"
+run_test "pyspark/mllib/util.py"
if [[ $FAILED == 0 ]]; then
echo -en "\033[32m" # Green