aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala
diff options
context:
space:
mode:
Diffstat (limited to 'mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala37
1 files changed, 17 insertions, 20 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala
index 2d89e76a4c..f6437c7fbc 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala
@@ -29,10 +29,8 @@ import org.apache.spark.{HashPartitioner, Logging, Partitioner}
import org.apache.spark.ml.{Estimator, Model}
import org.apache.spark.ml.param._
import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.SchemaRDD
-import org.apache.spark.sql.catalyst.dsl._
-import org.apache.spark.sql.catalyst.expressions.Cast
-import org.apache.spark.sql.catalyst.plans.LeftOuter
+import org.apache.spark.sql.{Column, DataFrame}
+import org.apache.spark.sql.dsl._
import org.apache.spark.sql.types.{DoubleType, FloatType, IntegerType, StructField, StructType}
import org.apache.spark.util.Utils
import org.apache.spark.util.collection.{OpenHashMap, OpenHashSet, SortDataFormat, Sorter}
@@ -112,7 +110,7 @@ class ALSModel private[ml] (
def setPredictionCol(value: String): this.type = set(predictionCol, value)
- override def transform(dataset: SchemaRDD, paramMap: ParamMap): SchemaRDD = {
+ override def transform(dataset: DataFrame, paramMap: ParamMap): DataFrame = {
import dataset.sqlContext._
import org.apache.spark.ml.recommendation.ALSModel.Factor
val map = this.paramMap ++ paramMap
@@ -120,13 +118,13 @@ class ALSModel private[ml] (
val instanceTable = s"instance_$uid"
val userTable = s"user_$uid"
val itemTable = s"item_$uid"
- val instances = dataset.as(Symbol(instanceTable))
+ val instances = dataset.as(instanceTable)
val users = userFactors.map { case (id, features) =>
Factor(id, features)
- }.as(Symbol(userTable))
+ }.as(userTable)
val items = itemFactors.map { case (id, features) =>
Factor(id, features)
- }.as(Symbol(itemTable))
+ }.as(itemTable)
val predict: (Seq[Float], Seq[Float]) => Float = (userFeatures, itemFeatures) => {
if (userFeatures != null && itemFeatures != null) {
blas.sdot(k, userFeatures.toArray, 1, itemFeatures.toArray, 1)
@@ -135,12 +133,12 @@ class ALSModel private[ml] (
}
}
val inputColumns = dataset.schema.fieldNames
- val prediction =
- predict.call(s"$userTable.features".attr, s"$itemTable.features".attr) as map(predictionCol)
- val outputColumns = inputColumns.map(f => s"$instanceTable.$f".attr as f) :+ prediction
+ val prediction = callUDF(predict, $"$userTable.features", $"$itemTable.features")
+ .as(map(predictionCol))
+ val outputColumns = inputColumns.map(f => $"$instanceTable.$f".as(f)) :+ prediction
instances
- .join(users, LeftOuter, Some(map(userCol).attr === s"$userTable.id".attr))
- .join(items, LeftOuter, Some(map(itemCol).attr === s"$itemTable.id".attr))
+ .join(users, Column(map(userCol)) === $"$userTable.id", "left")
+ .join(items, Column(map(itemCol)) === $"$itemTable.id", "left")
.select(outputColumns: _*)
}
@@ -209,14 +207,13 @@ class ALS extends Estimator[ALSModel] with ALSParams {
setMaxIter(20)
setRegParam(1.0)
- override def fit(dataset: SchemaRDD, paramMap: ParamMap): ALSModel = {
- import dataset.sqlContext._
+ override def fit(dataset: DataFrame, paramMap: ParamMap): ALSModel = {
val map = this.paramMap ++ paramMap
- val ratings =
- dataset.select(map(userCol).attr, map(itemCol).attr, Cast(map(ratingCol).attr, FloatType))
- .map { row =>
- new Rating(row.getInt(0), row.getInt(1), row.getFloat(2))
- }
+ val ratings = dataset
+ .select(Column(map(userCol)), Column(map(itemCol)), Column(map(ratingCol)).cast(FloatType))
+ .map { row =>
+ new Rating(row.getInt(0), row.getInt(1), row.getFloat(2))
+ }
val (userFactors, itemFactors) = ALS.train(ratings, rank = map(rank),
numUserBlocks = map(numUserBlocks), numItemBlocks = map(numItemBlocks),
maxIter = map(maxIter), regParam = map(regParam), implicitPrefs = map(implicitPrefs),