diff options
author | Reynold Xin <rxin@databricks.com> | 2016-02-13 21:06:31 -0800 |
---|---|---|
committer | Reynold Xin <rxin@databricks.com> | 2016-02-13 21:06:31 -0800 |
commit | 354d4c24be892271bd9a9eab6ceedfbc5d671c9c (patch) | |
tree | c0503ad0c303e6db4882bdbfa356fb78a8dd32fb /sql | |
parent | 388cd9ea8db2e438ebef9dfb894298f843438c43 (diff) | |
download | spark-354d4c24be892271bd9a9eab6ceedfbc5d671c9c.tar.gz spark-354d4c24be892271bd9a9eab6ceedfbc5d671c9c.tar.bz2 spark-354d4c24be892271bd9a9eab6ceedfbc5d671c9c.zip |
[SPARK-13296][SQL] Move UserDefinedFunction into sql.expressions.
This pull request has the following changes:
1. Moved UserDefinedFunction into expressions package. This is more consistent with how we structure the packages for window functions and UDAFs.
2. Moved UserDefinedPythonFunction into execution.python package, so we don't have a random private class in the top level sql package.
3. Move everything in execution/python.scala into the newly created execution.python package.
Most of the diffs are just straight copy-paste.
Author: Reynold Xin <rxin@databricks.com>
Closes #11181 from rxin/SPARK-13296.
Diffstat (limited to 'sql')
12 files changed, 309 insertions, 212 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala index c5b2b7d118..76c09a285d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala @@ -36,9 +36,10 @@ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.optimizer.CombineUnions import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ -import org.apache.spark.sql.execution.{EvaluatePython, ExplainCommand, FileRelation, LogicalRDD, Queryable, QueryExecution, SQLExecution} +import org.apache.spark.sql.execution.{ExplainCommand, FileRelation, LogicalRDD, Queryable, QueryExecution, SQLExecution} import org.apache.spark.sql.execution.datasources.{CreateTableUsingAsSelect, LogicalRelation} import org.apache.spark.sql.execution.datasources.json.JacksonGenerator +import org.apache.spark.sql.execution.python.EvaluatePython import org.apache.spark.sql.types._ import org.apache.spark.storage.StorageLevel import org.apache.spark.util.Utils diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index d58b99655c..c7d1096a13 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -193,7 +193,7 @@ class SQLContext private[sql]( protected[sql] lazy val analyzer: Analyzer = new Analyzer(catalog, functionRegistry, conf) { override val extendedResolutionRules = - ExtractPythonUDFs :: + python.ExtractPythonUDFs :: PreInsertCastAndRename :: (if (conf.runSQLOnFile) new ResolveDataSource(self) :: Nil else Nil) @@ -915,7 +915,7 @@ class SQLContext private[sql]( rdd: RDD[Array[Any]], schema: StructType): DataFrame = { - val rowRdd = rdd.map(r => EvaluatePython.fromJava(r, schema).asInstanceOf[InternalRow]) + val rowRdd = rdd.map(r => python.EvaluatePython.fromJava(r, schema).asInstanceOf[InternalRow]) DataFrame(this, LogicalRDD(schema.toAttributes, rowRdd)(self)) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala b/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala index f87a88d497..ecfc170bee 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala @@ -25,7 +25,8 @@ import org.apache.spark.sql.api.java._ import org.apache.spark.sql.catalyst.ScalaReflection import org.apache.spark.sql.catalyst.expressions.{Expression, ScalaUDF} import org.apache.spark.sql.execution.aggregate.ScalaUDAF -import org.apache.spark.sql.expressions.UserDefinedAggregateFunction +import org.apache.spark.sql.execution.python.UserDefinedPythonFunction +import org.apache.spark.sql.expressions.{UserDefinedAggregateFunction, UserDefinedFunction} import org.apache.spark.sql.types.DataType /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 598ddd7161..73fd22b38e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -369,8 +369,8 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { case logical.RepartitionByExpression(expressions, child, nPartitions) => execution.Exchange(HashPartitioning( expressions, nPartitions.getOrElse(numPartitions)), planLater(child)) :: Nil - case e @ EvaluatePython(udf, child, _) => - BatchPythonEvaluation(udf, e.output, planLater(child)) :: Nil + case e @ python.EvaluatePython(udf, child, _) => + python.BatchPythonEvaluation(udf, e.output, planLater(child)) :: Nil case LogicalRDD(output, rdd) => PhysicalRDD(output, rdd, "ExistingRDD") :: Nil case BroadcastHint(child) => planLater(child) :: Nil case _ => Nil diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchPythonEvaluation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchPythonEvaluation.scala new file mode 100644 index 0000000000..00df019527 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchPythonEvaluation.scala @@ -0,0 +1,104 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package org.apache.spark.sql.execution.python + +import scala.collection.JavaConverters._ + +import net.razorvine.pickle.{Pickler, Unpickler} + +import org.apache.spark.TaskContext +import org.apache.spark.api.python.PythonRunner +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.{Attribute, GenericMutableRow, JoinedRow, UnsafeProjection} +import org.apache.spark.sql.execution.SparkPlan +import org.apache.spark.sql.types.{StructField, StructType} + + +/** + * A physical plan that evalutes a [[PythonUDF]], one partition of tuples at a time. + * + * Python evaluation works by sending the necessary (projected) input data via a socket to an + * external Python process, and combine the result from the Python process with the original row. + * + * For each row we send to Python, we also put it in a queue. For each output row from Python, + * we drain the queue to find the original input row. Note that if the Python process is way too + * slow, this could lead to the queue growing unbounded and eventually run out of memory. + */ +case class BatchPythonEvaluation(udf: PythonUDF, output: Seq[Attribute], child: SparkPlan) + extends SparkPlan { + + def children: Seq[SparkPlan] = child :: Nil + + protected override def doExecute(): RDD[InternalRow] = { + val inputRDD = child.execute().map(_.copy()) + val bufferSize = inputRDD.conf.getInt("spark.buffer.size", 65536) + val reuseWorker = inputRDD.conf.getBoolean("spark.python.worker.reuse", defaultValue = true) + + inputRDD.mapPartitions { iter => + EvaluatePython.registerPicklers() // register pickler for Row + + // The queue used to buffer input rows so we can drain it to + // combine input with output from Python. + val queue = new java.util.concurrent.ConcurrentLinkedQueue[InternalRow]() + + val pickle = new Pickler + val currentRow = newMutableProjection(udf.children, child.output)() + val fields = udf.children.map(_.dataType) + val schema = new StructType(fields.map(t => new StructField("", t, true)).toArray) + + // Input iterator to Python: input rows are grouped so we send them in batches to Python. + // For each row, add it to the queue. + val inputIterator = iter.grouped(100).map { inputRows => + val toBePickled = inputRows.map { row => + queue.add(row) + EvaluatePython.toJava(currentRow(row), schema) + }.toArray + pickle.dumps(toBePickled) + } + + val context = TaskContext.get() + + // Output iterator for results from Python. + val outputIterator = new PythonRunner( + udf.command, + udf.envVars, + udf.pythonIncludes, + udf.pythonExec, + udf.pythonVer, + udf.broadcastVars, + udf.accumulator, + bufferSize, + reuseWorker + ).compute(inputIterator, context.partitionId(), context) + + val unpickle = new Unpickler + val row = new GenericMutableRow(1) + val joined = new JoinedRow + val resultProj = UnsafeProjection.create(output, output) + + outputIterator.flatMap { pickedResult => + val unpickledBatch = unpickle.loads(pickedResult) + unpickledBatch.asInstanceOf[java.util.ArrayList[Any]].asScala + }.map { result => + row(0) = EvaluatePython.fromJava(result, udf.dataType) + resultProj(joined(queue.poll(), row)) + } + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvaluatePython.scala index bf62bb05c3..8c46516594 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvaluatePython.scala @@ -15,106 +15,41 @@ * limitations under the License. */ -package org.apache.spark.sql.execution +package org.apache.spark.sql.execution.python import java.io.OutputStream -import java.util.{List => JList, Map => JMap} import scala.collection.JavaConverters._ -import net.razorvine.pickle._ +import net.razorvine.pickle.{IObjectPickler, Opcodes, Pickler} -import org.apache.spark.{Accumulator, Logging => SparkLogging, TaskContext} -import org.apache.spark.api.python.{PythonBroadcast, PythonRDD, PythonRunner, SerDeUtil} -import org.apache.spark.broadcast.Broadcast +import org.apache.spark.api.python.{PythonRDD, SerDeUtil} import org.apache.spark.rdd.RDD import org.apache.spark.sql.DataFrame import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan -import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, GenericArrayData, MapData} import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String /** - * A serialized version of a Python lambda function. Suitable for use in a [[PythonRDD]]. + * Evaluates a [[PythonUDF]], appending the result to the end of the input tuple. */ -private[spark] case class PythonUDF( - name: String, - command: Array[Byte], - envVars: JMap[String, String], - pythonIncludes: JList[String], - pythonExec: String, - pythonVer: String, - broadcastVars: JList[Broadcast[PythonBroadcast]], - accumulator: Accumulator[JList[Array[Byte]]], - dataType: DataType, - children: Seq[Expression]) extends Expression with Unevaluable with SparkLogging { - - override def toString: String = s"PythonUDF#$name(${children.mkString(",")})" - - override def nullable: Boolean = true -} +case class EvaluatePython( + udf: PythonUDF, + child: LogicalPlan, + resultAttribute: AttributeReference) + extends logical.UnaryNode { -/** - * Extracts PythonUDFs from operators, rewriting the query plan so that the UDF can be evaluated - * alone in a batch. - * - * This has the limitation that the input to the Python UDF is not allowed include attributes from - * multiple child operators. - */ -private[spark] object ExtractPythonUDFs extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { - // Skip EvaluatePython nodes. - case plan: EvaluatePython => plan - - case plan: LogicalPlan if plan.resolved => - // Extract any PythonUDFs from the current operator. - val udfs = plan.expressions.flatMap(_.collect { case udf: PythonUDF => udf }) - if (udfs.isEmpty) { - // If there aren't any, we are done. - plan - } else { - // Pick the UDF we are going to evaluate (TODO: Support evaluating multiple UDFs at a time) - // If there is more than one, we will add another evaluation operator in a subsequent pass. - udfs.find(_.resolved) match { - case Some(udf) => - var evaluation: EvaluatePython = null - - // Rewrite the child that has the input required for the UDF - val newChildren = plan.children.map { child => - // Check to make sure that the UDF can be evaluated with only the input of this child. - // Other cases are disallowed as they are ambiguous or would require a cartesian - // product. - if (udf.references.subsetOf(child.outputSet)) { - evaluation = EvaluatePython(udf, child) - evaluation - } else if (udf.references.intersect(child.outputSet).nonEmpty) { - sys.error(s"Invalid PythonUDF $udf, requires attributes from more than one child.") - } else { - child - } - } - - assert(evaluation != null, "Unable to evaluate PythonUDF. Missing input attributes.") - - // Trim away the new UDF value if it was only used for filtering or something. - logical.Project( - plan.output, - plan.transformExpressions { - case p: PythonUDF if p.fastEquals(udf) => evaluation.resultAttribute - }.withNewChildren(newChildren)) - - case None => - // If there is no Python UDF that is resolved, skip this round. - plan - } - } - } + def output: Seq[Attribute] = child.output :+ resultAttribute + + // References should not include the produced attribute. + override def references: AttributeSet = udf.references } + object EvaluatePython { def apply(udf: PythonUDF, child: LogicalPlan): EvaluatePython = new EvaluatePython(udf, child, AttributeReference("pythonUDF", udf.dataType)()) @@ -221,7 +156,7 @@ object EvaluatePython { if (array.length != fields.length) { throw new IllegalStateException( s"Input row doesn't have expected number of values required by the schema. " + - s"${fields.length} fields are required while ${array.length} values are provided." + s"${fields.length} fields are required while ${array.length} values are provided." ) } new GenericInternalRow(array.zip(fields).map { @@ -235,7 +170,6 @@ object EvaluatePython { case (c, _) => null } - private val module = "pyspark.sql.types" /** @@ -287,7 +221,7 @@ object EvaluatePython { out.write(Opcodes.MARK) var i = 0 - while (i < row.values.size) { + while (i < row.values.length) { pickler.save(row.values(i)) i += 1 } @@ -298,6 +232,7 @@ object EvaluatePython { } private[this] var registered = false + /** * This should be called before trying to serialize any above classes un cluster mode, * this should be put in the closure @@ -324,91 +259,3 @@ object EvaluatePython { } } } - -/** - * Evaluates a [[PythonUDF]], appending the result to the end of the input tuple. - */ -case class EvaluatePython( - udf: PythonUDF, - child: LogicalPlan, - resultAttribute: AttributeReference) - extends logical.UnaryNode { - - def output: Seq[Attribute] = child.output :+ resultAttribute - - // References should not include the produced attribute. - override def references: AttributeSet = udf.references -} - -/** - * Uses PythonRDD to evaluate a [[PythonUDF]], one partition of tuples at a time. - * - * Python evaluation works by sending the necessary (projected) input data via a socket to an - * external Python process, and combine the result from the Python process with the original row. - * - * For each row we send to Python, we also put it in a queue. For each output row from Python, - * we drain the queue to find the original input row. Note that if the Python process is way too - * slow, this could lead to the queue growing unbounded and eventually run out of memory. - */ -case class BatchPythonEvaluation(udf: PythonUDF, output: Seq[Attribute], child: SparkPlan) - extends SparkPlan { - - def children: Seq[SparkPlan] = child :: Nil - - protected override def doExecute(): RDD[InternalRow] = { - val inputRDD = child.execute().map(_.copy()) - val bufferSize = inputRDD.conf.getInt("spark.buffer.size", 65536) - val reuseWorker = inputRDD.conf.getBoolean("spark.python.worker.reuse", defaultValue = true) - - inputRDD.mapPartitions { iter => - EvaluatePython.registerPicklers() // register pickler for Row - - // The queue used to buffer input rows so we can drain it to - // combine input with output from Python. - val queue = new java.util.concurrent.ConcurrentLinkedQueue[InternalRow]() - - val pickle = new Pickler - val currentRow = newMutableProjection(udf.children, child.output)() - val fields = udf.children.map(_.dataType) - val schema = new StructType(fields.map(t => new StructField("", t, true)).toArray) - - // Input iterator to Python: input rows are grouped so we send them in batches to Python. - // For each row, add it to the queue. - val inputIterator = iter.grouped(100).map { inputRows => - val toBePickled = inputRows.map { row => - queue.add(row) - EvaluatePython.toJava(currentRow(row), schema) - }.toArray - pickle.dumps(toBePickled) - } - - val context = TaskContext.get() - - // Output iterator for results from Python. - val outputIterator = new PythonRunner( - udf.command, - udf.envVars, - udf.pythonIncludes, - udf.pythonExec, - udf.pythonVer, - udf.broadcastVars, - udf.accumulator, - bufferSize, - reuseWorker - ).compute(inputIterator, context.partitionId(), context) - - val unpickle = new Unpickler - val row = new GenericMutableRow(1) - val joined = new JoinedRow - val resultProj = UnsafeProjection.create(output, output) - - outputIterator.flatMap { pickedResult => - val unpickledBatch = unpickle.loads(pickedResult) - unpickledBatch.asInstanceOf[java.util.ArrayList[Any]].asScala - }.map { result => - row(0) = EvaluatePython.fromJava(result, udf.dataType) - resultProj(joined(queue.poll(), row)) - } - } - } -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala new file mode 100644 index 0000000000..6e76e9569f --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala @@ -0,0 +1,79 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package org.apache.spark.sql.execution.python + +import org.apache.spark.sql.catalyst.plans.logical +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.rules.Rule + +/** + * Extracts PythonUDFs from operators, rewriting the query plan so that the UDF can be evaluated + * alone in a batch. + * + * This has the limitation that the input to the Python UDF is not allowed include attributes from + * multiple child operators. + */ +private[spark] object ExtractPythonUDFs extends Rule[LogicalPlan] { + def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { + // Skip EvaluatePython nodes. + case plan: EvaluatePython => plan + + case plan: LogicalPlan if plan.resolved => + // Extract any PythonUDFs from the current operator. + val udfs = plan.expressions.flatMap(_.collect { case udf: PythonUDF => udf }) + if (udfs.isEmpty) { + // If there aren't any, we are done. + plan + } else { + // Pick the UDF we are going to evaluate (TODO: Support evaluating multiple UDFs at a time) + // If there is more than one, we will add another evaluation operator in a subsequent pass. + udfs.find(_.resolved) match { + case Some(udf) => + var evaluation: EvaluatePython = null + + // Rewrite the child that has the input required for the UDF + val newChildren = plan.children.map { child => + // Check to make sure that the UDF can be evaluated with only the input of this child. + // Other cases are disallowed as they are ambiguous or would require a cartesian + // product. + if (udf.references.subsetOf(child.outputSet)) { + evaluation = EvaluatePython(udf, child) + evaluation + } else if (udf.references.intersect(child.outputSet).nonEmpty) { + sys.error(s"Invalid PythonUDF $udf, requires attributes from more than one child.") + } else { + child + } + } + + assert(evaluation != null, "Unable to evaluate PythonUDF. Missing input attributes.") + + // Trim away the new UDF value if it was only used for filtering or something. + logical.Project( + plan.output, + plan.transformExpressions { + case p: PythonUDF if p.fastEquals(udf) => evaluation.resultAttribute + }.withNewChildren(newChildren)) + + case None => + // If there is no Python UDF that is resolved, skip this round. + plan + } + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDF.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDF.scala new file mode 100644 index 0000000000..0e53a0c473 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDF.scala @@ -0,0 +1,44 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.python + +import org.apache.spark.{Accumulator, Logging} +import org.apache.spark.api.python.PythonBroadcast +import org.apache.spark.broadcast.Broadcast +import org.apache.spark.sql.catalyst.expressions.{Expression, Unevaluable} +import org.apache.spark.sql.types.DataType + +/** + * A serialized version of a Python lambda function. + */ +case class PythonUDF( + name: String, + command: Array[Byte], + envVars: java.util.Map[String, String], + pythonIncludes: java.util.List[String], + pythonExec: String, + pythonVer: String, + broadcastVars: java.util.List[Broadcast[PythonBroadcast]], + accumulator: Accumulator[java.util.List[Array[Byte]]], + dataType: DataType, + children: Seq[Expression]) extends Expression with Unevaluable with Logging { + + override def toString: String = s"PythonUDF#$name(${children.mkString(",")})" + + override def nullable: Boolean = true +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonFunction.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonFunction.scala new file mode 100644 index 0000000000..79ac1c85c0 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonFunction.scala @@ -0,0 +1,51 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.python + +import org.apache.spark.Accumulator +import org.apache.spark.api.python.PythonBroadcast +import org.apache.spark.broadcast.Broadcast +import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.Column +import org.apache.spark.sql.types.DataType + +/** + * A user-defined Python function. This is used by the Python API. + */ +case class UserDefinedPythonFunction( + name: String, + command: Array[Byte], + envVars: java.util.Map[String, String], + pythonIncludes: java.util.List[String], + pythonExec: String, + pythonVer: String, + broadcastVars: java.util.List[Broadcast[PythonBroadcast]], + accumulator: Accumulator[java.util.List[Array[Byte]]], + dataType: DataType) { + + def builder(e: Seq[Expression]): PythonUDF = { + PythonUDF(name, command, envVars, pythonIncludes, pythonExec, pythonVer, broadcastVars, + accumulator, dataType, e) + } + + /** Returns a [[Column]] that will evaluate to calling this UDF with the given input. */ + def apply(exprs: Column*): Column = { + val udf = builder(exprs.map(_.expr)) + Column(udf) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/UserDefinedFunction.scala b/sql/core/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala index 2fb3bf07aa..bd35d19aa2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/UserDefinedFunction.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala @@ -15,16 +15,12 @@ * limitations under the License. */ -package org.apache.spark.sql +package org.apache.spark.sql.expressions -import java.util.{List => JList, Map => JMap} - -import org.apache.spark.Accumulator import org.apache.spark.annotation.Experimental -import org.apache.spark.api.python.PythonBroadcast -import org.apache.spark.broadcast.Broadcast -import org.apache.spark.sql.catalyst.expressions.{Expression, ScalaUDF} -import org.apache.spark.sql.execution.PythonUDF +import org.apache.spark.sql.catalyst.expressions.ScalaUDF +import org.apache.spark.sql.Column +import org.apache.spark.sql.functions import org.apache.spark.sql.types.DataType /** @@ -50,30 +46,3 @@ case class UserDefinedFunction protected[sql] ( Column(ScalaUDF(f, dataType, exprs.map(_.expr), inputTypes.getOrElse(Nil))) } } - -/** - * A user-defined Python function. To create one, use the `pythonUDF` functions in [[functions]]. - * This is used by Python API. - */ -private[sql] case class UserDefinedPythonFunction( - name: String, - command: Array[Byte], - envVars: JMap[String, String], - pythonIncludes: JList[String], - pythonExec: String, - pythonVer: String, - broadcastVars: JList[Broadcast[PythonBroadcast]], - accumulator: Accumulator[JList[Array[Byte]]], - dataType: DataType) { - - def builder(e: Seq[Expression]): PythonUDF = { - PythonUDF(name, command, envVars, pythonIncludes, pythonExec, pythonVer, broadcastVars, - accumulator, dataType, e) - } - - /** Returns a [[Column]] that will evaluate to calling this UDF with the given input. */ - def apply(exprs: Column*): Column = { - val udf = builder(exprs.map(_.expr)) - Column(udf) - } -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index d34d377ab6..e4ab6b4f23 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -28,6 +28,7 @@ import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.plans.logical.BroadcastHint +import org.apache.spark.sql.expressions.UserDefinedFunction import org.apache.spark.sql.types._ import org.apache.spark.util.Utils diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala index 2433b54ffc..ac174aa6bf 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala @@ -465,7 +465,7 @@ class HiveContext private[hive]( catalog.ParquetConversions :: catalog.CreateTables :: catalog.PreInsertionCasts :: - ExtractPythonUDFs :: + python.ExtractPythonUDFs :: PreInsertCastAndRename :: (if (conf.runSQLOnFile) new ResolveDataSource(self) :: Nil else Nil) |