aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorReynold Xin <rxin@databricks.com>2016-02-13 21:06:31 -0800
committerReynold Xin <rxin@databricks.com>2016-02-13 21:06:31 -0800
commit354d4c24be892271bd9a9eab6ceedfbc5d671c9c (patch)
treec0503ad0c303e6db4882bdbfa356fb78a8dd32fb
parent388cd9ea8db2e438ebef9dfb894298f843438c43 (diff)
downloadspark-354d4c24be892271bd9a9eab6ceedfbc5d671c9c.tar.gz
spark-354d4c24be892271bd9a9eab6ceedfbc5d671c9c.tar.bz2
spark-354d4c24be892271bd9a9eab6ceedfbc5d671c9c.zip
[SPARK-13296][SQL] Move UserDefinedFunction into sql.expressions.
This pull request has the following changes: 1. Moved UserDefinedFunction into expressions package. This is more consistent with how we structure the packages for window functions and UDAFs. 2. Moved UserDefinedPythonFunction into execution.python package, so we don't have a random private class in the top level sql package. 3. Move everything in execution/python.scala into the newly created execution.python package. Most of the diffs are just straight copy-paste. Author: Reynold Xin <rxin@databricks.com> Closes #11181 from rxin/SPARK-13296.
-rw-r--r--project/MimaExcludes.scala8
-rw-r--r--python/pyspark/sql/dataframe.py2
-rw-r--r--python/pyspark/sql/functions.py6
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala3
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala4
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala3
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala4
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchPythonEvaluation.scala104
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvaluatePython.scala (renamed from sql/core/src/main/scala/org/apache/spark/sql/execution/python.scala)187
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala79
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDF.scala44
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonFunction.scala51
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala (renamed from sql/core/src/main/scala/org/apache/spark/sql/UserDefinedFunction.scala)39
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/functions.scala1
-rw-r--r--sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala2
15 files changed, 320 insertions, 217 deletions
diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala
index 8611106db0..6abab7f126 100644
--- a/project/MimaExcludes.scala
+++ b/project/MimaExcludes.scala
@@ -235,7 +235,13 @@ object MimaExcludes {
ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.scheduler.cluster.YarnSchedulerBackend$YarnSchedulerEndpoint")
) ++ Seq(
// SPARK-7889
- ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.deploy.history.HistoryServer.org$apache$spark$deploy$history$HistoryServer$@tachSparkUI")
+ ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.deploy.history.HistoryServer.org$apache$spark$deploy$history$HistoryServer$@tachSparkUI"),
+ // SPARK-13296
+ ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.sql.UDFRegistration.register"),
+ ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.UserDefinedPythonFunction$"),
+ ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.UserDefinedPythonFunction"),
+ ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.UserDefinedFunction"),
+ ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.UserDefinedFunction$")
)
case v if v.startsWith("1.6") =>
Seq(
diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py
index 3104e41407..83b034fe77 100644
--- a/python/pyspark/sql/dataframe.py
+++ b/python/pyspark/sql/dataframe.py
@@ -262,7 +262,7 @@ class DataFrame(object):
[Row(age=2, name=u'Alice'), Row(age=5, name=u'Bob')]
"""
with SCCallSiteSync(self._sc) as css:
- port = self._sc._jvm.org.apache.spark.sql.execution.EvaluatePython.takeAndServe(
+ port = self._sc._jvm.org.apache.spark.sql.execution.python.EvaluatePython.takeAndServe(
self._jdf, num)
return list(_load_from_socket(port, BatchedSerializer(PickleSerializer())))
diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py
index 416d722bba..5fc1cc2cae 100644
--- a/python/pyspark/sql/functions.py
+++ b/python/pyspark/sql/functions.py
@@ -1652,9 +1652,9 @@ class UserDefinedFunction(object):
jdt = ctx._ssql_ctx.parseDataType(self.returnType.json())
if name is None:
name = f.__name__ if hasattr(f, '__name__') else f.__class__.__name__
- judf = sc._jvm.UserDefinedPythonFunction(name, bytearray(pickled_command), env, includes,
- sc.pythonExec, sc.pythonVer, broadcast_vars,
- sc._javaAccumulator, jdt)
+ judf = sc._jvm.org.apache.spark.sql.execution.python.UserDefinedPythonFunction(
+ name, bytearray(pickled_command), env, includes, sc.pythonExec, sc.pythonVer,
+ broadcast_vars, sc._javaAccumulator, jdt)
return judf
def __del__(self):
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
index c5b2b7d118..76c09a285d 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
@@ -36,9 +36,10 @@ import org.apache.spark.sql.catalyst.expressions.aggregate._
import org.apache.spark.sql.catalyst.optimizer.CombineUnions
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.logical._
-import org.apache.spark.sql.execution.{EvaluatePython, ExplainCommand, FileRelation, LogicalRDD, Queryable, QueryExecution, SQLExecution}
+import org.apache.spark.sql.execution.{ExplainCommand, FileRelation, LogicalRDD, Queryable, QueryExecution, SQLExecution}
import org.apache.spark.sql.execution.datasources.{CreateTableUsingAsSelect, LogicalRelation}
import org.apache.spark.sql.execution.datasources.json.JacksonGenerator
+import org.apache.spark.sql.execution.python.EvaluatePython
import org.apache.spark.sql.types._
import org.apache.spark.storage.StorageLevel
import org.apache.spark.util.Utils
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
index d58b99655c..c7d1096a13 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
@@ -193,7 +193,7 @@ class SQLContext private[sql](
protected[sql] lazy val analyzer: Analyzer =
new Analyzer(catalog, functionRegistry, conf) {
override val extendedResolutionRules =
- ExtractPythonUDFs ::
+ python.ExtractPythonUDFs ::
PreInsertCastAndRename ::
(if (conf.runSQLOnFile) new ResolveDataSource(self) :: Nil else Nil)
@@ -915,7 +915,7 @@ class SQLContext private[sql](
rdd: RDD[Array[Any]],
schema: StructType): DataFrame = {
- val rowRdd = rdd.map(r => EvaluatePython.fromJava(r, schema).asInstanceOf[InternalRow])
+ val rowRdd = rdd.map(r => python.EvaluatePython.fromJava(r, schema).asInstanceOf[InternalRow])
DataFrame(this, LogicalRDD(schema.toAttributes, rowRdd)(self))
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala b/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala
index f87a88d497..ecfc170bee 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala
@@ -25,7 +25,8 @@ import org.apache.spark.sql.api.java._
import org.apache.spark.sql.catalyst.ScalaReflection
import org.apache.spark.sql.catalyst.expressions.{Expression, ScalaUDF}
import org.apache.spark.sql.execution.aggregate.ScalaUDAF
-import org.apache.spark.sql.expressions.UserDefinedAggregateFunction
+import org.apache.spark.sql.execution.python.UserDefinedPythonFunction
+import org.apache.spark.sql.expressions.{UserDefinedAggregateFunction, UserDefinedFunction}
import org.apache.spark.sql.types.DataType
/**
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
index 598ddd7161..73fd22b38e 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
@@ -369,8 +369,8 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
case logical.RepartitionByExpression(expressions, child, nPartitions) =>
execution.Exchange(HashPartitioning(
expressions, nPartitions.getOrElse(numPartitions)), planLater(child)) :: Nil
- case e @ EvaluatePython(udf, child, _) =>
- BatchPythonEvaluation(udf, e.output, planLater(child)) :: Nil
+ case e @ python.EvaluatePython(udf, child, _) =>
+ python.BatchPythonEvaluation(udf, e.output, planLater(child)) :: Nil
case LogicalRDD(output, rdd) => PhysicalRDD(output, rdd, "ExistingRDD") :: Nil
case BroadcastHint(child) => planLater(child) :: Nil
case _ => Nil
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchPythonEvaluation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchPythonEvaluation.scala
new file mode 100644
index 0000000000..00df019527
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchPythonEvaluation.scala
@@ -0,0 +1,104 @@
+/*
+* Licensed to the Apache Software Foundation (ASF) under one or more
+* contributor license agreements. See the NOTICE file distributed with
+* this work for additional information regarding copyright ownership.
+* The ASF licenses this file to You under the Apache License, Version 2.0
+* (the "License"); you may not use this file except in compliance with
+* the License. You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*/
+
+package org.apache.spark.sql.execution.python
+
+import scala.collection.JavaConverters._
+
+import net.razorvine.pickle.{Pickler, Unpickler}
+
+import org.apache.spark.TaskContext
+import org.apache.spark.api.python.PythonRunner
+import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions.{Attribute, GenericMutableRow, JoinedRow, UnsafeProjection}
+import org.apache.spark.sql.execution.SparkPlan
+import org.apache.spark.sql.types.{StructField, StructType}
+
+
+/**
+ * A physical plan that evalutes a [[PythonUDF]], one partition of tuples at a time.
+ *
+ * Python evaluation works by sending the necessary (projected) input data via a socket to an
+ * external Python process, and combine the result from the Python process with the original row.
+ *
+ * For each row we send to Python, we also put it in a queue. For each output row from Python,
+ * we drain the queue to find the original input row. Note that if the Python process is way too
+ * slow, this could lead to the queue growing unbounded and eventually run out of memory.
+ */
+case class BatchPythonEvaluation(udf: PythonUDF, output: Seq[Attribute], child: SparkPlan)
+ extends SparkPlan {
+
+ def children: Seq[SparkPlan] = child :: Nil
+
+ protected override def doExecute(): RDD[InternalRow] = {
+ val inputRDD = child.execute().map(_.copy())
+ val bufferSize = inputRDD.conf.getInt("spark.buffer.size", 65536)
+ val reuseWorker = inputRDD.conf.getBoolean("spark.python.worker.reuse", defaultValue = true)
+
+ inputRDD.mapPartitions { iter =>
+ EvaluatePython.registerPicklers() // register pickler for Row
+
+ // The queue used to buffer input rows so we can drain it to
+ // combine input with output from Python.
+ val queue = new java.util.concurrent.ConcurrentLinkedQueue[InternalRow]()
+
+ val pickle = new Pickler
+ val currentRow = newMutableProjection(udf.children, child.output)()
+ val fields = udf.children.map(_.dataType)
+ val schema = new StructType(fields.map(t => new StructField("", t, true)).toArray)
+
+ // Input iterator to Python: input rows are grouped so we send them in batches to Python.
+ // For each row, add it to the queue.
+ val inputIterator = iter.grouped(100).map { inputRows =>
+ val toBePickled = inputRows.map { row =>
+ queue.add(row)
+ EvaluatePython.toJava(currentRow(row), schema)
+ }.toArray
+ pickle.dumps(toBePickled)
+ }
+
+ val context = TaskContext.get()
+
+ // Output iterator for results from Python.
+ val outputIterator = new PythonRunner(
+ udf.command,
+ udf.envVars,
+ udf.pythonIncludes,
+ udf.pythonExec,
+ udf.pythonVer,
+ udf.broadcastVars,
+ udf.accumulator,
+ bufferSize,
+ reuseWorker
+ ).compute(inputIterator, context.partitionId(), context)
+
+ val unpickle = new Unpickler
+ val row = new GenericMutableRow(1)
+ val joined = new JoinedRow
+ val resultProj = UnsafeProjection.create(output, output)
+
+ outputIterator.flatMap { pickedResult =>
+ val unpickledBatch = unpickle.loads(pickedResult)
+ unpickledBatch.asInstanceOf[java.util.ArrayList[Any]].asScala
+ }.map { result =>
+ row(0) = EvaluatePython.fromJava(result, udf.dataType)
+ resultProj(joined(queue.poll(), row))
+ }
+ }
+ }
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvaluatePython.scala
index bf62bb05c3..8c46516594 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvaluatePython.scala
@@ -15,106 +15,41 @@
* limitations under the License.
*/
-package org.apache.spark.sql.execution
+package org.apache.spark.sql.execution.python
import java.io.OutputStream
-import java.util.{List => JList, Map => JMap}
import scala.collection.JavaConverters._
-import net.razorvine.pickle._
+import net.razorvine.pickle.{IObjectPickler, Opcodes, Pickler}
-import org.apache.spark.{Accumulator, Logging => SparkLogging, TaskContext}
-import org.apache.spark.api.python.{PythonBroadcast, PythonRDD, PythonRunner, SerDeUtil}
-import org.apache.spark.broadcast.Broadcast
+import org.apache.spark.api.python.{PythonRDD, SerDeUtil}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.logical
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
-import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, GenericArrayData, MapData}
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String
/**
- * A serialized version of a Python lambda function. Suitable for use in a [[PythonRDD]].
+ * Evaluates a [[PythonUDF]], appending the result to the end of the input tuple.
*/
-private[spark] case class PythonUDF(
- name: String,
- command: Array[Byte],
- envVars: JMap[String, String],
- pythonIncludes: JList[String],
- pythonExec: String,
- pythonVer: String,
- broadcastVars: JList[Broadcast[PythonBroadcast]],
- accumulator: Accumulator[JList[Array[Byte]]],
- dataType: DataType,
- children: Seq[Expression]) extends Expression with Unevaluable with SparkLogging {
-
- override def toString: String = s"PythonUDF#$name(${children.mkString(",")})"
-
- override def nullable: Boolean = true
-}
+case class EvaluatePython(
+ udf: PythonUDF,
+ child: LogicalPlan,
+ resultAttribute: AttributeReference)
+ extends logical.UnaryNode {
-/**
- * Extracts PythonUDFs from operators, rewriting the query plan so that the UDF can be evaluated
- * alone in a batch.
- *
- * This has the limitation that the input to the Python UDF is not allowed include attributes from
- * multiple child operators.
- */
-private[spark] object ExtractPythonUDFs extends Rule[LogicalPlan] {
- def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
- // Skip EvaluatePython nodes.
- case plan: EvaluatePython => plan
-
- case plan: LogicalPlan if plan.resolved =>
- // Extract any PythonUDFs from the current operator.
- val udfs = plan.expressions.flatMap(_.collect { case udf: PythonUDF => udf })
- if (udfs.isEmpty) {
- // If there aren't any, we are done.
- plan
- } else {
- // Pick the UDF we are going to evaluate (TODO: Support evaluating multiple UDFs at a time)
- // If there is more than one, we will add another evaluation operator in a subsequent pass.
- udfs.find(_.resolved) match {
- case Some(udf) =>
- var evaluation: EvaluatePython = null
-
- // Rewrite the child that has the input required for the UDF
- val newChildren = plan.children.map { child =>
- // Check to make sure that the UDF can be evaluated with only the input of this child.
- // Other cases are disallowed as they are ambiguous or would require a cartesian
- // product.
- if (udf.references.subsetOf(child.outputSet)) {
- evaluation = EvaluatePython(udf, child)
- evaluation
- } else if (udf.references.intersect(child.outputSet).nonEmpty) {
- sys.error(s"Invalid PythonUDF $udf, requires attributes from more than one child.")
- } else {
- child
- }
- }
-
- assert(evaluation != null, "Unable to evaluate PythonUDF. Missing input attributes.")
-
- // Trim away the new UDF value if it was only used for filtering or something.
- logical.Project(
- plan.output,
- plan.transformExpressions {
- case p: PythonUDF if p.fastEquals(udf) => evaluation.resultAttribute
- }.withNewChildren(newChildren))
-
- case None =>
- // If there is no Python UDF that is resolved, skip this round.
- plan
- }
- }
- }
+ def output: Seq[Attribute] = child.output :+ resultAttribute
+
+ // References should not include the produced attribute.
+ override def references: AttributeSet = udf.references
}
+
object EvaluatePython {
def apply(udf: PythonUDF, child: LogicalPlan): EvaluatePython =
new EvaluatePython(udf, child, AttributeReference("pythonUDF", udf.dataType)())
@@ -221,7 +156,7 @@ object EvaluatePython {
if (array.length != fields.length) {
throw new IllegalStateException(
s"Input row doesn't have expected number of values required by the schema. " +
- s"${fields.length} fields are required while ${array.length} values are provided."
+ s"${fields.length} fields are required while ${array.length} values are provided."
)
}
new GenericInternalRow(array.zip(fields).map {
@@ -235,7 +170,6 @@ object EvaluatePython {
case (c, _) => null
}
-
private val module = "pyspark.sql.types"
/**
@@ -287,7 +221,7 @@ object EvaluatePython {
out.write(Opcodes.MARK)
var i = 0
- while (i < row.values.size) {
+ while (i < row.values.length) {
pickler.save(row.values(i))
i += 1
}
@@ -298,6 +232,7 @@ object EvaluatePython {
}
private[this] var registered = false
+
/**
* This should be called before trying to serialize any above classes un cluster mode,
* this should be put in the closure
@@ -324,91 +259,3 @@ object EvaluatePython {
}
}
}
-
-/**
- * Evaluates a [[PythonUDF]], appending the result to the end of the input tuple.
- */
-case class EvaluatePython(
- udf: PythonUDF,
- child: LogicalPlan,
- resultAttribute: AttributeReference)
- extends logical.UnaryNode {
-
- def output: Seq[Attribute] = child.output :+ resultAttribute
-
- // References should not include the produced attribute.
- override def references: AttributeSet = udf.references
-}
-
-/**
- * Uses PythonRDD to evaluate a [[PythonUDF]], one partition of tuples at a time.
- *
- * Python evaluation works by sending the necessary (projected) input data via a socket to an
- * external Python process, and combine the result from the Python process with the original row.
- *
- * For each row we send to Python, we also put it in a queue. For each output row from Python,
- * we drain the queue to find the original input row. Note that if the Python process is way too
- * slow, this could lead to the queue growing unbounded and eventually run out of memory.
- */
-case class BatchPythonEvaluation(udf: PythonUDF, output: Seq[Attribute], child: SparkPlan)
- extends SparkPlan {
-
- def children: Seq[SparkPlan] = child :: Nil
-
- protected override def doExecute(): RDD[InternalRow] = {
- val inputRDD = child.execute().map(_.copy())
- val bufferSize = inputRDD.conf.getInt("spark.buffer.size", 65536)
- val reuseWorker = inputRDD.conf.getBoolean("spark.python.worker.reuse", defaultValue = true)
-
- inputRDD.mapPartitions { iter =>
- EvaluatePython.registerPicklers() // register pickler for Row
-
- // The queue used to buffer input rows so we can drain it to
- // combine input with output from Python.
- val queue = new java.util.concurrent.ConcurrentLinkedQueue[InternalRow]()
-
- val pickle = new Pickler
- val currentRow = newMutableProjection(udf.children, child.output)()
- val fields = udf.children.map(_.dataType)
- val schema = new StructType(fields.map(t => new StructField("", t, true)).toArray)
-
- // Input iterator to Python: input rows are grouped so we send them in batches to Python.
- // For each row, add it to the queue.
- val inputIterator = iter.grouped(100).map { inputRows =>
- val toBePickled = inputRows.map { row =>
- queue.add(row)
- EvaluatePython.toJava(currentRow(row), schema)
- }.toArray
- pickle.dumps(toBePickled)
- }
-
- val context = TaskContext.get()
-
- // Output iterator for results from Python.
- val outputIterator = new PythonRunner(
- udf.command,
- udf.envVars,
- udf.pythonIncludes,
- udf.pythonExec,
- udf.pythonVer,
- udf.broadcastVars,
- udf.accumulator,
- bufferSize,
- reuseWorker
- ).compute(inputIterator, context.partitionId(), context)
-
- val unpickle = new Unpickler
- val row = new GenericMutableRow(1)
- val joined = new JoinedRow
- val resultProj = UnsafeProjection.create(output, output)
-
- outputIterator.flatMap { pickedResult =>
- val unpickledBatch = unpickle.loads(pickedResult)
- unpickledBatch.asInstanceOf[java.util.ArrayList[Any]].asScala
- }.map { result =>
- row(0) = EvaluatePython.fromJava(result, udf.dataType)
- resultProj(joined(queue.poll(), row))
- }
- }
- }
-}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala
new file mode 100644
index 0000000000..6e76e9569f
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala
@@ -0,0 +1,79 @@
+/*
+* Licensed to the Apache Software Foundation (ASF) under one or more
+* contributor license agreements. See the NOTICE file distributed with
+* this work for additional information regarding copyright ownership.
+* The ASF licenses this file to You under the Apache License, Version 2.0
+* (the "License"); you may not use this file except in compliance with
+* the License. You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*/
+
+package org.apache.spark.sql.execution.python
+
+import org.apache.spark.sql.catalyst.plans.logical
+import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
+import org.apache.spark.sql.catalyst.rules.Rule
+
+/**
+ * Extracts PythonUDFs from operators, rewriting the query plan so that the UDF can be evaluated
+ * alone in a batch.
+ *
+ * This has the limitation that the input to the Python UDF is not allowed include attributes from
+ * multiple child operators.
+ */
+private[spark] object ExtractPythonUDFs extends Rule[LogicalPlan] {
+ def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
+ // Skip EvaluatePython nodes.
+ case plan: EvaluatePython => plan
+
+ case plan: LogicalPlan if plan.resolved =>
+ // Extract any PythonUDFs from the current operator.
+ val udfs = plan.expressions.flatMap(_.collect { case udf: PythonUDF => udf })
+ if (udfs.isEmpty) {
+ // If there aren't any, we are done.
+ plan
+ } else {
+ // Pick the UDF we are going to evaluate (TODO: Support evaluating multiple UDFs at a time)
+ // If there is more than one, we will add another evaluation operator in a subsequent pass.
+ udfs.find(_.resolved) match {
+ case Some(udf) =>
+ var evaluation: EvaluatePython = null
+
+ // Rewrite the child that has the input required for the UDF
+ val newChildren = plan.children.map { child =>
+ // Check to make sure that the UDF can be evaluated with only the input of this child.
+ // Other cases are disallowed as they are ambiguous or would require a cartesian
+ // product.
+ if (udf.references.subsetOf(child.outputSet)) {
+ evaluation = EvaluatePython(udf, child)
+ evaluation
+ } else if (udf.references.intersect(child.outputSet).nonEmpty) {
+ sys.error(s"Invalid PythonUDF $udf, requires attributes from more than one child.")
+ } else {
+ child
+ }
+ }
+
+ assert(evaluation != null, "Unable to evaluate PythonUDF. Missing input attributes.")
+
+ // Trim away the new UDF value if it was only used for filtering or something.
+ logical.Project(
+ plan.output,
+ plan.transformExpressions {
+ case p: PythonUDF if p.fastEquals(udf) => evaluation.resultAttribute
+ }.withNewChildren(newChildren))
+
+ case None =>
+ // If there is no Python UDF that is resolved, skip this round.
+ plan
+ }
+ }
+ }
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDF.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDF.scala
new file mode 100644
index 0000000000..0e53a0c473
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDF.scala
@@ -0,0 +1,44 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.python
+
+import org.apache.spark.{Accumulator, Logging}
+import org.apache.spark.api.python.PythonBroadcast
+import org.apache.spark.broadcast.Broadcast
+import org.apache.spark.sql.catalyst.expressions.{Expression, Unevaluable}
+import org.apache.spark.sql.types.DataType
+
+/**
+ * A serialized version of a Python lambda function.
+ */
+case class PythonUDF(
+ name: String,
+ command: Array[Byte],
+ envVars: java.util.Map[String, String],
+ pythonIncludes: java.util.List[String],
+ pythonExec: String,
+ pythonVer: String,
+ broadcastVars: java.util.List[Broadcast[PythonBroadcast]],
+ accumulator: Accumulator[java.util.List[Array[Byte]]],
+ dataType: DataType,
+ children: Seq[Expression]) extends Expression with Unevaluable with Logging {
+
+ override def toString: String = s"PythonUDF#$name(${children.mkString(",")})"
+
+ override def nullable: Boolean = true
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonFunction.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonFunction.scala
new file mode 100644
index 0000000000..79ac1c85c0
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonFunction.scala
@@ -0,0 +1,51 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.python
+
+import org.apache.spark.Accumulator
+import org.apache.spark.api.python.PythonBroadcast
+import org.apache.spark.broadcast.Broadcast
+import org.apache.spark.sql.catalyst.expressions.Expression
+import org.apache.spark.sql.Column
+import org.apache.spark.sql.types.DataType
+
+/**
+ * A user-defined Python function. This is used by the Python API.
+ */
+case class UserDefinedPythonFunction(
+ name: String,
+ command: Array[Byte],
+ envVars: java.util.Map[String, String],
+ pythonIncludes: java.util.List[String],
+ pythonExec: String,
+ pythonVer: String,
+ broadcastVars: java.util.List[Broadcast[PythonBroadcast]],
+ accumulator: Accumulator[java.util.List[Array[Byte]]],
+ dataType: DataType) {
+
+ def builder(e: Seq[Expression]): PythonUDF = {
+ PythonUDF(name, command, envVars, pythonIncludes, pythonExec, pythonVer, broadcastVars,
+ accumulator, dataType, e)
+ }
+
+ /** Returns a [[Column]] that will evaluate to calling this UDF with the given input. */
+ def apply(exprs: Column*): Column = {
+ val udf = builder(exprs.map(_.expr))
+ Column(udf)
+ }
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/UserDefinedFunction.scala b/sql/core/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala
index 2fb3bf07aa..bd35d19aa2 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/UserDefinedFunction.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala
@@ -15,16 +15,12 @@
* limitations under the License.
*/
-package org.apache.spark.sql
+package org.apache.spark.sql.expressions
-import java.util.{List => JList, Map => JMap}
-
-import org.apache.spark.Accumulator
import org.apache.spark.annotation.Experimental
-import org.apache.spark.api.python.PythonBroadcast
-import org.apache.spark.broadcast.Broadcast
-import org.apache.spark.sql.catalyst.expressions.{Expression, ScalaUDF}
-import org.apache.spark.sql.execution.PythonUDF
+import org.apache.spark.sql.catalyst.expressions.ScalaUDF
+import org.apache.spark.sql.Column
+import org.apache.spark.sql.functions
import org.apache.spark.sql.types.DataType
/**
@@ -50,30 +46,3 @@ case class UserDefinedFunction protected[sql] (
Column(ScalaUDF(f, dataType, exprs.map(_.expr), inputTypes.getOrElse(Nil)))
}
}
-
-/**
- * A user-defined Python function. To create one, use the `pythonUDF` functions in [[functions]].
- * This is used by Python API.
- */
-private[sql] case class UserDefinedPythonFunction(
- name: String,
- command: Array[Byte],
- envVars: JMap[String, String],
- pythonIncludes: JList[String],
- pythonExec: String,
- pythonVer: String,
- broadcastVars: JList[Broadcast[PythonBroadcast]],
- accumulator: Accumulator[JList[Array[Byte]]],
- dataType: DataType) {
-
- def builder(e: Seq[Expression]): PythonUDF = {
- PythonUDF(name, command, envVars, pythonIncludes, pythonExec, pythonVer, broadcastVars,
- accumulator, dataType, e)
- }
-
- /** Returns a [[Column]] that will evaluate to calling this UDF with the given input. */
- def apply(exprs: Column*): Column = {
- val udf = builder(exprs.map(_.expr))
- Column(udf)
- }
-}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
index d34d377ab6..e4ab6b4f23 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
@@ -28,6 +28,7 @@ import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate._
import org.apache.spark.sql.catalyst.plans.logical.BroadcastHint
+import org.apache.spark.sql.expressions.UserDefinedFunction
import org.apache.spark.sql.types._
import org.apache.spark.util.Utils
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala
index 2433b54ffc..ac174aa6bf 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala
@@ -465,7 +465,7 @@ class HiveContext private[hive](
catalog.ParquetConversions ::
catalog.CreateTables ::
catalog.PreInsertionCasts ::
- ExtractPythonUDFs ::
+ python.ExtractPythonUDFs ::
PreInsertCastAndRename ::
(if (conf.runSQLOnFile) new ResolveDataSource(self) :: Nil else Nil)