aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark/ml/tests.py
diff options
context:
space:
mode:
authorXusen Yin <yinxusen@gmail.com>2016-04-18 11:52:29 -0700
committerJoseph K. Bradley <joseph@databricks.com>2016-04-18 11:52:29 -0700
commitb64482f49f6b9c7ff0ba64bd3202fe9cc6ad119a (patch)
tree18131b3a63a970be653d9350785dc0ab0bcbbfff /python/pyspark/ml/tests.py
parent775cf17eaaae1a38efe47b282b1d6bbdb99bd759 (diff)
downloadspark-b64482f49f6b9c7ff0ba64bd3202fe9cc6ad119a.tar.gz
spark-b64482f49f6b9c7ff0ba64bd3202fe9cc6ad119a.tar.bz2
spark-b64482f49f6b9c7ff0ba64bd3202fe9cc6ad119a.zip
[SPARK-14306][ML][PYSPARK] PySpark ml.classification OneVsRest support export/import
## What changes were proposed in this pull request? https://issues.apache.org/jira/browse/SPARK-14306 Add PySpark OneVsRest save/load supports. ## How was this patch tested? Test with Python unit test. Author: Xusen Yin <yinxusen@gmail.com> Closes #12439 from yinxusen/SPARK-14306-0415.
Diffstat (limited to 'python/pyspark/ml/tests.py')
-rw-r--r--python/pyspark/ml/tests.py25
1 files changed, 24 insertions, 1 deletions
diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py
index a7a9868bac..9d6ff47b54 100644
--- a/python/pyspark/ml/tests.py
+++ b/python/pyspark/ml/tests.py
@@ -43,7 +43,8 @@ import tempfile
import numpy as np
from pyspark.ml import Estimator, Model, Pipeline, PipelineModel, Transformer
-from pyspark.ml.classification import LogisticRegression, DecisionTreeClassifier, OneVsRest
+from pyspark.ml.classification import (
+ LogisticRegression, DecisionTreeClassifier, OneVsRest, OneVsRestModel)
from pyspark.ml.clustering import KMeans
from pyspark.ml.evaluation import BinaryClassificationEvaluator, RegressionEvaluator
from pyspark.ml.feature import *
@@ -881,6 +882,28 @@ class OneVsRestTests(PySparkTestCase):
output = model.transform(df)
self.assertEqual(output.columns, ["label", "features", "prediction"])
+ def test_save_load(self):
+ temp_path = tempfile.mkdtemp()
+ sqlContext = SQLContext(self.sc)
+ df = sqlContext.createDataFrame([(0.0, Vectors.dense(1.0, 0.8)),
+ (1.0, Vectors.sparse(2, [], [])),
+ (2.0, Vectors.dense(0.5, 0.5))],
+ ["label", "features"])
+ lr = LogisticRegression(maxIter=5, regParam=0.01)
+ ovr = OneVsRest(classifier=lr)
+ model = ovr.fit(df)
+ ovrPath = temp_path + "/ovr"
+ ovr.save(ovrPath)
+ loadedOvr = OneVsRest.load(ovrPath)
+ self.assertEqual(loadedOvr.getFeaturesCol(), ovr.getFeaturesCol())
+ self.assertEqual(loadedOvr.getLabelCol(), ovr.getLabelCol())
+ self.assertEqual(loadedOvr.getClassifier().uid, ovr.getClassifier().uid)
+ modelPath = temp_path + "/ovrModel"
+ model.save(modelPath)
+ loadedModel = OneVsRestModel.load(modelPath)
+ for m, n in zip(model.models, loadedModel.models):
+ self.assertEqual(m.uid, n.uid)
+
class HashingTFTest(PySparkTestCase):