aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
authorReynold Xin <rxin@databricks.com>2015-01-29 00:01:10 -0800
committerReynold Xin <rxin@databricks.com>2015-01-29 00:01:10 -0800
commit5ad78f62056f2560cd371ee964111a646806d0ff (patch)
treec5db8104a00b4a835db77bf7f7116622b47c8cc3 /mllib
parenta63be1a18f7b7d77f7deef2abc9a5be6ad24ae28 (diff)
downloadspark-5ad78f62056f2560cd371ee964111a646806d0ff.tar.gz
spark-5ad78f62056f2560cd371ee964111a646806d0ff.tar.bz2
spark-5ad78f62056f2560cd371ee964111a646806d0ff.zip
[SQL] Various DataFrame DSL update.
1. Added foreach, foreachPartition, flatMap to DataFrame. 2. Added col() in dsl. 3. Support renaming columns in toDataFrame. 4. Support type inference on arrays (in addition to Seq). 5. Updated mllib to use the new DSL. Author: Reynold Xin <rxin@databricks.com> Closes #4260 from rxin/sql-dsl-update and squashes the following commits: 73466c1 [Reynold Xin] Fixed LogisticRegression. Also added better error message for resolve. fab3ccc [Reynold Xin] Bug fix. d31fcd2 [Reynold Xin] Style fix. 62608c4 [Reynold Xin] [SQL] Various DataFrame DSL update.
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/Transformer.scala3
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala12
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala3
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala35
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala3
5 files changed, 20 insertions, 36 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala b/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala
index 29cd981078..6eb7ea639c 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala
@@ -23,7 +23,6 @@ import org.apache.spark.Logging
import org.apache.spark.annotation.AlphaComponent
import org.apache.spark.ml.param._
import org.apache.spark.sql.DataFrame
-import org.apache.spark.sql._
import org.apache.spark.sql.api.scala.dsl._
import org.apache.spark.sql.types._
@@ -99,6 +98,6 @@ private[ml] abstract class UnaryTransformer[IN, OUT, T <: UnaryTransformer[IN, O
transformSchema(dataset.schema, paramMap, logging = true)
val map = this.paramMap ++ paramMap
dataset.select($"*", callUDF(
- this.createTransformFunc(map), outputDataType, Column(map(inputCol))).as(map(outputCol)))
+ this.createTransformFunc(map), outputDataType, dataset(map(inputCol))).as(map(outputCol)))
}
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
index 101f6c8114..d82360dcce 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
@@ -25,7 +25,6 @@ import org.apache.spark.mllib.linalg.{BLAS, Vector, VectorUDT}
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.sql._
import org.apache.spark.sql.api.scala.dsl._
-import org.apache.spark.sql.catalyst.dsl._
import org.apache.spark.sql.types.{DoubleType, StructField, StructType}
import org.apache.spark.storage.StorageLevel
@@ -133,15 +132,14 @@ class LogisticRegressionModel private[ml] (
override def transform(dataset: DataFrame, paramMap: ParamMap): DataFrame = {
transformSchema(dataset.schema, paramMap, logging = true)
val map = this.paramMap ++ paramMap
- val score: Vector => Double = (v) => {
+ val scoreFunction: Vector => Double = (v) => {
val margin = BLAS.dot(v, weights)
1.0 / (1.0 + math.exp(-margin))
}
val t = map(threshold)
- val predict: Double => Double = (score) => {
- if (score > t) 1.0 else 0.0
- }
- dataset.select($"*", callUDF(score, Column(map(featuresCol))).as(map(scoreCol)))
- .select($"*", callUDF(predict, Column(map(scoreCol))).as(map(predictionCol)))
+ val predictFunction: Double => Double = (score) => { if (score > t) 1.0 else 0.0 }
+ dataset
+ .select($"*", callUDF(scoreFunction, col(map(featuresCol))).as(map(scoreCol)))
+ .select($"*", callUDF(predictFunction, col(map(scoreCol))).as(map(predictionCol)))
}
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala
index c456beb65d..78a48561dd 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala
@@ -24,7 +24,6 @@ import org.apache.spark.mllib.feature
import org.apache.spark.mllib.linalg.{Vector, VectorUDT}
import org.apache.spark.sql._
import org.apache.spark.sql.api.scala.dsl._
-import org.apache.spark.sql.catalyst.dsl._
import org.apache.spark.sql.types.{StructField, StructType}
/**
@@ -85,7 +84,7 @@ class StandardScalerModel private[ml] (
val scale: (Vector) => Vector = (v) => {
scaler.transform(v)
}
- dataset.select($"*", callUDF(scale, Column(map(inputCol))).as(map(outputCol)))
+ dataset.select($"*", callUDF(scale, col(map(inputCol))).as(map(outputCol)))
}
private[ml] override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {
diff --git a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala
index 738b1844b5..474d4731ec 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala
@@ -111,20 +111,10 @@ class ALSModel private[ml] (
def setPredictionCol(value: String): this.type = set(predictionCol, value)
override def transform(dataset: DataFrame, paramMap: ParamMap): DataFrame = {
- import dataset.sqlContext._
- import org.apache.spark.ml.recommendation.ALSModel.Factor
+ import dataset.sqlContext.createDataFrame
val map = this.paramMap ++ paramMap
- // TODO: Add DSL to simplify the code here.
- val instanceTable = s"instance_$uid"
- val userTable = s"user_$uid"
- val itemTable = s"item_$uid"
- val instances = dataset.as(instanceTable)
- val users = userFactors.map { case (id, features) =>
- Factor(id, features)
- }.as(userTable)
- val items = itemFactors.map { case (id, features) =>
- Factor(id, features)
- }.as(itemTable)
+ val users = userFactors.toDataFrame("id", "features")
+ val items = itemFactors.toDataFrame("id", "features")
val predict: (Seq[Float], Seq[Float]) => Float = (userFeatures, itemFeatures) => {
if (userFeatures != null && itemFeatures != null) {
blas.sdot(k, userFeatures.toArray, 1, itemFeatures.toArray, 1)
@@ -133,13 +123,14 @@ class ALSModel private[ml] (
}
}
val inputColumns = dataset.schema.fieldNames
- val prediction = callUDF(predict, $"$userTable.features", $"$itemTable.features")
- .as(map(predictionCol))
- val outputColumns = inputColumns.map(f => $"$instanceTable.$f".as(f)) :+ prediction
- instances
- .join(users, Column(map(userCol)) === $"$userTable.id", "left")
- .join(items, Column(map(itemCol)) === $"$itemTable.id", "left")
+ val prediction = callUDF(predict, users("features"), items("features")).as(map(predictionCol))
+ val outputColumns = inputColumns.map(f => dataset(f)) :+ prediction
+ dataset
+ .join(users, dataset(map(userCol)) === users("id"), "left")
+ .join(items, dataset(map(itemCol)) === items("id"), "left")
.select(outputColumns: _*)
+ // TODO: Just use a dataset("*")
+ // .select(dataset("*"), prediction)
}
override private[ml] def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {
@@ -147,10 +138,6 @@ class ALSModel private[ml] (
}
}
-private object ALSModel {
- /** Case class to convert factors to [[DataFrame]]s */
- private case class Factor(id: Int, features: Seq[Float])
-}
/**
* Alternating Least Squares (ALS) matrix factorization.
@@ -210,7 +197,7 @@ class ALS extends Estimator[ALSModel] with ALSParams {
override def fit(dataset: DataFrame, paramMap: ParamMap): ALSModel = {
val map = this.paramMap ++ paramMap
val ratings = dataset
- .select(Column(map(userCol)), Column(map(itemCol)), Column(map(ratingCol)).cast(FloatType))
+ .select(col(map(userCol)), col(map(itemCol)), col(map(ratingCol)).cast(FloatType))
.map { row =>
new Rating(row.getInt(0), row.getInt(1), row.getFloat(2))
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala
index 31c33f1bf6..567a8a6c03 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala
@@ -27,7 +27,8 @@ import breeze.linalg.{DenseVector => BDV, SparseVector => BSV, Vector => BV}
import org.apache.spark.SparkException
import org.apache.spark.mllib.util.NumericParser
-import org.apache.spark.sql.catalyst.expressions.{GenericMutableRow, Row}
+import org.apache.spark.sql.Row
+import org.apache.spark.sql.catalyst.expressions.GenericMutableRow
import org.apache.spark.sql.types._
/**