aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
authorReynold Xin <rxin@databricks.com>2015-02-03 20:07:46 -0800
committerReynold Xin <rxin@databricks.com>2015-02-03 20:07:46 -0800
commit1077f2e1def6266aee6ad6f0640a8f46cd273e21 (patch)
treeb843507b12b1e45148cb4239a1bd6ee452953144 /mllib
parente380d2d46c92b319eafe30974ac7c1509081fca4 (diff)
downloadspark-1077f2e1def6266aee6ad6f0640a8f46cd273e21.tar.gz
spark-1077f2e1def6266aee6ad6f0640a8f46cd273e21.tar.bz2
spark-1077f2e1def6266aee6ad6f0640a8f46cd273e21.zip
[SPARK-5578][SQL][DataFrame] Provide a convenient way for Scala users to use UDFs
A more convenient way to define user-defined functions. Author: Reynold Xin <rxin@databricks.com> Closes #4345 from rxin/defineUDF and squashes the following commits: 639c0f8 [Reynold Xin] udf tests. 0a0b339 [Reynold Xin] defineUDF -> udf. b452b8d [Reynold Xin] Fix UDF registration. d2e42c3 [Reynold Xin] SQLContext.udf.register() returns a UserDefinedFunction also. 4333605 [Reynold Xin] [SQL][DataFrame] defineUDF.
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala10
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala6
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala14
3 files changed, 13 insertions, 17 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
index 18be35ad59..df90078de1 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
@@ -132,14 +132,14 @@ class LogisticRegressionModel private[ml] (
override def transform(dataset: DataFrame, paramMap: ParamMap): DataFrame = {
transformSchema(dataset.schema, paramMap, logging = true)
val map = this.paramMap ++ paramMap
- val scoreFunction: Vector => Double = (v) => {
+ val scoreFunction = udf((v: Vector) => {
val margin = BLAS.dot(v, weights)
1.0 / (1.0 + math.exp(-margin))
- }
+ } : Double)
val t = map(threshold)
- val predictFunction: Double => Double = (score) => { if (score > t) 1.0 else 0.0 }
+ val predictFunction = udf((score: Double) => { if (score > t) 1.0 else 0.0 } : Double)
dataset
- .select($"*", callUDF(scoreFunction, col(map(featuresCol))).as(map(scoreCol)))
- .select($"*", callUDF(predictFunction, col(map(scoreCol))).as(map(predictionCol)))
+ .select($"*", scoreFunction(col(map(featuresCol))).as(map(scoreCol)))
+ .select($"*", predictFunction(col(map(scoreCol))).as(map(predictionCol)))
}
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala
index 01a4f5eb20..4745a7ae95 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala
@@ -81,10 +81,8 @@ class StandardScalerModel private[ml] (
override def transform(dataset: DataFrame, paramMap: ParamMap): DataFrame = {
transformSchema(dataset.schema, paramMap, logging = true)
val map = this.paramMap ++ paramMap
- val scale: (Vector) => Vector = (v) => {
- scaler.transform(v)
- }
- dataset.select($"*", callUDF(scale, col(map(inputCol))).as(map(outputCol)))
+ val scale = udf((v: Vector) => { scaler.transform(v) } : Vector)
+ dataset.select($"*", scale(col(map(inputCol))).as(map(outputCol)))
}
private[ml] override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {
diff --git a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala
index 511cb2fe40..c7bec7a845 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala
@@ -126,22 +126,20 @@ class ALSModel private[ml] (
val map = this.paramMap ++ paramMap
val users = userFactors.toDataFrame("id", "features")
val items = itemFactors.toDataFrame("id", "features")
- val predict: (Seq[Float], Seq[Float]) => Float = (userFeatures, itemFeatures) => {
+
+ // Register a UDF for DataFrame, and then
+ // create a new column named map(predictionCol) by running the predict UDF.
+ val predict = udf((userFeatures: Seq[Float], itemFeatures: Seq[Float]) => {
if (userFeatures != null && itemFeatures != null) {
blas.sdot(k, userFeatures.toArray, 1, itemFeatures.toArray, 1)
} else {
Float.NaN
}
- }
- val inputColumns = dataset.schema.fieldNames
- val prediction = callUDF(predict, users("features"), items("features")).as(map(predictionCol))
- val outputColumns = inputColumns.map(f => dataset(f)) :+ prediction
+ } : Float)
dataset
.join(users, dataset(map(userCol)) === users("id"), "left")
.join(items, dataset(map(itemCol)) === items("id"), "left")
- .select(outputColumns: _*)
- // TODO: Just use a dataset("*")
- // .select(dataset("*"), prediction)
+ .select(dataset("*"), predict(users("features"), items("features")).as(map(predictionCol)))
}
override private[ml] def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {