aboutsummaryrefslogtreecommitdiff
path: root/mllib/src
diff options
context:
space:
mode:
authorVasilis Vryniotis <bbriniotis@datumbox.com>2017-03-02 12:37:42 +0200
committerNick Pentreath <nickp@za.ibm.com>2017-03-02 12:37:42 +0200
commit625cfe09e673bfcb95e361ce19b534cf0a3c782c (patch)
tree7602307619a40627ffff5f838befaf4e473a9671 /mllib/src
parent8d6ef895ee492b8febbaac7ab2ef2c907b48fa4a (diff)
downloadspark-625cfe09e673bfcb95e361ce19b534cf0a3c782c.tar.gz
spark-625cfe09e673bfcb95e361ce19b534cf0a3c782c.tar.bz2
spark-625cfe09e673bfcb95e361ce19b534cf0a3c782c.zip
[SPARK-19733][ML] Removed unnecessary castings and refactored checked casts in ALS.
## What changes were proposed in this pull request? The original ALS was performing unnecessary casting to the user and item ids because the protected checkedCast() method required a double. I removed the castings and refactored the method to receive Any and efficiently handle all permitted numeric values. ## How was this patch tested? I tested it by running the unit-tests and by manually validating the result of checkedCast for various legal and illegal values. Author: Vasilis Vryniotis <bbriniotis@datumbox.com> Closes #17059 from datumbox/als_casting_fix.
Diffstat (limited to 'mllib/src')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala31
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala84
2 files changed, 95 insertions, 20 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala
index 04273a40d9..799e881fad 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala
@@ -80,14 +80,24 @@ private[recommendation] trait ALSModelParams extends Params with HasPredictionCo
/**
* Attempts to safely cast a user/item id to an Int. Throws an exception if the value is
- * out of integer range.
+ * out of integer range or contains a fractional part.
*/
- protected val checkedCast = udf { (n: Double) =>
- if (n > Int.MaxValue || n < Int.MinValue) {
- throw new IllegalArgumentException(s"ALS only supports values in Integer range for columns " +
- s"${$(userCol)} and ${$(itemCol)}. Value $n was out of Integer range.")
- } else {
- n.toInt
+ protected[recommendation] val checkedCast = udf { (n: Any) =>
+ n match {
+ case v: Int => v // Avoid unnecessary casting
+ case v: Number =>
+ val intV = v.intValue
+ // Checks if number within Int range and has no fractional part.
+ if (v.doubleValue == intV) {
+ intV
+ } else {
+ throw new IllegalArgumentException(s"ALS only supports values in Integer range " +
+ s"and without fractional part for columns ${$(userCol)} and ${$(itemCol)}. " +
+ s"Value $n was either out of Integer range or contained a fractional part that " +
+ s"could not be converted.")
+ }
+ case _ => throw new IllegalArgumentException(s"ALS only supports values in Integer range " +
+ s"for columns ${$(userCol)} and ${$(itemCol)}. Value $n was not numeric.")
}
}
@@ -288,9 +298,9 @@ class ALSModel private[ml] (
}
val predictions = dataset
.join(userFactors,
- checkedCast(dataset($(userCol)).cast(DoubleType)) === userFactors("id"), "left")
+ checkedCast(dataset($(userCol))) === userFactors("id"), "left")
.join(itemFactors,
- checkedCast(dataset($(itemCol)).cast(DoubleType)) === itemFactors("id"), "left")
+ checkedCast(dataset($(itemCol))) === itemFactors("id"), "left")
.select(dataset("*"),
predict(userFactors("features"), itemFactors("features")).as($(predictionCol)))
getColdStartStrategy match {
@@ -491,8 +501,7 @@ class ALS(@Since("1.4.0") override val uid: String) extends Estimator[ALSModel]
val r = if ($(ratingCol) != "") col($(ratingCol)).cast(FloatType) else lit(1.0f)
val ratings = dataset
- .select(checkedCast(col($(userCol)).cast(DoubleType)),
- checkedCast(col($(itemCol)).cast(DoubleType)), r)
+ .select(checkedCast(col($(userCol))), checkedCast(col($(itemCol))), r)
.rdd
.map { row =>
Rating(row.getInt(0), row.getInt(1), row.getFloat(2))
diff --git a/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala
index c9e7b505b2..c8228dd004 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala
@@ -40,7 +40,8 @@ import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.rdd.RDD
import org.apache.spark.scheduler.{SparkListener, SparkListenerStageCompleted}
import org.apache.spark.sql.{DataFrame, Row, SparkSession}
-import org.apache.spark.sql.types.{FloatType, IntegerType}
+import org.apache.spark.sql.functions.lit
+import org.apache.spark.sql.types._
import org.apache.spark.storage.StorageLevel
import org.apache.spark.util.Utils
@@ -205,6 +206,70 @@ class ALSSuite
assert(decompressed.toSet === expected)
}
+ test("CheckedCast") {
+ val checkedCast = new ALS().checkedCast
+ val df = spark.range(1)
+
+ withClue("Valid Integer Ids") {
+ df.select(checkedCast(lit(123))).collect()
+ }
+
+ withClue("Valid Long Ids") {
+ df.select(checkedCast(lit(1231L))).collect()
+ }
+
+ withClue("Valid Decimal Ids") {
+ df.select(checkedCast(lit(123).cast(DecimalType(15, 2)))).collect()
+ }
+
+ withClue("Valid Double Ids") {
+ df.select(checkedCast(lit(123.0))).collect()
+ }
+
+ val msg = "either out of Integer range or contained a fractional part"
+ withClue("Invalid Long: out of range") {
+ val e: SparkException = intercept[SparkException] {
+ df.select(checkedCast(lit(1231000000000L))).collect()
+ }
+ assert(e.getMessage.contains(msg))
+ }
+
+ withClue("Invalid Decimal: out of range") {
+ val e: SparkException = intercept[SparkException] {
+ df.select(checkedCast(lit(1231000000000.0).cast(DecimalType(15, 2)))).collect()
+ }
+ assert(e.getMessage.contains(msg))
+ }
+
+ withClue("Invalid Decimal: fractional part") {
+ val e: SparkException = intercept[SparkException] {
+ df.select(checkedCast(lit(123.1).cast(DecimalType(15, 2)))).collect()
+ }
+ assert(e.getMessage.contains(msg))
+ }
+
+ withClue("Invalid Double: out of range") {
+ val e: SparkException = intercept[SparkException] {
+ df.select(checkedCast(lit(1231000000000.0))).collect()
+ }
+ assert(e.getMessage.contains(msg))
+ }
+
+ withClue("Invalid Double: fractional part") {
+ val e: SparkException = intercept[SparkException] {
+ df.select(checkedCast(lit(123.1))).collect()
+ }
+ assert(e.getMessage.contains(msg))
+ }
+
+ withClue("Invalid Type") {
+ val e: SparkException = intercept[SparkException] {
+ df.select(checkedCast(lit("123.1"))).collect()
+ }
+ assert(e.getMessage.contains("was not numeric"))
+ }
+ }
+
/**
* Generates an explicit feedback dataset for testing ALS.
* @param numUsers number of users
@@ -510,34 +575,35 @@ class ALSSuite
(0, big, small, 0, big, small, 2.0),
(1, 1L, 1d, 0, 0L, 0d, 5.0)
).toDF("user", "user_big", "user_small", "item", "item_big", "item_small", "rating")
+ val msg = "either out of Integer range or contained a fractional part"
withClue("fit should fail when ids exceed integer range. ") {
assert(intercept[SparkException] {
als.fit(df.select(df("user_big").as("user"), df("item"), df("rating")))
- }.getCause.getMessage.contains("was out of Integer range"))
+ }.getCause.getMessage.contains(msg))
assert(intercept[SparkException] {
als.fit(df.select(df("user_small").as("user"), df("item"), df("rating")))
- }.getCause.getMessage.contains("was out of Integer range"))
+ }.getCause.getMessage.contains(msg))
assert(intercept[SparkException] {
als.fit(df.select(df("item_big").as("item"), df("user"), df("rating")))
- }.getCause.getMessage.contains("was out of Integer range"))
+ }.getCause.getMessage.contains(msg))
assert(intercept[SparkException] {
als.fit(df.select(df("item_small").as("item"), df("user"), df("rating")))
- }.getCause.getMessage.contains("was out of Integer range"))
+ }.getCause.getMessage.contains(msg))
}
withClue("transform should fail when ids exceed integer range. ") {
val model = als.fit(df)
assert(intercept[SparkException] {
model.transform(df.select(df("user_big").as("user"), df("item"))).first
- }.getMessage.contains("was out of Integer range"))
+ }.getMessage.contains(msg))
assert(intercept[SparkException] {
model.transform(df.select(df("user_small").as("user"), df("item"))).first
- }.getMessage.contains("was out of Integer range"))
+ }.getMessage.contains(msg))
assert(intercept[SparkException] {
model.transform(df.select(df("item_big").as("item"), df("user"))).first
- }.getMessage.contains("was out of Integer range"))
+ }.getMessage.contains(msg))
assert(intercept[SparkException] {
model.transform(df.select(df("item_small").as("item"), df("user"))).first
- }.getMessage.contains("was out of Integer range"))
+ }.getMessage.contains(msg))
}
}