aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala
diff options
context:
space:
mode:
Diffstat (limited to 'mllib/src/main/scala/org/apache/spark/ml/Transformer.scala')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/Transformer.scala18
1 files changed, 10 insertions, 8 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala b/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala
index 490e6609ad..23fbd228d0 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala
@@ -18,16 +18,14 @@
package org.apache.spark.ml
import scala.annotation.varargs
-import scala.reflect.runtime.universe.TypeTag
import org.apache.spark.Logging
import org.apache.spark.annotation.AlphaComponent
import org.apache.spark.ml.param._
import org.apache.spark.sql.SchemaRDD
import org.apache.spark.sql.api.java.JavaSchemaRDD
-import org.apache.spark.sql.catalyst.ScalaReflection
import org.apache.spark.sql.catalyst.analysis.Star
-import org.apache.spark.sql.catalyst.dsl._
+import org.apache.spark.sql.catalyst.expressions.ScalaUdf
import org.apache.spark.sql.catalyst.types._
/**
@@ -86,7 +84,7 @@ abstract class Transformer extends PipelineStage with Params {
* Abstract class for transformers that take one input column, apply transformation, and output the
* result as a new column.
*/
-private[ml] abstract class UnaryTransformer[IN, OUT: TypeTag, T <: UnaryTransformer[IN, OUT, T]]
+private[ml] abstract class UnaryTransformer[IN, OUT, T <: UnaryTransformer[IN, OUT, T]]
extends Transformer with HasInputCol with HasOutputCol with Logging {
def setInputCol(value: String): T = set(inputCol, value).asInstanceOf[T]
@@ -100,6 +98,11 @@ private[ml] abstract class UnaryTransformer[IN, OUT: TypeTag, T <: UnaryTransfor
protected def createTransformFunc(paramMap: ParamMap): IN => OUT
/**
+ * Returns the data type of the output column.
+ */
+ protected def outputDataType: DataType
+
+ /**
* Validates the input type. Throw an exception if it is invalid.
*/
protected def validateInputType(inputType: DataType): Unit = {}
@@ -111,9 +114,8 @@ private[ml] abstract class UnaryTransformer[IN, OUT: TypeTag, T <: UnaryTransfor
if (schema.fieldNames.contains(map(outputCol))) {
throw new IllegalArgumentException(s"Output column ${map(outputCol)} already exists.")
}
- val output = ScalaReflection.schemaFor[OUT]
val outputFields = schema.fields :+
- StructField(map(outputCol), output.dataType, output.nullable)
+ StructField(map(outputCol), outputDataType, !outputDataType.isPrimitive)
StructType(outputFields)
}
@@ -121,7 +123,7 @@ private[ml] abstract class UnaryTransformer[IN, OUT: TypeTag, T <: UnaryTransfor
transformSchema(dataset.schema, paramMap, logging = true)
import dataset.sqlContext._
val map = this.paramMap ++ paramMap
- val udf = this.createTransformFunc(map)
- dataset.select(Star(None), udf.call(map(inputCol).attr) as map(outputCol))
+ val udf = ScalaUdf(this.createTransformFunc(map), outputDataType, Seq(map(inputCol).attr))
+ dataset.select(Star(None), udf as map(outputCol))
}
}