aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/test
diff options
context:
space:
mode:
authorNick Pentreath <nickp@za.ibm.com>2016-05-18 21:13:12 +0200
committerNick Pentreath <nickp@za.ibm.com>2016-05-18 21:13:12 +0200
commite8b79afa024123f9d4ceaf0a1043a7e37d913a8d (patch)
tree112148429c15f0eed3ef53e6253898521352a011 /mllib/src/test
parent3d1e67f903ab3512fcad82b94b1825578f8117c9 (diff)
downloadspark-e8b79afa024123f9d4ceaf0a1043a7e37d913a8d.tar.gz
spark-e8b79afa024123f9d4ceaf0a1043a7e37d913a8d.tar.bz2
spark-e8b79afa024123f9d4ceaf0a1043a7e37d913a8d.zip
[SPARK-14891][ML] Add schema validation for ALS
This PR adds schema validation to `ml`'s ALS and ALSModel. Currently, no schema validation was performed as `transformSchema` was never called in `ALS.fit` or `ALSModel.transform`. Furthermore, due to no schema validation, if users passed in Long (or Float etc) ids, they would be silently cast to Int with no warning or error thrown. With this PR, ALS now supports all numeric types for `user`, `item`, and `rating` columns. The rating column is cast to `Float` and the user and item cols are cast to `Int` (as is the case currently) - however for user/item, the cast throws an error if the value is outside integer range. Behavior for rating col is unchanged (as it is not an issue). ## How was this patch tested? New test cases in `ALSSuite`. Author: Nick Pentreath <nickp@za.ibm.com> Closes #12762 from MLnick/SPARK-14891-als-validate-schema.
Diffstat (limited to 'mllib/src/test')
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala61
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/util/MLTestingUtils.scala45
2 files changed, 102 insertions, 4 deletions
diff --git a/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala
index bbfc415cbb..59b5edc401 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala
@@ -39,6 +39,7 @@ import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.rdd.RDD
import org.apache.spark.scheduler.{SparkListener, SparkListenerStageCompleted}
import org.apache.spark.sql.{DataFrame, Row, SparkSession}
+import org.apache.spark.sql.types.{FloatType, IntegerType}
import org.apache.spark.storage.StorageLevel
import org.apache.spark.util.Utils
@@ -205,7 +206,6 @@ class ALSSuite
/**
* Generates an explicit feedback dataset for testing ALS.
- *
* @param numUsers number of users
* @param numItems number of items
* @param rank rank
@@ -246,7 +246,6 @@ class ALSSuite
/**
* Generates an implicit feedback dataset for testing ALS.
- *
* @param numUsers number of users
* @param numItems number of items
* @param rank rank
@@ -265,7 +264,6 @@ class ALSSuite
/**
* Generates random user/item factors, with i.i.d. values drawn from U(a, b).
- *
* @param size number of users/items
* @param rank number of features
* @param random random number generator
@@ -284,7 +282,6 @@ class ALSSuite
/**
* Test ALS using the given training/test splits and parameters.
- *
* @param training training dataset
* @param test test dataset
* @param rank rank of the matrix factorization
@@ -486,6 +483,62 @@ class ALSSuite
assert(getFactors(model.userFactors) === getFactors(model2.userFactors))
assert(getFactors(model.itemFactors) === getFactors(model2.itemFactors))
}
+
+ test("input type validation") {
+ val spark = this.spark
+ import spark.implicits._
+
+ // check that ALS can handle all numeric types for rating column
+ // and user/item columns (when the user/item ids are within Int range)
+ val als = new ALS().setMaxIter(1).setRank(1)
+ Seq(("user", IntegerType), ("item", IntegerType), ("rating", FloatType)).foreach {
+ case (colName, sqlType) =>
+ MLTestingUtils.checkNumericTypesALS(als, spark, colName, sqlType) {
+ (ex, act) =>
+ ex.userFactors.first().getSeq[Float](1) === act.userFactors.first.getSeq[Float](1)
+ } { (ex, act, _) =>
+ ex.transform(_: DataFrame).select("prediction").first.getFloat(0) ~==
+ act.transform(_: DataFrame).select("prediction").first.getFloat(0) absTol 1e-6
+ }
+ }
+ // check user/item ids falling outside of Int range
+ val big = Int.MaxValue.toLong + 1
+ val small = Int.MinValue.toDouble - 1
+ val df = Seq(
+ (0, 0L, 0d, 1, 1L, 1d, 3.0),
+ (0, big, small, 0, big, small, 2.0),
+ (1, 1L, 1d, 0, 0L, 0d, 5.0)
+ ).toDF("user", "user_big", "user_small", "item", "item_big", "item_small", "rating")
+ withClue("fit should fail when ids exceed integer range. ") {
+ assert(intercept[IllegalArgumentException] {
+ als.fit(df.select(df("user_big").as("user"), df("item"), df("rating")))
+ }.getMessage.contains("was out of Integer range"))
+ assert(intercept[IllegalArgumentException] {
+ als.fit(df.select(df("user_small").as("user"), df("item"), df("rating")))
+ }.getMessage.contains("was out of Integer range"))
+ assert(intercept[IllegalArgumentException] {
+ als.fit(df.select(df("item_big").as("item"), df("user"), df("rating")))
+ }.getMessage.contains("was out of Integer range"))
+ assert(intercept[IllegalArgumentException] {
+ als.fit(df.select(df("item_small").as("item"), df("user"), df("rating")))
+ }.getMessage.contains("was out of Integer range"))
+ }
+ withClue("transform should fail when ids exceed integer range. ") {
+ val model = als.fit(df)
+ assert(intercept[SparkException] {
+ model.transform(df.select(df("user_big").as("user"), df("item"))).first
+ }.getMessage.contains("was out of Integer range"))
+ assert(intercept[SparkException] {
+ model.transform(df.select(df("user_small").as("user"), df("item"))).first
+ }.getMessage.contains("was out of Integer range"))
+ assert(intercept[SparkException] {
+ model.transform(df.select(df("item_big").as("item"), df("user"))).first
+ }.getMessage.contains("was out of Integer range"))
+ assert(intercept[SparkException] {
+ model.transform(df.select(df("item_small").as("item"), df("user"))).first
+ }.getMessage.contains("was out of Integer range"))
+ }
+ }
}
class ALSCleanerSuite extends SparkFunSuite {
diff --git a/mllib/src/test/scala/org/apache/spark/ml/util/MLTestingUtils.scala b/mllib/src/test/scala/org/apache/spark/ml/util/MLTestingUtils.scala
index 6aae625fc8..80b976914c 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/util/MLTestingUtils.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/util/MLTestingUtils.scala
@@ -22,6 +22,7 @@ import org.apache.spark.ml.{Estimator, Model}
import org.apache.spark.ml.evaluation.Evaluator
import org.apache.spark.ml.linalg.Vectors
import org.apache.spark.ml.param.ParamMap
+import org.apache.spark.ml.recommendation.{ALS, ALSModel}
import org.apache.spark.ml.tree.impl.TreeTests
import org.apache.spark.sql.{DataFrame, SparkSession}
import org.apache.spark.sql.functions._
@@ -58,6 +59,30 @@ object MLTestingUtils extends SparkFunSuite {
"Column label must be of type NumericType but was actually of type StringType"))
}
+ def checkNumericTypesALS(
+ estimator: ALS,
+ spark: SparkSession,
+ column: String,
+ baseType: NumericType)
+ (check: (ALSModel, ALSModel) => Unit)
+ (check2: (ALSModel, ALSModel, DataFrame) => Unit): Unit = {
+ val dfs = genRatingsDFWithNumericCols(spark, column)
+ val expected = estimator.fit(dfs(baseType))
+ val actuals = dfs.keys.filter(_ != baseType).map(t => (t, estimator.fit(dfs(t))))
+ actuals.foreach { case (_, actual) => check(expected, actual) }
+ actuals.foreach { case (t, actual) => check2(expected, actual, dfs(t)) }
+
+ val baseDF = dfs(baseType)
+ val others = baseDF.columns.toSeq.diff(Seq(column)).map(col(_))
+ val cols = Seq(col(column).cast(StringType)) ++ others
+ val strDF = baseDF.select(cols: _*)
+ val thrown = intercept[IllegalArgumentException] {
+ estimator.fit(strDF)
+ }
+ assert(thrown.getMessage.contains(
+ s"$column must be of type NumericType but was actually of type StringType"))
+ }
+
def checkNumericTypes[T <: Evaluator](evaluator: T, spark: SparkSession): Unit = {
val dfs = genEvaluatorDFWithNumericLabelCol(spark, "label", "prediction")
val expected = evaluator.evaluate(dfs(DoubleType))
@@ -116,6 +141,26 @@ object MLTestingUtils extends SparkFunSuite {
}.toMap
}
+ def genRatingsDFWithNumericCols(
+ spark: SparkSession,
+ column: String): Map[NumericType, DataFrame] = {
+ val df = spark.createDataFrame(Seq(
+ (0, 10, 1.0),
+ (1, 20, 2.0),
+ (2, 30, 3.0),
+ (3, 40, 4.0),
+ (4, 50, 5.0)
+ )).toDF("user", "item", "rating")
+
+ val others = df.columns.toSeq.diff(Seq(column)).map(col(_))
+ val types: Seq[NumericType] =
+ Seq(ShortType, LongType, IntegerType, FloatType, ByteType, DoubleType, DecimalType(10, 0))
+ types.map { t =>
+ val cols = Seq(col(column).cast(t)) ++ others
+ t -> df.select(cols: _*)
+ }.toMap
+ }
+
def genEvaluatorDFWithNumericLabelCol(
spark: SparkSession,
labelColName: String = "label",