From 5ad78f62056f2560cd371ee964111a646806d0ff Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Thu, 29 Jan 2015 00:01:10 -0800 Subject: [SQL] Various DataFrame DSL update. 1. Added foreach, foreachPartition, flatMap to DataFrame. 2. Added col() in dsl. 3. Support renaming columns in toDataFrame. 4. Support type inference on arrays (in addition to Seq). 5. Updated mllib to use the new DSL. Author: Reynold Xin Closes #4260 from rxin/sql-dsl-update and squashes the following commits: 73466c1 [Reynold Xin] Fixed LogisticRegression. Also added better error message for resolve. fab3ccc [Reynold Xin] Bug fix. d31fcd2 [Reynold Xin] Style fix. 62608c4 [Reynold Xin] [SQL] Various DataFrame DSL update. --- .../scala/org/apache/spark/ml/Transformer.scala | 3 +- .../ml/classification/LogisticRegression.scala | 12 ++++---- .../apache/spark/ml/feature/StandardScaler.scala | 3 +- .../org/apache/spark/ml/recommendation/ALS.scala | 35 +++++++--------------- .../org/apache/spark/mllib/linalg/Vectors.scala | 3 +- 5 files changed, 20 insertions(+), 36 deletions(-) (limited to 'mllib') diff --git a/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala b/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala index 29cd981078..6eb7ea639c 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala @@ -23,7 +23,6 @@ import org.apache.spark.Logging import org.apache.spark.annotation.AlphaComponent import org.apache.spark.ml.param._ import org.apache.spark.sql.DataFrame -import org.apache.spark.sql._ import org.apache.spark.sql.api.scala.dsl._ import org.apache.spark.sql.types._ @@ -99,6 +98,6 @@ private[ml] abstract class UnaryTransformer[IN, OUT, T <: UnaryTransformer[IN, O transformSchema(dataset.schema, paramMap, logging = true) val map = this.paramMap ++ paramMap dataset.select($"*", callUDF( - this.createTransformFunc(map), outputDataType, Column(map(inputCol))).as(map(outputCol))) + this.createTransformFunc(map), outputDataType, dataset(map(inputCol))).as(map(outputCol))) } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala index 101f6c8114..d82360dcce 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala @@ -25,7 +25,6 @@ import org.apache.spark.mllib.linalg.{BLAS, Vector, VectorUDT} import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.sql._ import org.apache.spark.sql.api.scala.dsl._ -import org.apache.spark.sql.catalyst.dsl._ import org.apache.spark.sql.types.{DoubleType, StructField, StructType} import org.apache.spark.storage.StorageLevel @@ -133,15 +132,14 @@ class LogisticRegressionModel private[ml] ( override def transform(dataset: DataFrame, paramMap: ParamMap): DataFrame = { transformSchema(dataset.schema, paramMap, logging = true) val map = this.paramMap ++ paramMap - val score: Vector => Double = (v) => { + val scoreFunction: Vector => Double = (v) => { val margin = BLAS.dot(v, weights) 1.0 / (1.0 + math.exp(-margin)) } val t = map(threshold) - val predict: Double => Double = (score) => { - if (score > t) 1.0 else 0.0 - } - dataset.select($"*", callUDF(score, Column(map(featuresCol))).as(map(scoreCol))) - .select($"*", callUDF(predict, Column(map(scoreCol))).as(map(predictionCol))) + val predictFunction: Double => Double = (score) => { if (score > t) 1.0 else 0.0 } + dataset + .select($"*", callUDF(scoreFunction, col(map(featuresCol))).as(map(scoreCol))) + .select($"*", callUDF(predictFunction, col(map(scoreCol))).as(map(predictionCol))) } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala index c456beb65d..78a48561dd 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala @@ -24,7 +24,6 @@ import org.apache.spark.mllib.feature import org.apache.spark.mllib.linalg.{Vector, VectorUDT} import org.apache.spark.sql._ import org.apache.spark.sql.api.scala.dsl._ -import org.apache.spark.sql.catalyst.dsl._ import org.apache.spark.sql.types.{StructField, StructType} /** @@ -85,7 +84,7 @@ class StandardScalerModel private[ml] ( val scale: (Vector) => Vector = (v) => { scaler.transform(v) } - dataset.select($"*", callUDF(scale, Column(map(inputCol))).as(map(outputCol))) + dataset.select($"*", callUDF(scale, col(map(inputCol))).as(map(outputCol))) } private[ml] override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = { diff --git a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala index 738b1844b5..474d4731ec 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala @@ -111,20 +111,10 @@ class ALSModel private[ml] ( def setPredictionCol(value: String): this.type = set(predictionCol, value) override def transform(dataset: DataFrame, paramMap: ParamMap): DataFrame = { - import dataset.sqlContext._ - import org.apache.spark.ml.recommendation.ALSModel.Factor + import dataset.sqlContext.createDataFrame val map = this.paramMap ++ paramMap - // TODO: Add DSL to simplify the code here. - val instanceTable = s"instance_$uid" - val userTable = s"user_$uid" - val itemTable = s"item_$uid" - val instances = dataset.as(instanceTable) - val users = userFactors.map { case (id, features) => - Factor(id, features) - }.as(userTable) - val items = itemFactors.map { case (id, features) => - Factor(id, features) - }.as(itemTable) + val users = userFactors.toDataFrame("id", "features") + val items = itemFactors.toDataFrame("id", "features") val predict: (Seq[Float], Seq[Float]) => Float = (userFeatures, itemFeatures) => { if (userFeatures != null && itemFeatures != null) { blas.sdot(k, userFeatures.toArray, 1, itemFeatures.toArray, 1) @@ -133,13 +123,14 @@ class ALSModel private[ml] ( } } val inputColumns = dataset.schema.fieldNames - val prediction = callUDF(predict, $"$userTable.features", $"$itemTable.features") - .as(map(predictionCol)) - val outputColumns = inputColumns.map(f => $"$instanceTable.$f".as(f)) :+ prediction - instances - .join(users, Column(map(userCol)) === $"$userTable.id", "left") - .join(items, Column(map(itemCol)) === $"$itemTable.id", "left") + val prediction = callUDF(predict, users("features"), items("features")).as(map(predictionCol)) + val outputColumns = inputColumns.map(f => dataset(f)) :+ prediction + dataset + .join(users, dataset(map(userCol)) === users("id"), "left") + .join(items, dataset(map(itemCol)) === items("id"), "left") .select(outputColumns: _*) + // TODO: Just use a dataset("*") + // .select(dataset("*"), prediction) } override private[ml] def transformSchema(schema: StructType, paramMap: ParamMap): StructType = { @@ -147,10 +138,6 @@ class ALSModel private[ml] ( } } -private object ALSModel { - /** Case class to convert factors to [[DataFrame]]s */ - private case class Factor(id: Int, features: Seq[Float]) -} /** * Alternating Least Squares (ALS) matrix factorization. @@ -210,7 +197,7 @@ class ALS extends Estimator[ALSModel] with ALSParams { override def fit(dataset: DataFrame, paramMap: ParamMap): ALSModel = { val map = this.paramMap ++ paramMap val ratings = dataset - .select(Column(map(userCol)), Column(map(itemCol)), Column(map(ratingCol)).cast(FloatType)) + .select(col(map(userCol)), col(map(itemCol)), col(map(ratingCol)).cast(FloatType)) .map { row => new Rating(row.getInt(0), row.getInt(1), row.getFloat(2)) } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala index 31c33f1bf6..567a8a6c03 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala @@ -27,7 +27,8 @@ import breeze.linalg.{DenseVector => BDV, SparseVector => BSV, Vector => BV} import org.apache.spark.SparkException import org.apache.spark.mllib.util.NumericParser -import org.apache.spark.sql.catalyst.expressions.{GenericMutableRow, Row} +import org.apache.spark.sql.Row +import org.apache.spark.sql.catalyst.expressions.GenericMutableRow import org.apache.spark.sql.types._ /** -- cgit v1.2.3