diff options
Diffstat (limited to 'mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala')
-rw-r--r-- | mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala | 14 |
1 files changed, 6 insertions, 8 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala index 8c570812f8..eeb6301c3f 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala @@ -24,7 +24,7 @@ import org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS import org.apache.spark.mllib.linalg.{BLAS, Vector, VectorUDT} import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.sql._ -import org.apache.spark.sql.catalyst.analysis.Star +import org.apache.spark.sql.dsl._ import org.apache.spark.sql.catalyst.dsl._ import org.apache.spark.sql.types.{DoubleType, StructField, StructType} import org.apache.spark.storage.StorageLevel @@ -87,11 +87,10 @@ class LogisticRegression extends Estimator[LogisticRegressionModel] with Logisti def setScoreCol(value: String): this.type = set(scoreCol, value) def setPredictionCol(value: String): this.type = set(predictionCol, value) - override def fit(dataset: SchemaRDD, paramMap: ParamMap): LogisticRegressionModel = { + override def fit(dataset: DataFrame, paramMap: ParamMap): LogisticRegressionModel = { transformSchema(dataset.schema, paramMap, logging = true) - import dataset.sqlContext._ val map = this.paramMap ++ paramMap - val instances = dataset.select(map(labelCol).attr, map(featuresCol).attr) + val instances = dataset.select(map(labelCol), map(featuresCol)) .map { case Row(label: Double, features: Vector) => LabeledPoint(label, features) }.persist(StorageLevel.MEMORY_AND_DISK) @@ -131,9 +130,8 @@ class LogisticRegressionModel private[ml] ( validateAndTransformSchema(schema, paramMap, fitting = false) } - override def transform(dataset: SchemaRDD, paramMap: ParamMap): SchemaRDD = { + override def transform(dataset: DataFrame, paramMap: ParamMap): DataFrame = { transformSchema(dataset.schema, paramMap, logging = true) - import dataset.sqlContext._ val map = this.paramMap ++ paramMap val score: Vector => Double = (v) => { val margin = BLAS.dot(v, weights) @@ -143,7 +141,7 @@ class LogisticRegressionModel private[ml] ( val predict: Double => Double = (score) => { if (score > t) 1.0 else 0.0 } - dataset.select(Star(None), score.call(map(featuresCol).attr) as map(scoreCol)) - .select(Star(None), predict.call(map(scoreCol).attr) as map(predictionCol)) + dataset.select($"*", callUDF(score, Column(map(featuresCol))).as(map(scoreCol))) + .select($"*", callUDF(predict, Column(map(scoreCol))).as(map(predictionCol))) } } |