aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
authorXiangrui Meng <meng@databricks.com>2015-05-28 22:38:38 -0700
committerXiangrui Meng <meng@databricks.com>2015-05-28 22:38:38 -0700
commitdb9513789756da4f16bb1fe8cf1d19500f231f54 (patch)
treeaaef83386cdad3975181b554d68527abf41407cb /mllib
parentcd3d9a5c0c3e77098a72c85dffe4a27737009ae7 (diff)
downloadspark-db9513789756da4f16bb1fe8cf1d19500f231f54.tar.gz
spark-db9513789756da4f16bb1fe8cf1d19500f231f54.tar.bz2
spark-db9513789756da4f16bb1fe8cf1d19500f231f54.zip
[SPARK-7922] [MLLIB] use DataFrames for user/item factors in ALSModel
Expose user/item factors in DataFrames. This is to be more consistent with the pipeline API. It also helps maintain consistent APIs across languages. This PR also removed fitting params from `ALSModel`. coderxiang Author: Xiangrui Meng <meng@databricks.com> Closes #6468 from mengxr/SPARK-7922 and squashes the following commits: 7bfb1d5 [Xiangrui Meng] update ALSModel in PySpark 1ba5607 [Xiangrui Meng] use DataFrames for user/item factors in ALS
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala101
1 files changed, 57 insertions, 44 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala
index 900b637ff8..df009d855e 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala
@@ -35,21 +35,46 @@ import org.apache.spark.annotation.{DeveloperApi, Experimental}
import org.apache.spark.ml.{Estimator, Model}
import org.apache.spark.ml.param._
import org.apache.spark.ml.param.shared._
-import org.apache.spark.ml.util.Identifiable
+import org.apache.spark.ml.util.{Identifiable, SchemaUtils}
import org.apache.spark.mllib.optimization.NNLS
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.functions._
-import org.apache.spark.sql.types.{DoubleType, FloatType, IntegerType, StructField, StructType}
+import org.apache.spark.sql.types.{DoubleType, FloatType, IntegerType, StructType}
import org.apache.spark.storage.StorageLevel
import org.apache.spark.util.Utils
import org.apache.spark.util.collection.{OpenHashMap, OpenHashSet, SortDataFormat, Sorter}
import org.apache.spark.util.random.XORShiftRandom
/**
+ * Common params for ALS and ALSModel.
+ */
+private[recommendation] trait ALSModelParams extends Params with HasPredictionCol {
+ /**
+ * Param for the column name for user ids.
+ * Default: "user"
+ * @group param
+ */
+ val userCol = new Param[String](this, "userCol", "column name for user ids")
+
+ /** @group getParam */
+ def getUserCol: String = $(userCol)
+
+ /**
+ * Param for the column name for item ids.
+ * Default: "item"
+ * @group param
+ */
+ val itemCol = new Param[String](this, "itemCol", "column name for item ids")
+
+ /** @group getParam */
+ def getItemCol: String = $(itemCol)
+}
+
+/**
* Common params for ALS.
*/
-private[recommendation] trait ALSParams extends Params with HasMaxIter with HasRegParam
+private[recommendation] trait ALSParams extends ALSModelParams with HasMaxIter with HasRegParam
with HasPredictionCol with HasCheckpointInterval with HasSeed {
/**
@@ -106,26 +131,6 @@ private[recommendation] trait ALSParams extends Params with HasMaxIter with HasR
def getAlpha: Double = $(alpha)
/**
- * Param for the column name for user ids.
- * Default: "user"
- * @group param
- */
- val userCol = new Param[String](this, "userCol", "column name for user ids")
-
- /** @group getParam */
- def getUserCol: String = $(userCol)
-
- /**
- * Param for the column name for item ids.
- * Default: "item"
- * @group param
- */
- val itemCol = new Param[String](this, "itemCol", "column name for item ids")
-
- /** @group getParam */
- def getItemCol: String = $(itemCol)
-
- /**
* Param for the column name for ratings.
* Default: "rating"
* @group param
@@ -156,55 +161,60 @@ private[recommendation] trait ALSParams extends Params with HasMaxIter with HasR
* @return output schema
*/
protected def validateAndTransformSchema(schema: StructType): StructType = {
- require(schema($(userCol)).dataType == IntegerType)
- require(schema($(itemCol)).dataType== IntegerType)
+ SchemaUtils.checkColumnType(schema, $(userCol), IntegerType)
+ SchemaUtils.checkColumnType(schema, $(itemCol), IntegerType)
val ratingType = schema($(ratingCol)).dataType
require(ratingType == FloatType || ratingType == DoubleType)
- val predictionColName = $(predictionCol)
- require(!schema.fieldNames.contains(predictionColName),
- s"Prediction column $predictionColName already exists.")
- val newFields = schema.fields :+ StructField($(predictionCol), FloatType, nullable = false)
- StructType(newFields)
+ SchemaUtils.appendColumn(schema, $(predictionCol), FloatType)
}
}
/**
* :: Experimental ::
* Model fitted by ALS.
+ *
+ * @param rank rank of the matrix factorization model
+ * @param userFactors a DataFrame that stores user factors in two columns: `id` and `features`
+ * @param itemFactors a DataFrame that stores item factors in two columns: `id` and `features`
*/
@Experimental
class ALSModel private[ml] (
override val uid: String,
- k: Int,
- userFactors: RDD[(Int, Array[Float])],
- itemFactors: RDD[(Int, Array[Float])])
- extends Model[ALSModel] with ALSParams {
+ val rank: Int,
+ @transient val userFactors: DataFrame,
+ @transient val itemFactors: DataFrame)
+ extends Model[ALSModel] with ALSModelParams {
+
+ /** @group setParam */
+ def setUserCol(value: String): this.type = set(userCol, value)
+
+ /** @group setParam */
+ def setItemCol(value: String): this.type = set(itemCol, value)
/** @group setParam */
def setPredictionCol(value: String): this.type = set(predictionCol, value)
override def transform(dataset: DataFrame): DataFrame = {
- import dataset.sqlContext.implicits._
- val users = userFactors.toDF("id", "features")
- val items = itemFactors.toDF("id", "features")
-
// Register a UDF for DataFrame, and then
// create a new column named map(predictionCol) by running the predict UDF.
val predict = udf { (userFeatures: Seq[Float], itemFeatures: Seq[Float]) =>
if (userFeatures != null && itemFeatures != null) {
- blas.sdot(k, userFeatures.toArray, 1, itemFeatures.toArray, 1)
+ blas.sdot(rank, userFeatures.toArray, 1, itemFeatures.toArray, 1)
} else {
Float.NaN
}
}
dataset
- .join(users, dataset($(userCol)) === users("id"), "left")
- .join(items, dataset($(itemCol)) === items("id"), "left")
- .select(dataset("*"), predict(users("features"), items("features")).as($(predictionCol)))
+ .join(userFactors, dataset($(userCol)) === userFactors("id"), "left")
+ .join(itemFactors, dataset($(itemCol)) === itemFactors("id"), "left")
+ .select(dataset("*"),
+ predict(userFactors("features"), itemFactors("features")).as($(predictionCol)))
}
override def transformSchema(schema: StructType): StructType = {
- validateAndTransformSchema(schema)
+ SchemaUtils.checkColumnType(schema, $(userCol), IntegerType)
+ SchemaUtils.checkColumnType(schema, $(itemCol), IntegerType)
+ SchemaUtils.appendColumn(schema, $(predictionCol), FloatType)
}
}
@@ -299,6 +309,7 @@ class ALS(override val uid: String) extends Estimator[ALSModel] with ALSParams {
}
override def fit(dataset: DataFrame): ALSModel = {
+ import dataset.sqlContext.implicits._
val ratings = dataset
.select(col($(userCol)).cast(IntegerType), col($(itemCol)).cast(IntegerType),
col($(ratingCol)).cast(FloatType))
@@ -310,7 +321,9 @@ class ALS(override val uid: String) extends Estimator[ALSModel] with ALSParams {
maxIter = $(maxIter), regParam = $(regParam), implicitPrefs = $(implicitPrefs),
alpha = $(alpha), nonnegative = $(nonnegative),
checkpointInterval = $(checkpointInterval), seed = $(seed))
- val model = new ALSModel(uid, $(rank), userFactors, itemFactors).setParent(this)
+ val userDF = userFactors.toDF("id", "features")
+ val itemDF = itemFactors.toDF("id", "features")
+ val model = new ALSModel(uid, $(rank), userDF, itemDF).setParent(this)
copyValues(model)
}