aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala
diff options
context:
space:
mode:
Diffstat (limited to 'mllib/src/main/scala/org/apache/spark/ml/Transformer.scala')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/Transformer.scala17
1 files changed, 8 insertions, 9 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala b/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala
index af56f9c435..b233bff083 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala
@@ -22,9 +22,9 @@ import scala.annotation.varargs
import org.apache.spark.Logging
import org.apache.spark.annotation.AlphaComponent
import org.apache.spark.ml.param._
-import org.apache.spark.sql.SchemaRDD
-import org.apache.spark.sql.catalyst.analysis.Star
-import org.apache.spark.sql.catalyst.expressions.ScalaUdf
+import org.apache.spark.sql.DataFrame
+import org.apache.spark.sql._
+import org.apache.spark.sql.dsl._
import org.apache.spark.sql.types._
/**
@@ -41,7 +41,7 @@ abstract class Transformer extends PipelineStage with Params {
* @return transformed dataset
*/
@varargs
- def transform(dataset: SchemaRDD, paramPairs: ParamPair[_]*): SchemaRDD = {
+ def transform(dataset: DataFrame, paramPairs: ParamPair[_]*): DataFrame = {
val map = new ParamMap()
paramPairs.foreach(map.put(_))
transform(dataset, map)
@@ -53,7 +53,7 @@ abstract class Transformer extends PipelineStage with Params {
* @param paramMap additional parameters, overwrite embedded params
* @return transformed dataset
*/
- def transform(dataset: SchemaRDD, paramMap: ParamMap): SchemaRDD
+ def transform(dataset: DataFrame, paramMap: ParamMap): DataFrame
}
/**
@@ -95,11 +95,10 @@ private[ml] abstract class UnaryTransformer[IN, OUT, T <: UnaryTransformer[IN, O
StructType(outputFields)
}
- override def transform(dataset: SchemaRDD, paramMap: ParamMap): SchemaRDD = {
+ override def transform(dataset: DataFrame, paramMap: ParamMap): DataFrame = {
transformSchema(dataset.schema, paramMap, logging = true)
- import dataset.sqlContext._
val map = this.paramMap ++ paramMap
- val udf = ScalaUdf(this.createTransformFunc(map), outputDataType, Seq(map(inputCol).attr))
- dataset.select(Star(None), udf as map(outputCol))
+ dataset.select($"*", callUDF(
+ this.createTransformFunc(map), outputDataType, Column(map(inputCol))).as(map(outputCol)))
}
}