aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark/mllib/tests.py
diff options
context:
space:
mode:
authorKazuki Taniguchi <kazuki.t.1018@gmail.com>2015-01-30 00:39:44 -0800
committerXiangrui Meng <meng@databricks.com>2015-01-30 00:39:44 -0800
commitbc1fc9b60dab69ae74419e35dc6bd263dc504f34 (patch)
tree99bb73f18a7cf2bb70ab31b99cfa72e71699bdf5 /python/pyspark/mllib/tests.py
parentdd4d84cf809e6e425958fe768c518679d1828779 (diff)
downloadspark-bc1fc9b60dab69ae74419e35dc6bd263dc504f34.tar.gz
spark-bc1fc9b60dab69ae74419e35dc6bd263dc504f34.tar.bz2
spark-bc1fc9b60dab69ae74419e35dc6bd263dc504f34.zip
[SPARK-5094][MLlib] Add Python API for Gradient Boosted Trees
This PR is implementing the Gradient Boosted Trees for Python API. Author: Kazuki Taniguchi <kazuki.t.1018@gmail.com> Closes #3951 from kazk1018/gbt_for_py and squashes the following commits: 620d247 [Kazuki Taniguchi] [SPARK-5094][MLlib] Add Python API for Gradient Boosted Trees
Diffstat (limited to 'python/pyspark/mllib/tests.py')
-rw-r--r--python/pyspark/mllib/tests.py41
1 files changed, 34 insertions, 7 deletions
diff --git a/python/pyspark/mllib/tests.py b/python/pyspark/mllib/tests.py
index f48e3d6dac..61e0cf5d90 100644
--- a/python/pyspark/mllib/tests.py
+++ b/python/pyspark/mllib/tests.py
@@ -169,7 +169,7 @@ class ListTests(PySparkTestCase):
def test_classification(self):
from pyspark.mllib.classification import LogisticRegressionWithSGD, SVMWithSGD, NaiveBayes
- from pyspark.mllib.tree import DecisionTree
+ from pyspark.mllib.tree import DecisionTree, RandomForest, GradientBoostedTrees
data = [
LabeledPoint(0.0, [1, 0, 0]),
LabeledPoint(1.0, [0, 1, 1]),
@@ -198,18 +198,31 @@ class ListTests(PySparkTestCase):
self.assertTrue(nb_model.predict(features[3]) > 0)
categoricalFeaturesInfo = {0: 3} # feature 0 has 3 categories
- dt_model = \
- DecisionTree.trainClassifier(rdd, numClasses=2,
- categoricalFeaturesInfo=categoricalFeaturesInfo)
+ dt_model = DecisionTree.trainClassifier(
+ rdd, numClasses=2, categoricalFeaturesInfo=categoricalFeaturesInfo)
self.assertTrue(dt_model.predict(features[0]) <= 0)
self.assertTrue(dt_model.predict(features[1]) > 0)
self.assertTrue(dt_model.predict(features[2]) <= 0)
self.assertTrue(dt_model.predict(features[3]) > 0)
+ rf_model = RandomForest.trainClassifier(
+ rdd, numClasses=2, categoricalFeaturesInfo=categoricalFeaturesInfo, numTrees=100)
+ self.assertTrue(rf_model.predict(features[0]) <= 0)
+ self.assertTrue(rf_model.predict(features[1]) > 0)
+ self.assertTrue(rf_model.predict(features[2]) <= 0)
+ self.assertTrue(rf_model.predict(features[3]) > 0)
+
+ gbt_model = GradientBoostedTrees.trainClassifier(
+ rdd, categoricalFeaturesInfo=categoricalFeaturesInfo)
+ self.assertTrue(gbt_model.predict(features[0]) <= 0)
+ self.assertTrue(gbt_model.predict(features[1]) > 0)
+ self.assertTrue(gbt_model.predict(features[2]) <= 0)
+ self.assertTrue(gbt_model.predict(features[3]) > 0)
+
def test_regression(self):
from pyspark.mllib.regression import LinearRegressionWithSGD, LassoWithSGD, \
RidgeRegressionWithSGD
- from pyspark.mllib.tree import DecisionTree
+ from pyspark.mllib.tree import DecisionTree, RandomForest, GradientBoostedTrees
data = [
LabeledPoint(-1.0, [0, -1]),
LabeledPoint(1.0, [0, 1]),
@@ -238,13 +251,27 @@ class ListTests(PySparkTestCase):
self.assertTrue(rr_model.predict(features[3]) > 0)
categoricalFeaturesInfo = {0: 2} # feature 0 has 2 categories
- dt_model = \
- DecisionTree.trainRegressor(rdd, categoricalFeaturesInfo=categoricalFeaturesInfo)
+ dt_model = DecisionTree.trainRegressor(
+ rdd, categoricalFeaturesInfo=categoricalFeaturesInfo)
self.assertTrue(dt_model.predict(features[0]) <= 0)
self.assertTrue(dt_model.predict(features[1]) > 0)
self.assertTrue(dt_model.predict(features[2]) <= 0)
self.assertTrue(dt_model.predict(features[3]) > 0)
+ rf_model = RandomForest.trainRegressor(
+ rdd, categoricalFeaturesInfo=categoricalFeaturesInfo, numTrees=100)
+ self.assertTrue(rf_model.predict(features[0]) <= 0)
+ self.assertTrue(rf_model.predict(features[1]) > 0)
+ self.assertTrue(rf_model.predict(features[2]) <= 0)
+ self.assertTrue(rf_model.predict(features[3]) > 0)
+
+ gbt_model = GradientBoostedTrees.trainRegressor(
+ rdd, categoricalFeaturesInfo=categoricalFeaturesInfo)
+ self.assertTrue(gbt_model.predict(features[0]) <= 0)
+ self.assertTrue(gbt_model.predict(features[1]) > 0)
+ self.assertTrue(gbt_model.predict(features[2]) <= 0)
+ self.assertTrue(gbt_model.predict(features[3]) > 0)
+
class StatTests(PySparkTestCase):
# SPARK-4023