aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
authorNick Pentreath <nickp@za.ibm.com>2017-02-28 16:17:35 +0200
committerNick Pentreath <nickp@za.ibm.com>2017-02-28 16:17:35 +0200
commitb405466513bcc02cadf1477b6b682ace95d81658 (patch)
tree5f1d0b2e6ebe9b8c463010bca8bea4074ad5ef86 /mllib
parent9b8eca65dcf68129470ead39362ce870ffb0bb1d (diff)
downloadspark-b405466513bcc02cadf1477b6b682ace95d81658.tar.gz
spark-b405466513bcc02cadf1477b6b682ace95d81658.tar.bz2
spark-b405466513bcc02cadf1477b6b682ace95d81658.zip
[SPARK-14489][ML][PYSPARK] ALS unknown user/item prediction strategy
This PR adds a param to `ALS`/`ALSModel` to set the strategy used when encountering unknown users or items at prediction time in `transform`. This can occur in 2 scenarios: (a) production scoring, and (b) cross-validation & evaluation. The current behavior returns `NaN` if a user/item is unknown. In scenario (b), this can easily occur when using `CrossValidator` or `TrainValidationSplit` since some users/items may only occur in the test set and not in the training set. In this case, the evaluator returns `NaN` for all metrics, making model selection impossible. The new param, `coldStartStrategy`, defaults to `nan` (the current behavior). The other option supported initially is `drop`, which drops all rows with `NaN` predictions. This flag allows users to use `ALS` in cross-validation settings. It is made an `expertParam`. The param is made a string so that the set of strategies can be extended in future (some options are discussed in [SPARK-14489](https://issues.apache.org/jira/browse/SPARK-14489)). ## How was this patch tested? New unit tests, and manual "before and after" tests for Scala & Python using MovieLens `ml-latest-small` as example data. Here, using `CrossValidator` or `TrainValidationSplit` with the default param setting results in metrics that are all `NaN`, while setting `coldStartStrategy` to `drop` results in valid metrics. Author: Nick Pentreath <nickp@za.ibm.com> Closes #12896 from MLnick/SPARK-14489-als-nan.
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala44
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala51
2 files changed, 91 insertions, 4 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala
index 97c8655298..af007625d1 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala
@@ -90,6 +90,27 @@ private[recommendation] trait ALSModelParams extends Params with HasPredictionCo
n.toInt
}
}
+
+ /**
+ * Param for strategy for dealing with unknown or new users/items at prediction time.
+ * This may be useful in cross-validation or production scenarios, for handling user/item ids
+ * the model has not seen in the training data.
+ * Supported values:
+ * - "nan": predicted value for unknown ids will be NaN.
+ * - "drop": rows in the input DataFrame containing unknown ids will be dropped from
+ * the output DataFrame containing predictions.
+ * Default: "nan".
+ * @group expertParam
+ */
+ val coldStartStrategy = new Param[String](this, "coldStartStrategy",
+ "strategy for dealing with unknown or new users/items at prediction time. This may be " +
+ "useful in cross-validation or production scenarios, for handling user/item ids the model " +
+ "has not seen in the training data. Supported values: " +
+ s"${ALSModel.supportedColdStartStrategies.mkString(",")}.",
+ (s: String) => ALSModel.supportedColdStartStrategies.contains(s.toLowerCase))
+
+ /** @group expertGetParam */
+ def getColdStartStrategy: String = $(coldStartStrategy).toLowerCase
}
/**
@@ -203,7 +224,8 @@ private[recommendation] trait ALSParams extends ALSModelParams with HasMaxIter w
setDefault(rank -> 10, maxIter -> 10, regParam -> 0.1, numUserBlocks -> 10, numItemBlocks -> 10,
implicitPrefs -> false, alpha -> 1.0, userCol -> "user", itemCol -> "item",
ratingCol -> "rating", nonnegative -> false, checkpointInterval -> 10,
- intermediateStorageLevel -> "MEMORY_AND_DISK", finalStorageLevel -> "MEMORY_AND_DISK")
+ intermediateStorageLevel -> "MEMORY_AND_DISK", finalStorageLevel -> "MEMORY_AND_DISK",
+ coldStartStrategy -> "nan")
/**
* Validates and transforms the input schema.
@@ -248,6 +270,10 @@ class ALSModel private[ml] (
@Since("1.3.0")
def setPredictionCol(value: String): this.type = set(predictionCol, value)
+ /** @group expertSetParam */
+ @Since("2.2.0")
+ def setColdStartStrategy(value: String): this.type = set(coldStartStrategy, value)
+
@Since("2.0.0")
override def transform(dataset: Dataset[_]): DataFrame = {
transformSchema(dataset.schema)
@@ -260,13 +286,19 @@ class ALSModel private[ml] (
Float.NaN
}
}
- dataset
+ val predictions = dataset
.join(userFactors,
checkedCast(dataset($(userCol)).cast(DoubleType)) === userFactors("id"), "left")
.join(itemFactors,
checkedCast(dataset($(itemCol)).cast(DoubleType)) === itemFactors("id"), "left")
.select(dataset("*"),
predict(userFactors("features"), itemFactors("features")).as($(predictionCol)))
+ getColdStartStrategy match {
+ case ALSModel.Drop =>
+ predictions.na.drop("all", Seq($(predictionCol)))
+ case ALSModel.NaN =>
+ predictions
+ }
}
@Since("1.3.0")
@@ -290,6 +322,10 @@ class ALSModel private[ml] (
@Since("1.6.0")
object ALSModel extends MLReadable[ALSModel] {
+ private val NaN = "nan"
+ private val Drop = "drop"
+ private[recommendation] final val supportedColdStartStrategies = Array(NaN, Drop)
+
@Since("1.6.0")
override def read: MLReader[ALSModel] = new ALSModelReader
@@ -432,6 +468,10 @@ class ALS(@Since("1.4.0") override val uid: String) extends Estimator[ALSModel]
@Since("2.0.0")
def setFinalStorageLevel(value: String): this.type = set(finalStorageLevel, value)
+ /** @group expertSetParam */
+ @Since("2.2.0")
+ def setColdStartStrategy(value: String): this.type = set(coldStartStrategy, value)
+
/**
* Sets both numUserBlocks and numItemBlocks to the specific value.
*
diff --git a/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala
index b923bacce2..c9e7b505b2 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala
@@ -498,8 +498,8 @@ class ALSSuite
(ex, act) =>
ex.userFactors.first().getSeq[Float](1) === act.userFactors.first.getSeq[Float](1)
} { (ex, act, _) =>
- ex.transform(_: DataFrame).select("prediction").first.getFloat(0) ~==
- act.transform(_: DataFrame).select("prediction").first.getFloat(0) absTol 1e-6
+ ex.transform(_: DataFrame).select("prediction").first.getDouble(0) ~==
+ act.transform(_: DataFrame).select("prediction").first.getDouble(0) absTol 1e-6
}
}
// check user/item ids falling outside of Int range
@@ -547,6 +547,53 @@ class ALSSuite
ALS.train(ratings)
}
}
+
+ test("ALS cold start user/item prediction strategy") {
+ val spark = this.spark
+ import spark.implicits._
+ import org.apache.spark.sql.functions._
+
+ val (ratings, _) = genExplicitTestData(numUsers = 4, numItems = 4, rank = 1)
+ val data = ratings.toDF
+ val knownUser = data.select(max("user")).as[Int].first()
+ val unknownUser = knownUser + 10
+ val knownItem = data.select(max("item")).as[Int].first()
+ val unknownItem = knownItem + 20
+ val test = Seq(
+ (unknownUser, unknownItem),
+ (knownUser, unknownItem),
+ (unknownUser, knownItem),
+ (knownUser, knownItem)
+ ).toDF("user", "item")
+
+ val als = new ALS().setMaxIter(1).setRank(1)
+ // default is 'nan'
+ val defaultModel = als.fit(data)
+ val defaultPredictions = defaultModel.transform(test).select("prediction").as[Float].collect()
+ assert(defaultPredictions.length == 4)
+ assert(defaultPredictions.slice(0, 3).forall(_.isNaN))
+ assert(!defaultPredictions.last.isNaN)
+
+ // check 'drop' strategy should filter out rows with unknown users/items
+ val dropPredictions = defaultModel
+ .setColdStartStrategy("drop")
+ .transform(test)
+ .select("prediction").as[Float].collect()
+ assert(dropPredictions.length == 1)
+ assert(!dropPredictions.head.isNaN)
+ assert(dropPredictions.head ~== defaultPredictions.last relTol 1e-14)
+ }
+
+ test("case insensitive cold start param value") {
+ val spark = this.spark
+ import spark.implicits._
+ val (ratings, _) = genExplicitTestData(numUsers = 2, numItems = 2, rank = 1)
+ val data = ratings.toDF
+ val model = new ALS().fit(data)
+ Seq("nan", "NaN", "Nan", "drop", "DROP", "Drop").foreach { s =>
+ model.setColdStartStrategy(s).transform(data)
+ }
+ }
}
class ALSCleanerSuite extends SparkFunSuite {