aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
authorNick Pentreath <nickp@za.ibm.com>2016-05-18 21:13:12 +0200
committerNick Pentreath <nickp@za.ibm.com>2016-05-18 21:13:12 +0200
commite8b79afa024123f9d4ceaf0a1043a7e37d913a8d (patch)
tree112148429c15f0eed3ef53e6253898521352a011 /mllib
parent3d1e67f903ab3512fcad82b94b1825578f8117c9 (diff)
downloadspark-e8b79afa024123f9d4ceaf0a1043a7e37d913a8d.tar.gz
spark-e8b79afa024123f9d4ceaf0a1043a7e37d913a8d.tar.bz2
spark-e8b79afa024123f9d4ceaf0a1043a7e37d913a8d.zip
[SPARK-14891][ML] Add schema validation for ALS
This PR adds schema validation to `ml`'s ALS and ALSModel. Currently, no schema validation was performed as `transformSchema` was never called in `ALS.fit` or `ALSModel.transform`. Furthermore, due to no schema validation, if users passed in Long (or Float etc) ids, they would be silently cast to Int with no warning or error thrown. With this PR, ALS now supports all numeric types for `user`, `item`, and `rating` columns. The rating column is cast to `Float` and the user and item cols are cast to `Int` (as is the case currently) - however for user/item, the cast throws an error if the value is outside integer range. Behavior for rating col is unchanged (as it is not an issue). ## How was this patch tested? New test cases in `ALSSuite`. Author: Nick Pentreath <nickp@za.ibm.com> Closes #12762 from MLnick/SPARK-14891-als-validate-schema.
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala55
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala61
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/util/MLTestingUtils.scala45
3 files changed, 143 insertions, 18 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala
index 509c944fed..f257382d22 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala
@@ -42,7 +42,7 @@ import org.apache.spark.mllib.optimization.NNLS
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{DataFrame, Dataset}
import org.apache.spark.sql.functions._
-import org.apache.spark.sql.types.{DoubleType, FloatType, IntegerType, StructType}
+import org.apache.spark.sql.types._
import org.apache.spark.storage.StorageLevel
import org.apache.spark.util.Utils
import org.apache.spark.util.collection.{OpenHashMap, OpenHashSet, SortDataFormat, Sorter}
@@ -53,24 +53,43 @@ import org.apache.spark.util.random.XORShiftRandom
*/
private[recommendation] trait ALSModelParams extends Params with HasPredictionCol {
/**
- * Param for the column name for user ids.
+ * Param for the column name for user ids. Ids must be integers. Other
+ * numeric types are supported for this column, but will be cast to integers as long as they
+ * fall within the integer value range.
* Default: "user"
* @group param
*/
- val userCol = new Param[String](this, "userCol", "column name for user ids")
+ val userCol = new Param[String](this, "userCol", "column name for user ids. Ids must be within " +
+ "the integer value range.")
/** @group getParam */
def getUserCol: String = $(userCol)
/**
- * Param for the column name for item ids.
+ * Param for the column name for item ids. Ids must be integers. Other
+ * numeric types are supported for this column, but will be cast to integers as long as they
+ * fall within the integer value range.
* Default: "item"
* @group param
*/
- val itemCol = new Param[String](this, "itemCol", "column name for item ids")
+ val itemCol = new Param[String](this, "itemCol", "column name for item ids. Ids must be within " +
+ "the integer value range.")
/** @group getParam */
def getItemCol: String = $(itemCol)
+
+ /**
+ * Attempts to safely cast a user/item id to an Int. Throws an exception if the value is
+ * out of integer range.
+ */
+ protected val checkedCast = udf { (n: Double) =>
+ if (n > Int.MaxValue || n < Int.MinValue) {
+ throw new IllegalArgumentException(s"ALS only supports values in Integer range for columns " +
+ s"${$(userCol)} and ${$(itemCol)}. Value $n was out of Integer range.")
+ } else {
+ n.toInt
+ }
+ }
}
/**
@@ -193,10 +212,11 @@ private[recommendation] trait ALSParams extends ALSModelParams with HasMaxIter w
* @return output schema
*/
protected def validateAndTransformSchema(schema: StructType): StructType = {
- SchemaUtils.checkColumnType(schema, $(userCol), IntegerType)
- SchemaUtils.checkColumnType(schema, $(itemCol), IntegerType)
- val ratingType = schema($(ratingCol)).dataType
- require(ratingType == FloatType || ratingType == DoubleType)
+ // user and item will be cast to Int
+ SchemaUtils.checkNumericType(schema, $(userCol))
+ SchemaUtils.checkNumericType(schema, $(itemCol))
+ // rating will be cast to Float
+ SchemaUtils.checkNumericType(schema, $(ratingCol))
SchemaUtils.appendColumn(schema, $(predictionCol), FloatType)
}
}
@@ -232,6 +252,7 @@ class ALSModel private[ml] (
@Since("2.0.0")
override def transform(dataset: Dataset[_]): DataFrame = {
+ transformSchema(dataset.schema)
// Register a UDF for DataFrame, and then
// create a new column named map(predictionCol) by running the predict UDF.
val predict = udf { (userFeatures: Seq[Float], itemFeatures: Seq[Float]) =>
@@ -242,16 +263,19 @@ class ALSModel private[ml] (
}
}
dataset
- .join(userFactors, dataset($(userCol)) === userFactors("id"), "left")
- .join(itemFactors, dataset($(itemCol)) === itemFactors("id"), "left")
+ .join(userFactors,
+ checkedCast(dataset($(userCol)).cast(DoubleType)) === userFactors("id"), "left")
+ .join(itemFactors,
+ checkedCast(dataset($(itemCol)).cast(DoubleType)) === itemFactors("id"), "left")
.select(dataset("*"),
predict(userFactors("features"), itemFactors("features")).as($(predictionCol)))
}
@Since("1.3.0")
override def transformSchema(schema: StructType): StructType = {
- SchemaUtils.checkColumnType(schema, $(userCol), IntegerType)
- SchemaUtils.checkColumnType(schema, $(itemCol), IntegerType)
+ // user and item will be cast to Int
+ SchemaUtils.checkNumericType(schema, $(userCol))
+ SchemaUtils.checkNumericType(schema, $(itemCol))
SchemaUtils.appendColumn(schema, $(predictionCol), FloatType)
}
@@ -430,10 +454,13 @@ class ALS(@Since("1.4.0") override val uid: String) extends Estimator[ALSModel]
@Since("2.0.0")
override def fit(dataset: Dataset[_]): ALSModel = {
+ transformSchema(dataset.schema)
import dataset.sparkSession.implicits._
+
val r = if ($(ratingCol) != "") col($(ratingCol)).cast(FloatType) else lit(1.0f)
val ratings = dataset
- .select(col($(userCol)).cast(IntegerType), col($(itemCol)).cast(IntegerType), r)
+ .select(checkedCast(col($(userCol)).cast(DoubleType)),
+ checkedCast(col($(itemCol)).cast(DoubleType)), r)
.rdd
.map { row =>
Rating(row.getInt(0), row.getInt(1), row.getFloat(2))
diff --git a/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala
index bbfc415cbb..59b5edc401 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala
@@ -39,6 +39,7 @@ import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.rdd.RDD
import org.apache.spark.scheduler.{SparkListener, SparkListenerStageCompleted}
import org.apache.spark.sql.{DataFrame, Row, SparkSession}
+import org.apache.spark.sql.types.{FloatType, IntegerType}
import org.apache.spark.storage.StorageLevel
import org.apache.spark.util.Utils
@@ -205,7 +206,6 @@ class ALSSuite
/**
* Generates an explicit feedback dataset for testing ALS.
- *
* @param numUsers number of users
* @param numItems number of items
* @param rank rank
@@ -246,7 +246,6 @@ class ALSSuite
/**
* Generates an implicit feedback dataset for testing ALS.
- *
* @param numUsers number of users
* @param numItems number of items
* @param rank rank
@@ -265,7 +264,6 @@ class ALSSuite
/**
* Generates random user/item factors, with i.i.d. values drawn from U(a, b).
- *
* @param size number of users/items
* @param rank number of features
* @param random random number generator
@@ -284,7 +282,6 @@ class ALSSuite
/**
* Test ALS using the given training/test splits and parameters.
- *
* @param training training dataset
* @param test test dataset
* @param rank rank of the matrix factorization
@@ -486,6 +483,62 @@ class ALSSuite
assert(getFactors(model.userFactors) === getFactors(model2.userFactors))
assert(getFactors(model.itemFactors) === getFactors(model2.itemFactors))
}
+
+ test("input type validation") {
+ val spark = this.spark
+ import spark.implicits._
+
+ // check that ALS can handle all numeric types for rating column
+ // and user/item columns (when the user/item ids are within Int range)
+ val als = new ALS().setMaxIter(1).setRank(1)
+ Seq(("user", IntegerType), ("item", IntegerType), ("rating", FloatType)).foreach {
+ case (colName, sqlType) =>
+ MLTestingUtils.checkNumericTypesALS(als, spark, colName, sqlType) {
+ (ex, act) =>
+ ex.userFactors.first().getSeq[Float](1) === act.userFactors.first.getSeq[Float](1)
+ } { (ex, act, _) =>
+ ex.transform(_: DataFrame).select("prediction").first.getFloat(0) ~==
+ act.transform(_: DataFrame).select("prediction").first.getFloat(0) absTol 1e-6
+ }
+ }
+ // check user/item ids falling outside of Int range
+ val big = Int.MaxValue.toLong + 1
+ val small = Int.MinValue.toDouble - 1
+ val df = Seq(
+ (0, 0L, 0d, 1, 1L, 1d, 3.0),
+ (0, big, small, 0, big, small, 2.0),
+ (1, 1L, 1d, 0, 0L, 0d, 5.0)
+ ).toDF("user", "user_big", "user_small", "item", "item_big", "item_small", "rating")
+ withClue("fit should fail when ids exceed integer range. ") {
+ assert(intercept[IllegalArgumentException] {
+ als.fit(df.select(df("user_big").as("user"), df("item"), df("rating")))
+ }.getMessage.contains("was out of Integer range"))
+ assert(intercept[IllegalArgumentException] {
+ als.fit(df.select(df("user_small").as("user"), df("item"), df("rating")))
+ }.getMessage.contains("was out of Integer range"))
+ assert(intercept[IllegalArgumentException] {
+ als.fit(df.select(df("item_big").as("item"), df("user"), df("rating")))
+ }.getMessage.contains("was out of Integer range"))
+ assert(intercept[IllegalArgumentException] {
+ als.fit(df.select(df("item_small").as("item"), df("user"), df("rating")))
+ }.getMessage.contains("was out of Integer range"))
+ }
+ withClue("transform should fail when ids exceed integer range. ") {
+ val model = als.fit(df)
+ assert(intercept[SparkException] {
+ model.transform(df.select(df("user_big").as("user"), df("item"))).first
+ }.getMessage.contains("was out of Integer range"))
+ assert(intercept[SparkException] {
+ model.transform(df.select(df("user_small").as("user"), df("item"))).first
+ }.getMessage.contains("was out of Integer range"))
+ assert(intercept[SparkException] {
+ model.transform(df.select(df("item_big").as("item"), df("user"))).first
+ }.getMessage.contains("was out of Integer range"))
+ assert(intercept[SparkException] {
+ model.transform(df.select(df("item_small").as("item"), df("user"))).first
+ }.getMessage.contains("was out of Integer range"))
+ }
+ }
}
class ALSCleanerSuite extends SparkFunSuite {
diff --git a/mllib/src/test/scala/org/apache/spark/ml/util/MLTestingUtils.scala b/mllib/src/test/scala/org/apache/spark/ml/util/MLTestingUtils.scala
index 6aae625fc8..80b976914c 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/util/MLTestingUtils.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/util/MLTestingUtils.scala
@@ -22,6 +22,7 @@ import org.apache.spark.ml.{Estimator, Model}
import org.apache.spark.ml.evaluation.Evaluator
import org.apache.spark.ml.linalg.Vectors
import org.apache.spark.ml.param.ParamMap
+import org.apache.spark.ml.recommendation.{ALS, ALSModel}
import org.apache.spark.ml.tree.impl.TreeTests
import org.apache.spark.sql.{DataFrame, SparkSession}
import org.apache.spark.sql.functions._
@@ -58,6 +59,30 @@ object MLTestingUtils extends SparkFunSuite {
"Column label must be of type NumericType but was actually of type StringType"))
}
+ def checkNumericTypesALS(
+ estimator: ALS,
+ spark: SparkSession,
+ column: String,
+ baseType: NumericType)
+ (check: (ALSModel, ALSModel) => Unit)
+ (check2: (ALSModel, ALSModel, DataFrame) => Unit): Unit = {
+ val dfs = genRatingsDFWithNumericCols(spark, column)
+ val expected = estimator.fit(dfs(baseType))
+ val actuals = dfs.keys.filter(_ != baseType).map(t => (t, estimator.fit(dfs(t))))
+ actuals.foreach { case (_, actual) => check(expected, actual) }
+ actuals.foreach { case (t, actual) => check2(expected, actual, dfs(t)) }
+
+ val baseDF = dfs(baseType)
+ val others = baseDF.columns.toSeq.diff(Seq(column)).map(col(_))
+ val cols = Seq(col(column).cast(StringType)) ++ others
+ val strDF = baseDF.select(cols: _*)
+ val thrown = intercept[IllegalArgumentException] {
+ estimator.fit(strDF)
+ }
+ assert(thrown.getMessage.contains(
+ s"$column must be of type NumericType but was actually of type StringType"))
+ }
+
def checkNumericTypes[T <: Evaluator](evaluator: T, spark: SparkSession): Unit = {
val dfs = genEvaluatorDFWithNumericLabelCol(spark, "label", "prediction")
val expected = evaluator.evaluate(dfs(DoubleType))
@@ -116,6 +141,26 @@ object MLTestingUtils extends SparkFunSuite {
}.toMap
}
+ def genRatingsDFWithNumericCols(
+ spark: SparkSession,
+ column: String): Map[NumericType, DataFrame] = {
+ val df = spark.createDataFrame(Seq(
+ (0, 10, 1.0),
+ (1, 20, 2.0),
+ (2, 30, 3.0),
+ (3, 40, 4.0),
+ (4, 50, 5.0)
+ )).toDF("user", "item", "rating")
+
+ val others = df.columns.toSeq.diff(Seq(column)).map(col(_))
+ val types: Seq[NumericType] =
+ Seq(ShortType, LongType, IntegerType, FloatType, ByteType, DoubleType, DecimalType(10, 0))
+ types.map { t =>
+ val cols = Seq(col(column).cast(t)) ++ others
+ t -> df.select(cols: _*)
+ }.toMap
+ }
+
def genEvaluatorDFWithNumericLabelCol(
spark: SparkSession,
labelColName: String = "label",