aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark/ml
diff options
context:
space:
mode:
Diffstat (limited to 'python/pyspark/ml')
-rw-r--r--python/pyspark/ml/tests.py5
-rw-r--r--python/pyspark/ml/util.py4
2 files changed, 8 insertions, 1 deletions
diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py
index 224232ed7f..f6159b2c95 100644
--- a/python/pyspark/ml/tests.py
+++ b/python/pyspark/ml/tests.py
@@ -51,6 +51,7 @@ from pyspark.ml.param.shared import HasMaxIter, HasInputCol, HasSeed
from pyspark.ml.regression import LinearRegression, DecisionTreeRegressor
from pyspark.ml.tuning import *
from pyspark.ml.util import keyword_only
+from pyspark.ml.util import MLWritable, MLWriter
from pyspark.ml.wrapper import JavaWrapper
from pyspark.mllib.linalg import DenseVector, SparseVector
from pyspark.sql import DataFrame, SQLContext, Row
@@ -655,6 +656,10 @@ class PersistenceTest(PySparkTestCase):
except OSError:
pass
+ def test_write_property(self):
+ lr = LinearRegression(maxIter=1)
+ self.assertTrue(isinstance(lr.write, MLWriter))
+
def test_decisiontree_classifier(self):
dt = DecisionTreeClassifier(maxDepth=1)
path = tempfile.mkdtemp()
diff --git a/python/pyspark/ml/util.py b/python/pyspark/ml/util.py
index 6703851262..d4411fdfb9 100644
--- a/python/pyspark/ml/util.py
+++ b/python/pyspark/ml/util.py
@@ -134,13 +134,14 @@ class MLWritable(object):
.. versionadded:: 2.0.0
"""
+ @property
def write(self):
"""Returns an JavaMLWriter instance for this ML instance."""
raise NotImplementedError("MLWritable is not yet implemented for type: %r" % type(self))
def save(self, path):
"""Save this ML instance to the given path, a shortcut of `write().save(path)`."""
- self.write().save(path)
+ self.write.save(path)
@inherit_doc
@@ -149,6 +150,7 @@ class JavaMLWritable(MLWritable):
(Private) Mixin for ML instances that provide :py:class:`JavaMLWriter`.
"""
+ @property
def write(self):
"""Returns an JavaMLWriter instance for this ML instance."""
return JavaMLWriter(self)