aboutsummaryrefslogtreecommitdiff
path: root/python
diff options
context:
space:
mode:
authorNick Pentreath <nickp@za.ibm.com>2016-04-29 22:01:41 -0700
committerXiangrui Meng <meng@databricks.com>2016-04-29 22:01:41 -0700
commit90fa2c6e7f4893af51e0cfb6dc162b828ea55995 (patch)
treecad8481dab7030e32f8001caab2f67fd99b3c49e /python
parentd7755cfd07554c132b7271730102b8b68eb56b28 (diff)
downloadspark-90fa2c6e7f4893af51e0cfb6dc162b828ea55995.tar.gz
spark-90fa2c6e7f4893af51e0cfb6dc162b828ea55995.tar.bz2
spark-90fa2c6e7f4893af51e0cfb6dc162b828ea55995.zip
[SPARK-14412][ML][PYSPARK] Add StorageLevel params to ALS
`mllib` `ALS` supports `setIntermediateRDDStorageLevel` and `setFinalRDDStorageLevel`. This PR adds these as Params in `ml` `ALS`. They are put in group **expertParam** since few users will need them. ## How was this patch tested? New test cases in `ALSSuite` and `tests.py`. cc yanboliang jkbradley sethah rishabhbhardwaj Author: Nick Pentreath <nickp@za.ibm.com> Closes #12660 from MLnick/SPARK-14412-als-storage-params.
Diffstat (limited to 'python')
-rw-r--r--python/pyspark/ml/recommendation.py58
-rw-r--r--python/pyspark/ml/tests.py27
2 files changed, 80 insertions, 5 deletions
diff --git a/python/pyspark/ml/recommendation.py b/python/pyspark/ml/recommendation.py
index 4e42c468cc..97ac6ea83d 100644
--- a/python/pyspark/ml/recommendation.py
+++ b/python/pyspark/ml/recommendation.py
@@ -119,21 +119,35 @@ class ALS(JavaEstimator, HasCheckpointInterval, HasMaxIter, HasPredictionCol, Ha
nonnegative = Param(Params._dummy(), "nonnegative",
"whether to use nonnegative constraint for least squares",
typeConverter=TypeConverters.toBoolean)
+ intermediateRDDStorageLevel = Param(Params._dummy(), "intermediateRDDStorageLevel",
+ "StorageLevel for intermediate RDDs. Cannot be 'NONE'. " +
+ "Default: 'MEMORY_AND_DISK'.",
+ typeConverter=TypeConverters.toString)
+ finalRDDStorageLevel = Param(Params._dummy(), "finalRDDStorageLevel",
+ "StorageLevel for ALS model factor RDDs. " +
+ "Default: 'MEMORY_AND_DISK'.",
+ typeConverter=TypeConverters.toString)
@keyword_only
def __init__(self, rank=10, maxIter=10, regParam=0.1, numUserBlocks=10, numItemBlocks=10,
implicitPrefs=False, alpha=1.0, userCol="user", itemCol="item", seed=None,
- ratingCol="rating", nonnegative=False, checkpointInterval=10):
+ ratingCol="rating", nonnegative=False, checkpointInterval=10,
+ intermediateRDDStorageLevel="MEMORY_AND_DISK",
+ finalRDDStorageLevel="MEMORY_AND_DISK"):
"""
__init__(self, rank=10, maxIter=10, regParam=0.1, numUserBlocks=10, numItemBlocks=10, \
implicitPrefs=false, alpha=1.0, userCol="user", itemCol="item", seed=None, \
- ratingCol="rating", nonnegative=false, checkpointInterval=10)
+ ratingCol="rating", nonnegative=false, checkpointInterval=10, \
+ intermediateRDDStorageLevel="MEMORY_AND_DISK", \
+ finalRDDStorageLevel="MEMORY_AND_DISK")
"""
super(ALS, self).__init__()
self._java_obj = self._new_java_obj("org.apache.spark.ml.recommendation.ALS", self.uid)
self._setDefault(rank=10, maxIter=10, regParam=0.1, numUserBlocks=10, numItemBlocks=10,
implicitPrefs=False, alpha=1.0, userCol="user", itemCol="item", seed=None,
- ratingCol="rating", nonnegative=False, checkpointInterval=10)
+ ratingCol="rating", nonnegative=False, checkpointInterval=10,
+ intermediateRDDStorageLevel="MEMORY_AND_DISK",
+ finalRDDStorageLevel="MEMORY_AND_DISK")
kwargs = self.__init__._input_kwargs
self.setParams(**kwargs)
@@ -141,11 +155,15 @@ class ALS(JavaEstimator, HasCheckpointInterval, HasMaxIter, HasPredictionCol, Ha
@since("1.4.0")
def setParams(self, rank=10, maxIter=10, regParam=0.1, numUserBlocks=10, numItemBlocks=10,
implicitPrefs=False, alpha=1.0, userCol="user", itemCol="item", seed=None,
- ratingCol="rating", nonnegative=False, checkpointInterval=10):
+ ratingCol="rating", nonnegative=False, checkpointInterval=10,
+ intermediateRDDStorageLevel="MEMORY_AND_DISK",
+ finalRDDStorageLevel="MEMORY_AND_DISK"):
"""
setParams(self, rank=10, maxIter=10, regParam=0.1, numUserBlocks=10, numItemBlocks=10, \
implicitPrefs=False, alpha=1.0, userCol="user", itemCol="item", seed=None, \
- ratingCol="rating", nonnegative=False, checkpointInterval=10)
+ ratingCol="rating", nonnegative=False, checkpointInterval=10, \
+ intermediateRDDStorageLevel="MEMORY_AND_DISK", \
+ finalRDDStorageLevel="MEMORY_AND_DISK")
Sets params for ALS.
"""
kwargs = self.setParams._input_kwargs
@@ -297,6 +315,36 @@ class ALS(JavaEstimator, HasCheckpointInterval, HasMaxIter, HasPredictionCol, Ha
"""
return self.getOrDefault(self.nonnegative)
+ @since("2.0.0")
+ def setIntermediateRDDStorageLevel(self, value):
+ """
+ Sets the value of :py:attr:`intermediateRDDStorageLevel`.
+ """
+ self._set(intermediateRDDStorageLevel=value)
+ return self
+
+ @since("2.0.0")
+ def getIntermediateRDDStorageLevel(self):
+ """
+ Gets the value of intermediateRDDStorageLevel or its default value.
+ """
+ return self.getOrDefault(self.intermediateRDDStorageLevel)
+
+ @since("2.0.0")
+ def setFinalRDDStorageLevel(self, value):
+ """
+ Sets the value of :py:attr:`finalRDDStorageLevel`.
+ """
+ self._set(finalRDDStorageLevel=value)
+ return self
+
+ @since("2.0.0")
+ def getFinalRDDStorageLevel(self):
+ """
+ Gets the value of finalRDDStorageLevel or its default value.
+ """
+ return self.getOrDefault(self.finalRDDStorageLevel)
+
class ALSModel(JavaModel, JavaMLWritable, JavaMLReadable):
"""
diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py
index faca148218..7722d57e9e 100644
--- a/python/pyspark/ml/tests.py
+++ b/python/pyspark/ml/tests.py
@@ -50,12 +50,15 @@ from pyspark.ml.evaluation import BinaryClassificationEvaluator, RegressionEvalu
from pyspark.ml.feature import *
from pyspark.ml.param import Param, Params, TypeConverters
from pyspark.ml.param.shared import HasMaxIter, HasInputCol, HasSeed
+from pyspark.ml.recommendation import ALS
from pyspark.ml.regression import LinearRegression, DecisionTreeRegressor
from pyspark.ml.tuning import *
from pyspark.ml.wrapper import JavaParams
from pyspark.mllib.linalg import Vectors, DenseVector, SparseVector
from pyspark.sql import DataFrame, SQLContext, Row
from pyspark.sql.functions import rand
+from pyspark.sql.utils import IllegalArgumentException
+from pyspark.storagelevel import *
from pyspark.tests import ReusedPySparkTestCase as PySparkTestCase
@@ -999,6 +1002,30 @@ class HashingTFTest(PySparkTestCase):
": expected " + str(expected[i]) + ", got " + str(features[i]))
+class ALSTest(PySparkTestCase):
+
+ def test_storage_levels(self):
+ sqlContext = SQLContext(self.sc)
+ df = sqlContext.createDataFrame(
+ [(0, 0, 4.0), (0, 1, 2.0), (1, 1, 3.0), (1, 2, 4.0), (2, 1, 1.0), (2, 2, 5.0)],
+ ["user", "item", "rating"])
+ als = ALS().setMaxIter(1).setRank(1)
+ # test default params
+ als.fit(df)
+ self.assertEqual(als.getIntermediateRDDStorageLevel(), "MEMORY_AND_DISK")
+ self.assertEqual(als._java_obj.getIntermediateRDDStorageLevel(), "MEMORY_AND_DISK")
+ self.assertEqual(als.getFinalRDDStorageLevel(), "MEMORY_AND_DISK")
+ self.assertEqual(als._java_obj.getFinalRDDStorageLevel(), "MEMORY_AND_DISK")
+ # test non-default params
+ als.setIntermediateRDDStorageLevel("MEMORY_ONLY_2")
+ als.setFinalRDDStorageLevel("DISK_ONLY")
+ als.fit(df)
+ self.assertEqual(als.getIntermediateRDDStorageLevel(), "MEMORY_ONLY_2")
+ self.assertEqual(als._java_obj.getIntermediateRDDStorageLevel(), "MEMORY_ONLY_2")
+ self.assertEqual(als.getFinalRDDStorageLevel(), "DISK_ONLY")
+ self.assertEqual(als._java_obj.getFinalRDDStorageLevel(), "DISK_ONLY")
+
+
if __name__ == "__main__":
from pyspark.ml.tests import *
if xmlrunner: