aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark
diff options
context:
space:
mode:
authorNick Pentreath <nickp@za.ibm.com>2017-02-28 16:17:35 +0200
committerNick Pentreath <nickp@za.ibm.com>2017-02-28 16:17:35 +0200
commitb405466513bcc02cadf1477b6b682ace95d81658 (patch)
tree5f1d0b2e6ebe9b8c463010bca8bea4074ad5ef86 /python/pyspark
parent9b8eca65dcf68129470ead39362ce870ffb0bb1d (diff)
downloadspark-b405466513bcc02cadf1477b6b682ace95d81658.tar.gz
spark-b405466513bcc02cadf1477b6b682ace95d81658.tar.bz2
spark-b405466513bcc02cadf1477b6b682ace95d81658.zip
[SPARK-14489][ML][PYSPARK] ALS unknown user/item prediction strategy
This PR adds a param to `ALS`/`ALSModel` to set the strategy used when encountering unknown users or items at prediction time in `transform`. This can occur in 2 scenarios: (a) production scoring, and (b) cross-validation & evaluation. The current behavior returns `NaN` if a user/item is unknown. In scenario (b), this can easily occur when using `CrossValidator` or `TrainValidationSplit` since some users/items may only occur in the test set and not in the training set. In this case, the evaluator returns `NaN` for all metrics, making model selection impossible. The new param, `coldStartStrategy`, defaults to `nan` (the current behavior). The other option supported initially is `drop`, which drops all rows with `NaN` predictions. This flag allows users to use `ALS` in cross-validation settings. It is made an `expertParam`. The param is made a string so that the set of strategies can be extended in future (some options are discussed in [SPARK-14489](https://issues.apache.org/jira/browse/SPARK-14489)). ## How was this patch tested? New unit tests, and manual "before and after" tests for Scala & Python using MovieLens `ml-latest-small` as example data. Here, using `CrossValidator` or `TrainValidationSplit` with the default param setting results in metrics that are all `NaN`, while setting `coldStartStrategy` to `drop` results in valid metrics. Author: Nick Pentreath <nickp@za.ibm.com> Closes #12896 from MLnick/SPARK-14489-als-nan.
Diffstat (limited to 'python/pyspark')
-rw-r--r--python/pyspark/ml/recommendation.py30
1 files changed, 25 insertions, 5 deletions
diff --git a/python/pyspark/ml/recommendation.py b/python/pyspark/ml/recommendation.py
index e28d38bd19..43f82daa9f 100644
--- a/python/pyspark/ml/recommendation.py
+++ b/python/pyspark/ml/recommendation.py
@@ -125,19 +125,25 @@ class ALS(JavaEstimator, HasCheckpointInterval, HasMaxIter, HasPredictionCol, Ha
finalStorageLevel = Param(Params._dummy(), "finalStorageLevel",
"StorageLevel for ALS model factors.",
typeConverter=TypeConverters.toString)
+ coldStartStrategy = Param(Params._dummy(), "coldStartStrategy", "strategy for dealing with " +
+ "unknown or new users/items at prediction time. This may be useful " +
+ "in cross-validation or production scenarios, for handling " +
+ "user/item ids the model has not seen in the training data. " +
+ "Supported values: 'nan', 'drop'.",
+ typeConverter=TypeConverters.toString)
@keyword_only
def __init__(self, rank=10, maxIter=10, regParam=0.1, numUserBlocks=10, numItemBlocks=10,
implicitPrefs=False, alpha=1.0, userCol="user", itemCol="item", seed=None,
ratingCol="rating", nonnegative=False, checkpointInterval=10,
intermediateStorageLevel="MEMORY_AND_DISK",
- finalStorageLevel="MEMORY_AND_DISK"):
+ finalStorageLevel="MEMORY_AND_DISK", coldStartStrategy="nan"):
"""
__init__(self, rank=10, maxIter=10, regParam=0.1, numUserBlocks=10, numItemBlocks=10, \
implicitPrefs=false, alpha=1.0, userCol="user", itemCol="item", seed=None, \
ratingCol="rating", nonnegative=false, checkpointInterval=10, \
intermediateStorageLevel="MEMORY_AND_DISK", \
- finalStorageLevel="MEMORY_AND_DISK")
+ finalStorageLevel="MEMORY_AND_DISK", coldStartStrategy="nan")
"""
super(ALS, self).__init__()
self._java_obj = self._new_java_obj("org.apache.spark.ml.recommendation.ALS", self.uid)
@@ -145,7 +151,7 @@ class ALS(JavaEstimator, HasCheckpointInterval, HasMaxIter, HasPredictionCol, Ha
implicitPrefs=False, alpha=1.0, userCol="user", itemCol="item",
ratingCol="rating", nonnegative=False, checkpointInterval=10,
intermediateStorageLevel="MEMORY_AND_DISK",
- finalStorageLevel="MEMORY_AND_DISK")
+ finalStorageLevel="MEMORY_AND_DISK", coldStartStrategy="nan")
kwargs = self.__init__._input_kwargs
self.setParams(**kwargs)
@@ -155,13 +161,13 @@ class ALS(JavaEstimator, HasCheckpointInterval, HasMaxIter, HasPredictionCol, Ha
implicitPrefs=False, alpha=1.0, userCol="user", itemCol="item", seed=None,
ratingCol="rating", nonnegative=False, checkpointInterval=10,
intermediateStorageLevel="MEMORY_AND_DISK",
- finalStorageLevel="MEMORY_AND_DISK"):
+ finalStorageLevel="MEMORY_AND_DISK", coldStartStrategy="nan"):
"""
setParams(self, rank=10, maxIter=10, regParam=0.1, numUserBlocks=10, numItemBlocks=10, \
implicitPrefs=False, alpha=1.0, userCol="user", itemCol="item", seed=None, \
ratingCol="rating", nonnegative=False, checkpointInterval=10, \
intermediateStorageLevel="MEMORY_AND_DISK", \
- finalStorageLevel="MEMORY_AND_DISK")
+ finalStorageLevel="MEMORY_AND_DISK", coldStartStrategy="nan")
Sets params for ALS.
"""
kwargs = self.setParams._input_kwargs
@@ -332,6 +338,20 @@ class ALS(JavaEstimator, HasCheckpointInterval, HasMaxIter, HasPredictionCol, Ha
"""
return self.getOrDefault(self.finalStorageLevel)
+ @since("2.2.0")
+ def setColdStartStrategy(self, value):
+ """
+ Sets the value of :py:attr:`coldStartStrategy`.
+ """
+ return self._set(coldStartStrategy=value)
+
+ @since("2.2.0")
+ def getColdStartStrategy(self):
+ """
+ Gets the value of coldStartStrategy or its default value.
+ """
+ return self.getOrDefault(self.coldStartStrategy)
+
class ALSModel(JavaModel, JavaMLWritable, JavaMLReadable):
"""