From 63b200e8d4a05d5b744d437fd10781c6b5429da9 Mon Sep 17 00:00:00 2001 From: "wm624@hotmail.com" Date: Mon, 28 Mar 2016 22:33:25 -0700 Subject: [SPARK-14071][PYSPARK][ML] Change MLWritable.write to be a property Add property to MLWritable.write method, so we can use .write instead of .write() Add a new test to ml/test.py to check whether the write is a property. ./python/run-tests --python-executables=python2.7 --modules=pyspark-ml Will test against the following Python executables: ['python2.7'] Will test the following Python modules: ['pyspark-ml'] Finished test(python2.7): pyspark.ml.evaluation (11s) Finished test(python2.7): pyspark.ml.clustering (16s) Finished test(python2.7): pyspark.ml.classification (24s) Finished test(python2.7): pyspark.ml.recommendation (24s) Finished test(python2.7): pyspark.ml.feature (39s) Finished test(python2.7): pyspark.ml.regression (26s) Finished test(python2.7): pyspark.ml.tuning (15s) Finished test(python2.7): pyspark.ml.tests (30s) Tests passed in 55 seconds Author: wm624@hotmail.com Closes #11945 from wangmiao1981/fix_property. --- python/pyspark/ml/tests.py | 5 +++++ python/pyspark/ml/util.py | 4 +++- 2 files changed, 8 insertions(+), 1 deletion(-) (limited to 'python/pyspark/ml') diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py index 224232ed7f..f6159b2c95 100644 --- a/python/pyspark/ml/tests.py +++ b/python/pyspark/ml/tests.py @@ -51,6 +51,7 @@ from pyspark.ml.param.shared import HasMaxIter, HasInputCol, HasSeed from pyspark.ml.regression import LinearRegression, DecisionTreeRegressor from pyspark.ml.tuning import * from pyspark.ml.util import keyword_only +from pyspark.ml.util import MLWritable, MLWriter from pyspark.ml.wrapper import JavaWrapper from pyspark.mllib.linalg import DenseVector, SparseVector from pyspark.sql import DataFrame, SQLContext, Row @@ -655,6 +656,10 @@ class PersistenceTest(PySparkTestCase): except OSError: pass + def test_write_property(self): + lr = LinearRegression(maxIter=1) + self.assertTrue(isinstance(lr.write, MLWriter)) + def test_decisiontree_classifier(self): dt = DecisionTreeClassifier(maxDepth=1) path = tempfile.mkdtemp() diff --git a/python/pyspark/ml/util.py b/python/pyspark/ml/util.py index 6703851262..d4411fdfb9 100644 --- a/python/pyspark/ml/util.py +++ b/python/pyspark/ml/util.py @@ -134,13 +134,14 @@ class MLWritable(object): .. versionadded:: 2.0.0 """ + @property def write(self): """Returns an JavaMLWriter instance for this ML instance.""" raise NotImplementedError("MLWritable is not yet implemented for type: %r" % type(self)) def save(self, path): """Save this ML instance to the given path, a shortcut of `write().save(path)`.""" - self.write().save(path) + self.write.save(path) @inherit_doc @@ -149,6 +150,7 @@ class JavaMLWritable(MLWritable): (Private) Mixin for ML instances that provide :py:class:`JavaMLWriter`. """ + @property def write(self): """Returns an JavaMLWriter instance for this ML instance.""" return JavaMLWriter(self) -- cgit v1.2.3