aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
diff options
context:
space:
mode:
Diffstat (limited to 'mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala14
1 files changed, 6 insertions, 8 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
index 8c570812f8..eeb6301c3f 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
@@ -24,7 +24,7 @@ import org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS
import org.apache.spark.mllib.linalg.{BLAS, Vector, VectorUDT}
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.sql._
-import org.apache.spark.sql.catalyst.analysis.Star
+import org.apache.spark.sql.dsl._
import org.apache.spark.sql.catalyst.dsl._
import org.apache.spark.sql.types.{DoubleType, StructField, StructType}
import org.apache.spark.storage.StorageLevel
@@ -87,11 +87,10 @@ class LogisticRegression extends Estimator[LogisticRegressionModel] with Logisti
def setScoreCol(value: String): this.type = set(scoreCol, value)
def setPredictionCol(value: String): this.type = set(predictionCol, value)
- override def fit(dataset: SchemaRDD, paramMap: ParamMap): LogisticRegressionModel = {
+ override def fit(dataset: DataFrame, paramMap: ParamMap): LogisticRegressionModel = {
transformSchema(dataset.schema, paramMap, logging = true)
- import dataset.sqlContext._
val map = this.paramMap ++ paramMap
- val instances = dataset.select(map(labelCol).attr, map(featuresCol).attr)
+ val instances = dataset.select(map(labelCol), map(featuresCol))
.map { case Row(label: Double, features: Vector) =>
LabeledPoint(label, features)
}.persist(StorageLevel.MEMORY_AND_DISK)
@@ -131,9 +130,8 @@ class LogisticRegressionModel private[ml] (
validateAndTransformSchema(schema, paramMap, fitting = false)
}
- override def transform(dataset: SchemaRDD, paramMap: ParamMap): SchemaRDD = {
+ override def transform(dataset: DataFrame, paramMap: ParamMap): DataFrame = {
transformSchema(dataset.schema, paramMap, logging = true)
- import dataset.sqlContext._
val map = this.paramMap ++ paramMap
val score: Vector => Double = (v) => {
val margin = BLAS.dot(v, weights)
@@ -143,7 +141,7 @@ class LogisticRegressionModel private[ml] (
val predict: Double => Double = (score) => {
if (score > t) 1.0 else 0.0
}
- dataset.select(Star(None), score.call(map(featuresCol).attr) as map(scoreCol))
- .select(Star(None), predict.call(map(scoreCol).attr) as map(predictionCol))
+ dataset.select($"*", callUDF(score, Column(map(featuresCol))).as(map(scoreCol)))
+ .select($"*", callUDF(predict, Column(map(scoreCol))).as(map(predictionCol)))
}
}