aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/main/scala
diff options
context:
space:
mode:
Diffstat (limited to 'mllib/src/main/scala')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala55
1 files changed, 41 insertions, 14 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala
index 509c944fed..f257382d22 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala
@@ -42,7 +42,7 @@ import org.apache.spark.mllib.optimization.NNLS
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{DataFrame, Dataset}
import org.apache.spark.sql.functions._
-import org.apache.spark.sql.types.{DoubleType, FloatType, IntegerType, StructType}
+import org.apache.spark.sql.types._
import org.apache.spark.storage.StorageLevel
import org.apache.spark.util.Utils
import org.apache.spark.util.collection.{OpenHashMap, OpenHashSet, SortDataFormat, Sorter}
@@ -53,24 +53,43 @@ import org.apache.spark.util.random.XORShiftRandom
*/
private[recommendation] trait ALSModelParams extends Params with HasPredictionCol {
/**
- * Param for the column name for user ids.
+ * Param for the column name for user ids. Ids must be integers. Other
+ * numeric types are supported for this column, but will be cast to integers as long as they
+ * fall within the integer value range.
* Default: "user"
* @group param
*/
- val userCol = new Param[String](this, "userCol", "column name for user ids")
+ val userCol = new Param[String](this, "userCol", "column name for user ids. Ids must be within " +
+ "the integer value range.")
/** @group getParam */
def getUserCol: String = $(userCol)
/**
- * Param for the column name for item ids.
+ * Param for the column name for item ids. Ids must be integers. Other
+ * numeric types are supported for this column, but will be cast to integers as long as they
+ * fall within the integer value range.
* Default: "item"
* @group param
*/
- val itemCol = new Param[String](this, "itemCol", "column name for item ids")
+ val itemCol = new Param[String](this, "itemCol", "column name for item ids. Ids must be within " +
+ "the integer value range.")
/** @group getParam */
def getItemCol: String = $(itemCol)
+
+ /**
+ * Attempts to safely cast a user/item id to an Int. Throws an exception if the value is
+ * out of integer range.
+ */
+ protected val checkedCast = udf { (n: Double) =>
+ if (n > Int.MaxValue || n < Int.MinValue) {
+ throw new IllegalArgumentException(s"ALS only supports values in Integer range for columns " +
+ s"${$(userCol)} and ${$(itemCol)}. Value $n was out of Integer range.")
+ } else {
+ n.toInt
+ }
+ }
}
/**
@@ -193,10 +212,11 @@ private[recommendation] trait ALSParams extends ALSModelParams with HasMaxIter w
* @return output schema
*/
protected def validateAndTransformSchema(schema: StructType): StructType = {
- SchemaUtils.checkColumnType(schema, $(userCol), IntegerType)
- SchemaUtils.checkColumnType(schema, $(itemCol), IntegerType)
- val ratingType = schema($(ratingCol)).dataType
- require(ratingType == FloatType || ratingType == DoubleType)
+ // user and item will be cast to Int
+ SchemaUtils.checkNumericType(schema, $(userCol))
+ SchemaUtils.checkNumericType(schema, $(itemCol))
+ // rating will be cast to Float
+ SchemaUtils.checkNumericType(schema, $(ratingCol))
SchemaUtils.appendColumn(schema, $(predictionCol), FloatType)
}
}
@@ -232,6 +252,7 @@ class ALSModel private[ml] (
@Since("2.0.0")
override def transform(dataset: Dataset[_]): DataFrame = {
+ transformSchema(dataset.schema)
// Register a UDF for DataFrame, and then
// create a new column named map(predictionCol) by running the predict UDF.
val predict = udf { (userFeatures: Seq[Float], itemFeatures: Seq[Float]) =>
@@ -242,16 +263,19 @@ class ALSModel private[ml] (
}
}
dataset
- .join(userFactors, dataset($(userCol)) === userFactors("id"), "left")
- .join(itemFactors, dataset($(itemCol)) === itemFactors("id"), "left")
+ .join(userFactors,
+ checkedCast(dataset($(userCol)).cast(DoubleType)) === userFactors("id"), "left")
+ .join(itemFactors,
+ checkedCast(dataset($(itemCol)).cast(DoubleType)) === itemFactors("id"), "left")
.select(dataset("*"),
predict(userFactors("features"), itemFactors("features")).as($(predictionCol)))
}
@Since("1.3.0")
override def transformSchema(schema: StructType): StructType = {
- SchemaUtils.checkColumnType(schema, $(userCol), IntegerType)
- SchemaUtils.checkColumnType(schema, $(itemCol), IntegerType)
+ // user and item will be cast to Int
+ SchemaUtils.checkNumericType(schema, $(userCol))
+ SchemaUtils.checkNumericType(schema, $(itemCol))
SchemaUtils.appendColumn(schema, $(predictionCol), FloatType)
}
@@ -430,10 +454,13 @@ class ALS(@Since("1.4.0") override val uid: String) extends Estimator[ALSModel]
@Since("2.0.0")
override def fit(dataset: Dataset[_]): ALSModel = {
+ transformSchema(dataset.schema)
import dataset.sparkSession.implicits._
+
val r = if ($(ratingCol) != "") col($(ratingCol)).cast(FloatType) else lit(1.0f)
val ratings = dataset
- .select(col($(userCol)).cast(IntegerType), col($(itemCol)).cast(IntegerType), r)
+ .select(checkedCast(col($(userCol)).cast(DoubleType)),
+ checkedCast(col($(itemCol)).cast(DoubleType)), r)
.rdd
.map { row =>
Rating(row.getInt(0), row.getInt(1), row.getFloat(2))