aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/main/scala
diff options
context:
space:
mode:
authorNick Pentreath <nickp@za.ibm.com>2016-05-18 21:13:12 +0200
committerNick Pentreath <nickp@za.ibm.com>2016-05-18 21:13:12 +0200
commite8b79afa024123f9d4ceaf0a1043a7e37d913a8d (patch)
tree112148429c15f0eed3ef53e6253898521352a011 /mllib/src/main/scala
parent3d1e67f903ab3512fcad82b94b1825578f8117c9 (diff)
downloadspark-e8b79afa024123f9d4ceaf0a1043a7e37d913a8d.tar.gz
spark-e8b79afa024123f9d4ceaf0a1043a7e37d913a8d.tar.bz2
spark-e8b79afa024123f9d4ceaf0a1043a7e37d913a8d.zip
[SPARK-14891][ML] Add schema validation for ALS
This PR adds schema validation to `ml`'s ALS and ALSModel. Currently, no schema validation was performed as `transformSchema` was never called in `ALS.fit` or `ALSModel.transform`. Furthermore, due to no schema validation, if users passed in Long (or Float etc) ids, they would be silently cast to Int with no warning or error thrown. With this PR, ALS now supports all numeric types for `user`, `item`, and `rating` columns. The rating column is cast to `Float` and the user and item cols are cast to `Int` (as is the case currently) - however for user/item, the cast throws an error if the value is outside integer range. Behavior for rating col is unchanged (as it is not an issue). ## How was this patch tested? New test cases in `ALSSuite`. Author: Nick Pentreath <nickp@za.ibm.com> Closes #12762 from MLnick/SPARK-14891-als-validate-schema.
Diffstat (limited to 'mllib/src/main/scala')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala55
1 files changed, 41 insertions, 14 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala
index 509c944fed..f257382d22 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala
@@ -42,7 +42,7 @@ import org.apache.spark.mllib.optimization.NNLS
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{DataFrame, Dataset}
import org.apache.spark.sql.functions._
-import org.apache.spark.sql.types.{DoubleType, FloatType, IntegerType, StructType}
+import org.apache.spark.sql.types._
import org.apache.spark.storage.StorageLevel
import org.apache.spark.util.Utils
import org.apache.spark.util.collection.{OpenHashMap, OpenHashSet, SortDataFormat, Sorter}
@@ -53,24 +53,43 @@ import org.apache.spark.util.random.XORShiftRandom
*/
private[recommendation] trait ALSModelParams extends Params with HasPredictionCol {
/**
- * Param for the column name for user ids.
+ * Param for the column name for user ids. Ids must be integers. Other
+ * numeric types are supported for this column, but will be cast to integers as long as they
+ * fall within the integer value range.
* Default: "user"
* @group param
*/
- val userCol = new Param[String](this, "userCol", "column name for user ids")
+ val userCol = new Param[String](this, "userCol", "column name for user ids. Ids must be within " +
+ "the integer value range.")
/** @group getParam */
def getUserCol: String = $(userCol)
/**
- * Param for the column name for item ids.
+ * Param for the column name for item ids. Ids must be integers. Other
+ * numeric types are supported for this column, but will be cast to integers as long as they
+ * fall within the integer value range.
* Default: "item"
* @group param
*/
- val itemCol = new Param[String](this, "itemCol", "column name for item ids")
+ val itemCol = new Param[String](this, "itemCol", "column name for item ids. Ids must be within " +
+ "the integer value range.")
/** @group getParam */
def getItemCol: String = $(itemCol)
+
+ /**
+ * Attempts to safely cast a user/item id to an Int. Throws an exception if the value is
+ * out of integer range.
+ */
+ protected val checkedCast = udf { (n: Double) =>
+ if (n > Int.MaxValue || n < Int.MinValue) {
+ throw new IllegalArgumentException(s"ALS only supports values in Integer range for columns " +
+ s"${$(userCol)} and ${$(itemCol)}. Value $n was out of Integer range.")
+ } else {
+ n.toInt
+ }
+ }
}
/**
@@ -193,10 +212,11 @@ private[recommendation] trait ALSParams extends ALSModelParams with HasMaxIter w
* @return output schema
*/
protected def validateAndTransformSchema(schema: StructType): StructType = {
- SchemaUtils.checkColumnType(schema, $(userCol), IntegerType)
- SchemaUtils.checkColumnType(schema, $(itemCol), IntegerType)
- val ratingType = schema($(ratingCol)).dataType
- require(ratingType == FloatType || ratingType == DoubleType)
+ // user and item will be cast to Int
+ SchemaUtils.checkNumericType(schema, $(userCol))
+ SchemaUtils.checkNumericType(schema, $(itemCol))
+ // rating will be cast to Float
+ SchemaUtils.checkNumericType(schema, $(ratingCol))
SchemaUtils.appendColumn(schema, $(predictionCol), FloatType)
}
}
@@ -232,6 +252,7 @@ class ALSModel private[ml] (
@Since("2.0.0")
override def transform(dataset: Dataset[_]): DataFrame = {
+ transformSchema(dataset.schema)
// Register a UDF for DataFrame, and then
// create a new column named map(predictionCol) by running the predict UDF.
val predict = udf { (userFeatures: Seq[Float], itemFeatures: Seq[Float]) =>
@@ -242,16 +263,19 @@ class ALSModel private[ml] (
}
}
dataset
- .join(userFactors, dataset($(userCol)) === userFactors("id"), "left")
- .join(itemFactors, dataset($(itemCol)) === itemFactors("id"), "left")
+ .join(userFactors,
+ checkedCast(dataset($(userCol)).cast(DoubleType)) === userFactors("id"), "left")
+ .join(itemFactors,
+ checkedCast(dataset($(itemCol)).cast(DoubleType)) === itemFactors("id"), "left")
.select(dataset("*"),
predict(userFactors("features"), itemFactors("features")).as($(predictionCol)))
}
@Since("1.3.0")
override def transformSchema(schema: StructType): StructType = {
- SchemaUtils.checkColumnType(schema, $(userCol), IntegerType)
- SchemaUtils.checkColumnType(schema, $(itemCol), IntegerType)
+ // user and item will be cast to Int
+ SchemaUtils.checkNumericType(schema, $(userCol))
+ SchemaUtils.checkNumericType(schema, $(itemCol))
SchemaUtils.appendColumn(schema, $(predictionCol), FloatType)
}
@@ -430,10 +454,13 @@ class ALS(@Since("1.4.0") override val uid: String) extends Estimator[ALSModel]
@Since("2.0.0")
override def fit(dataset: Dataset[_]): ALSModel = {
+ transformSchema(dataset.schema)
import dataset.sparkSession.implicits._
+
val r = if ($(ratingCol) != "") col($(ratingCol)).cast(FloatType) else lit(1.0f)
val ratings = dataset
- .select(col($(userCol)).cast(IntegerType), col($(itemCol)).cast(IntegerType), r)
+ .select(checkedCast(col($(userCol)).cast(DoubleType)),
+ checkedCast(col($(itemCol)).cast(DoubleType)), r)
.rdd
.map { row =>
Rating(row.getInt(0), row.getInt(1), row.getFloat(2))