From db9513789756da4f16bb1fe8cf1d19500f231f54 Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Thu, 28 May 2015 22:38:38 -0700 Subject: [SPARK-7922] [MLLIB] use DataFrames for user/item factors in ALSModel Expose user/item factors in DataFrames. This is to be more consistent with the pipeline API. It also helps maintain consistent APIs across languages. This PR also removed fitting params from `ALSModel`. coderxiang Author: Xiangrui Meng Closes #6468 from mengxr/SPARK-7922 and squashes the following commits: 7bfb1d5 [Xiangrui Meng] update ALSModel in PySpark 1ba5607 [Xiangrui Meng] use DataFrames for user/item factors in ALS --- .../org/apache/spark/ml/recommendation/ALS.scala | 101 ++++++++++++--------- 1 file changed, 57 insertions(+), 44 deletions(-) (limited to 'mllib') diff --git a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala index 900b637ff8..df009d855e 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala @@ -35,21 +35,46 @@ import org.apache.spark.annotation.{DeveloperApi, Experimental} import org.apache.spark.ml.{Estimator, Model} import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared._ -import org.apache.spark.ml.util.Identifiable +import org.apache.spark.ml.util.{Identifiable, SchemaUtils} import org.apache.spark.mllib.optimization.NNLS import org.apache.spark.rdd.RDD import org.apache.spark.sql.DataFrame import org.apache.spark.sql.functions._ -import org.apache.spark.sql.types.{DoubleType, FloatType, IntegerType, StructField, StructType} +import org.apache.spark.sql.types.{DoubleType, FloatType, IntegerType, StructType} import org.apache.spark.storage.StorageLevel import org.apache.spark.util.Utils import org.apache.spark.util.collection.{OpenHashMap, OpenHashSet, SortDataFormat, Sorter} import org.apache.spark.util.random.XORShiftRandom +/** + * Common params for ALS and ALSModel. + */ +private[recommendation] trait ALSModelParams extends Params with HasPredictionCol { + /** + * Param for the column name for user ids. + * Default: "user" + * @group param + */ + val userCol = new Param[String](this, "userCol", "column name for user ids") + + /** @group getParam */ + def getUserCol: String = $(userCol) + + /** + * Param for the column name for item ids. + * Default: "item" + * @group param + */ + val itemCol = new Param[String](this, "itemCol", "column name for item ids") + + /** @group getParam */ + def getItemCol: String = $(itemCol) +} + /** * Common params for ALS. */ -private[recommendation] trait ALSParams extends Params with HasMaxIter with HasRegParam +private[recommendation] trait ALSParams extends ALSModelParams with HasMaxIter with HasRegParam with HasPredictionCol with HasCheckpointInterval with HasSeed { /** @@ -105,26 +130,6 @@ private[recommendation] trait ALSParams extends Params with HasMaxIter with HasR /** @group getParam */ def getAlpha: Double = $(alpha) - /** - * Param for the column name for user ids. - * Default: "user" - * @group param - */ - val userCol = new Param[String](this, "userCol", "column name for user ids") - - /** @group getParam */ - def getUserCol: String = $(userCol) - - /** - * Param for the column name for item ids. - * Default: "item" - * @group param - */ - val itemCol = new Param[String](this, "itemCol", "column name for item ids") - - /** @group getParam */ - def getItemCol: String = $(itemCol) - /** * Param for the column name for ratings. * Default: "rating" @@ -156,55 +161,60 @@ private[recommendation] trait ALSParams extends Params with HasMaxIter with HasR * @return output schema */ protected def validateAndTransformSchema(schema: StructType): StructType = { - require(schema($(userCol)).dataType == IntegerType) - require(schema($(itemCol)).dataType== IntegerType) + SchemaUtils.checkColumnType(schema, $(userCol), IntegerType) + SchemaUtils.checkColumnType(schema, $(itemCol), IntegerType) val ratingType = schema($(ratingCol)).dataType require(ratingType == FloatType || ratingType == DoubleType) - val predictionColName = $(predictionCol) - require(!schema.fieldNames.contains(predictionColName), - s"Prediction column $predictionColName already exists.") - val newFields = schema.fields :+ StructField($(predictionCol), FloatType, nullable = false) - StructType(newFields) + SchemaUtils.appendColumn(schema, $(predictionCol), FloatType) } } /** * :: Experimental :: * Model fitted by ALS. + * + * @param rank rank of the matrix factorization model + * @param userFactors a DataFrame that stores user factors in two columns: `id` and `features` + * @param itemFactors a DataFrame that stores item factors in two columns: `id` and `features` */ @Experimental class ALSModel private[ml] ( override val uid: String, - k: Int, - userFactors: RDD[(Int, Array[Float])], - itemFactors: RDD[(Int, Array[Float])]) - extends Model[ALSModel] with ALSParams { + val rank: Int, + @transient val userFactors: DataFrame, + @transient val itemFactors: DataFrame) + extends Model[ALSModel] with ALSModelParams { + + /** @group setParam */ + def setUserCol(value: String): this.type = set(userCol, value) + + /** @group setParam */ + def setItemCol(value: String): this.type = set(itemCol, value) /** @group setParam */ def setPredictionCol(value: String): this.type = set(predictionCol, value) override def transform(dataset: DataFrame): DataFrame = { - import dataset.sqlContext.implicits._ - val users = userFactors.toDF("id", "features") - val items = itemFactors.toDF("id", "features") - // Register a UDF for DataFrame, and then // create a new column named map(predictionCol) by running the predict UDF. val predict = udf { (userFeatures: Seq[Float], itemFeatures: Seq[Float]) => if (userFeatures != null && itemFeatures != null) { - blas.sdot(k, userFeatures.toArray, 1, itemFeatures.toArray, 1) + blas.sdot(rank, userFeatures.toArray, 1, itemFeatures.toArray, 1) } else { Float.NaN } } dataset - .join(users, dataset($(userCol)) === users("id"), "left") - .join(items, dataset($(itemCol)) === items("id"), "left") - .select(dataset("*"), predict(users("features"), items("features")).as($(predictionCol))) + .join(userFactors, dataset($(userCol)) === userFactors("id"), "left") + .join(itemFactors, dataset($(itemCol)) === itemFactors("id"), "left") + .select(dataset("*"), + predict(userFactors("features"), itemFactors("features")).as($(predictionCol))) } override def transformSchema(schema: StructType): StructType = { - validateAndTransformSchema(schema) + SchemaUtils.checkColumnType(schema, $(userCol), IntegerType) + SchemaUtils.checkColumnType(schema, $(itemCol), IntegerType) + SchemaUtils.appendColumn(schema, $(predictionCol), FloatType) } } @@ -299,6 +309,7 @@ class ALS(override val uid: String) extends Estimator[ALSModel] with ALSParams { } override def fit(dataset: DataFrame): ALSModel = { + import dataset.sqlContext.implicits._ val ratings = dataset .select(col($(userCol)).cast(IntegerType), col($(itemCol)).cast(IntegerType), col($(ratingCol)).cast(FloatType)) @@ -310,7 +321,9 @@ class ALS(override val uid: String) extends Estimator[ALSModel] with ALSParams { maxIter = $(maxIter), regParam = $(regParam), implicitPrefs = $(implicitPrefs), alpha = $(alpha), nonnegative = $(nonnegative), checkpointInterval = $(checkpointInterval), seed = $(seed)) - val model = new ALSModel(uid, $(rank), userFactors, itemFactors).setParent(this) + val userDF = userFactors.toDF("id", "features") + val itemDF = itemFactors.toDF("id", "features") + val model = new ALSModel(uid, $(rank), userDF, itemDF).setParent(this) copyValues(model) } -- cgit v1.2.3