diff options
author | Reynold Xin <rxin@databricks.com> | 2015-02-04 23:57:53 -0800 |
---|---|---|
committer | Reynold Xin <rxin@databricks.com> | 2015-02-04 23:58:13 -0800 |
commit | 40746749a6670f43aece69fe1482e92fa87decf5 (patch) | |
tree | 16a30493acc37d7e85536940352b12c794c797db /mllib | |
parent | 0040b6128d1ab6f5bb1b87629a9b324b5e802b47 (diff) | |
download | spark-40746749a6670f43aece69fe1482e92fa87decf5.tar.gz spark-40746749a6670f43aece69fe1482e92fa87decf5.tar.bz2 spark-40746749a6670f43aece69fe1482e92fa87decf5.zip |
[MLlib] Minor: UDF style update.
Author: Reynold Xin <rxin@databricks.com>
Closes #4388 from rxin/mllib-style and squashes the following commits:
61d465b [Reynold Xin] oops
3364295 [Reynold Xin] Missed one ..
5e068e3 [Reynold Xin] [MLlib] Minor: UDF style update.
(cherry picked from commit c3ba4d4cd032e376bfdf7ea7eaab65a79a771e7e)
Signed-off-by: Reynold Xin <rxin@databricks.com>
Diffstat (limited to 'mllib')
-rw-r--r-- | mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala | 8 | ||||
-rw-r--r-- | mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala | 4 |
2 files changed, 7 insertions, 5 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala index df90078de1..b46a5cd8bd 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala @@ -132,12 +132,14 @@ class LogisticRegressionModel private[ml] ( override def transform(dataset: DataFrame, paramMap: ParamMap): DataFrame = { transformSchema(dataset.schema, paramMap, logging = true) val map = this.paramMap ++ paramMap - val scoreFunction = udf((v: Vector) => { + val scoreFunction = udf { v: Vector => val margin = BLAS.dot(v, weights) 1.0 / (1.0 + math.exp(-margin)) - } : Double) + } val t = map(threshold) - val predictFunction = udf((score: Double) => { if (score > t) 1.0 else 0.0 } : Double) + val predictFunction = udf { score: Double => + if (score > t) 1.0 else 0.0 + } dataset .select($"*", scoreFunction(col(map(featuresCol))).as(map(scoreCol))) .select($"*", predictFunction(col(map(scoreCol))).as(map(predictionCol))) diff --git a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala index 09456b289e..bf5737177c 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala @@ -129,13 +129,13 @@ class ALSModel private[ml] ( // Register a UDF for DataFrame, and then // create a new column named map(predictionCol) by running the predict UDF. - val predict = udf((userFeatures: Seq[Float], itemFeatures: Seq[Float]) => { + val predict = udf { (userFeatures: Seq[Float], itemFeatures: Seq[Float]) => if (userFeatures != null && itemFeatures != null) { blas.sdot(k, userFeatures.toArray, 1, itemFeatures.toArray, 1) } else { Float.NaN } - } : Float) + } dataset .join(users, dataset(map(userCol)) === users("id"), "left") .join(items, dataset(map(itemCol)) === items("id"), "left") |