diff options
author | Nick Pentreath <nickp@za.ibm.com> | 2017-02-28 16:17:35 +0200 |
---|---|---|
committer | Nick Pentreath <nickp@za.ibm.com> | 2017-02-28 16:17:35 +0200 |
commit | b405466513bcc02cadf1477b6b682ace95d81658 (patch) | |
tree | 5f1d0b2e6ebe9b8c463010bca8bea4074ad5ef86 /python/pyspark/ml | |
parent | 9b8eca65dcf68129470ead39362ce870ffb0bb1d (diff) | |
download | spark-b405466513bcc02cadf1477b6b682ace95d81658.tar.gz spark-b405466513bcc02cadf1477b6b682ace95d81658.tar.bz2 spark-b405466513bcc02cadf1477b6b682ace95d81658.zip |
[SPARK-14489][ML][PYSPARK] ALS unknown user/item prediction strategy
This PR adds a param to `ALS`/`ALSModel` to set the strategy used when encountering unknown users or items at prediction time in `transform`. This can occur in 2 scenarios: (a) production scoring, and (b) cross-validation & evaluation.
The current behavior returns `NaN` if a user/item is unknown. In scenario (b), this can easily occur when using `CrossValidator` or `TrainValidationSplit` since some users/items may only occur in the test set and not in the training set. In this case, the evaluator returns `NaN` for all metrics, making model selection impossible.
The new param, `coldStartStrategy`, defaults to `nan` (the current behavior). The other option supported initially is `drop`, which drops all rows with `NaN` predictions. This flag allows users to use `ALS` in cross-validation settings. It is made an `expertParam`. The param is made a string so that the set of strategies can be extended in future (some options are discussed in [SPARK-14489](https://issues.apache.org/jira/browse/SPARK-14489)).
## How was this patch tested?
New unit tests, and manual "before and after" tests for Scala & Python using MovieLens `ml-latest-small` as example data. Here, using `CrossValidator` or `TrainValidationSplit` with the default param setting results in metrics that are all `NaN`, while setting `coldStartStrategy` to `drop` results in valid metrics.
Author: Nick Pentreath <nickp@za.ibm.com>
Closes #12896 from MLnick/SPARK-14489-als-nan.
Diffstat (limited to 'python/pyspark/ml')
-rw-r--r-- | python/pyspark/ml/recommendation.py | 30 |
1 files changed, 25 insertions, 5 deletions
diff --git a/python/pyspark/ml/recommendation.py b/python/pyspark/ml/recommendation.py index e28d38bd19..43f82daa9f 100644 --- a/python/pyspark/ml/recommendation.py +++ b/python/pyspark/ml/recommendation.py @@ -125,19 +125,25 @@ class ALS(JavaEstimator, HasCheckpointInterval, HasMaxIter, HasPredictionCol, Ha finalStorageLevel = Param(Params._dummy(), "finalStorageLevel", "StorageLevel for ALS model factors.", typeConverter=TypeConverters.toString) + coldStartStrategy = Param(Params._dummy(), "coldStartStrategy", "strategy for dealing with " + + "unknown or new users/items at prediction time. This may be useful " + + "in cross-validation or production scenarios, for handling " + + "user/item ids the model has not seen in the training data. " + + "Supported values: 'nan', 'drop'.", + typeConverter=TypeConverters.toString) @keyword_only def __init__(self, rank=10, maxIter=10, regParam=0.1, numUserBlocks=10, numItemBlocks=10, implicitPrefs=False, alpha=1.0, userCol="user", itemCol="item", seed=None, ratingCol="rating", nonnegative=False, checkpointInterval=10, intermediateStorageLevel="MEMORY_AND_DISK", - finalStorageLevel="MEMORY_AND_DISK"): + finalStorageLevel="MEMORY_AND_DISK", coldStartStrategy="nan"): """ __init__(self, rank=10, maxIter=10, regParam=0.1, numUserBlocks=10, numItemBlocks=10, \ implicitPrefs=false, alpha=1.0, userCol="user", itemCol="item", seed=None, \ ratingCol="rating", nonnegative=false, checkpointInterval=10, \ intermediateStorageLevel="MEMORY_AND_DISK", \ - finalStorageLevel="MEMORY_AND_DISK") + finalStorageLevel="MEMORY_AND_DISK", coldStartStrategy="nan") """ super(ALS, self).__init__() self._java_obj = self._new_java_obj("org.apache.spark.ml.recommendation.ALS", self.uid) @@ -145,7 +151,7 @@ class ALS(JavaEstimator, HasCheckpointInterval, HasMaxIter, HasPredictionCol, Ha implicitPrefs=False, alpha=1.0, userCol="user", itemCol="item", ratingCol="rating", nonnegative=False, checkpointInterval=10, intermediateStorageLevel="MEMORY_AND_DISK", - finalStorageLevel="MEMORY_AND_DISK") + finalStorageLevel="MEMORY_AND_DISK", coldStartStrategy="nan") kwargs = self.__init__._input_kwargs self.setParams(**kwargs) @@ -155,13 +161,13 @@ class ALS(JavaEstimator, HasCheckpointInterval, HasMaxIter, HasPredictionCol, Ha implicitPrefs=False, alpha=1.0, userCol="user", itemCol="item", seed=None, ratingCol="rating", nonnegative=False, checkpointInterval=10, intermediateStorageLevel="MEMORY_AND_DISK", - finalStorageLevel="MEMORY_AND_DISK"): + finalStorageLevel="MEMORY_AND_DISK", coldStartStrategy="nan"): """ setParams(self, rank=10, maxIter=10, regParam=0.1, numUserBlocks=10, numItemBlocks=10, \ implicitPrefs=False, alpha=1.0, userCol="user", itemCol="item", seed=None, \ ratingCol="rating", nonnegative=False, checkpointInterval=10, \ intermediateStorageLevel="MEMORY_AND_DISK", \ - finalStorageLevel="MEMORY_AND_DISK") + finalStorageLevel="MEMORY_AND_DISK", coldStartStrategy="nan") Sets params for ALS. """ kwargs = self.setParams._input_kwargs @@ -332,6 +338,20 @@ class ALS(JavaEstimator, HasCheckpointInterval, HasMaxIter, HasPredictionCol, Ha """ return self.getOrDefault(self.finalStorageLevel) + @since("2.2.0") + def setColdStartStrategy(self, value): + """ + Sets the value of :py:attr:`coldStartStrategy`. + """ + return self._set(coldStartStrategy=value) + + @since("2.2.0") + def getColdStartStrategy(self): + """ + Gets the value of coldStartStrategy or its default value. + """ + return self.getOrDefault(self.coldStartStrategy) + class ALSModel(JavaModel, JavaMLWritable, JavaMLReadable): """ |