aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
authorReynold Xin <rxin@databricks.com>2015-02-04 23:57:53 -0800
committerReynold Xin <rxin@databricks.com>2015-02-04 23:57:53 -0800
commitc3ba4d4cd032e376bfdf7ea7eaab65a79a771e7e (patch)
tree16a30493acc37d7e85536940352b12c794c797db /mllib
parent7d789e117d6ddaf66159e708db600f2d8db8d787 (diff)
downloadspark-c3ba4d4cd032e376bfdf7ea7eaab65a79a771e7e.tar.gz
spark-c3ba4d4cd032e376bfdf7ea7eaab65a79a771e7e.tar.bz2
spark-c3ba4d4cd032e376bfdf7ea7eaab65a79a771e7e.zip
[MLlib] Minor: UDF style update.
Author: Reynold Xin <rxin@databricks.com> Closes #4388 from rxin/mllib-style and squashes the following commits: 61d465b [Reynold Xin] oops 3364295 [Reynold Xin] Missed one .. 5e068e3 [Reynold Xin] [MLlib] Minor: UDF style update.
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala8
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala4
2 files changed, 7 insertions, 5 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
index df90078de1..b46a5cd8bd 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
@@ -132,12 +132,14 @@ class LogisticRegressionModel private[ml] (
override def transform(dataset: DataFrame, paramMap: ParamMap): DataFrame = {
transformSchema(dataset.schema, paramMap, logging = true)
val map = this.paramMap ++ paramMap
- val scoreFunction = udf((v: Vector) => {
+ val scoreFunction = udf { v: Vector =>
val margin = BLAS.dot(v, weights)
1.0 / (1.0 + math.exp(-margin))
- } : Double)
+ }
val t = map(threshold)
- val predictFunction = udf((score: Double) => { if (score > t) 1.0 else 0.0 } : Double)
+ val predictFunction = udf { score: Double =>
+ if (score > t) 1.0 else 0.0
+ }
dataset
.select($"*", scoreFunction(col(map(featuresCol))).as(map(scoreCol)))
.select($"*", predictFunction(col(map(scoreCol))).as(map(predictionCol)))
diff --git a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala
index 09456b289e..bf5737177c 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala
@@ -129,13 +129,13 @@ class ALSModel private[ml] (
// Register a UDF for DataFrame, and then
// create a new column named map(predictionCol) by running the predict UDF.
- val predict = udf((userFeatures: Seq[Float], itemFeatures: Seq[Float]) => {
+ val predict = udf { (userFeatures: Seq[Float], itemFeatures: Seq[Float]) =>
if (userFeatures != null && itemFeatures != null) {
blas.sdot(k, userFeatures.toArray, 1, itemFeatures.toArray, 1)
} else {
Float.NaN
}
- } : Float)
+ }
dataset
.join(users, dataset(map(userCol)) === users("id"), "left")
.join(items, dataset(map(itemCol)) === items("id"), "left")