aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/Transformer.scala4
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/functions.scala14
2 files changed, 16 insertions, 2 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala b/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala
index 3c7bcf7590..1f3325ad09 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala
@@ -115,8 +115,8 @@ abstract class UnaryTransformer[IN, OUT, T <: UnaryTransformer[IN, OUT, T]]
override def transform(dataset: DataFrame): DataFrame = {
transformSchema(dataset.schema, logging = true)
- dataset.withColumn($(outputCol),
- callUDF(this.createTransformFunc, outputDataType, dataset($(inputCol))))
+ val transformUDF = udf(this.createTransformFunc, outputDataType)
+ dataset.withColumn($(outputCol), transformUDF(dataset($(inputCol))))
}
override def copy(extra: ParamMap): T = defaultCopy(extra)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
index 97c5aed6da..3572f3c3a1 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
@@ -2844,6 +2844,20 @@ object functions extends LegacyFunctions {
// scalastyle:on line.size.limit
/**
+ * Defines a user-defined function (UDF) using a Scala closure. For this variant, the caller must
+ * specifcy the output data type, and there is no automatic input type coercion.
+ *
+ * @param f A closure in Scala
+ * @param dataType The output data type of the UDF
+ *
+ * @group udf_funcs
+ * @since 2.0.0
+ */
+ def udf(f: AnyRef, dataType: DataType): UserDefinedFunction = {
+ UserDefinedFunction(f, dataType, None)
+ }
+
+ /**
* Call an user-defined function.
* Example:
* {{{