diff options
-rw-r--r-- | mllib/src/main/scala/org/apache/spark/ml/Transformer.scala | 4 | ||||
-rw-r--r-- | sql/core/src/main/scala/org/apache/spark/sql/functions.scala | 14 |
2 files changed, 16 insertions, 2 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala b/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala index 3c7bcf7590..1f3325ad09 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala @@ -115,8 +115,8 @@ abstract class UnaryTransformer[IN, OUT, T <: UnaryTransformer[IN, OUT, T]] override def transform(dataset: DataFrame): DataFrame = { transformSchema(dataset.schema, logging = true) - dataset.withColumn($(outputCol), - callUDF(this.createTransformFunc, outputDataType, dataset($(inputCol)))) + val transformUDF = udf(this.createTransformFunc, outputDataType) + dataset.withColumn($(outputCol), transformUDF(dataset($(inputCol)))) } override def copy(extra: ParamMap): T = defaultCopy(extra) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 97c5aed6da..3572f3c3a1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -2844,6 +2844,20 @@ object functions extends LegacyFunctions { // scalastyle:on line.size.limit /** + * Defines a user-defined function (UDF) using a Scala closure. For this variant, the caller must + * specifcy the output data type, and there is no automatic input type coercion. + * + * @param f A closure in Scala + * @param dataType The output data type of the UDF + * + * @group udf_funcs + * @since 2.0.0 + */ + def udf(f: AnyRef, dataType: DataType): UserDefinedFunction = { + UserDefinedFunction(f, dataType, None) + } + + /** * Call an user-defined function. * Example: * {{{ |