From 354d4c24be892271bd9a9eab6ceedfbc5d671c9c Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Sat, 13 Feb 2016 21:06:31 -0800 Subject: [SPARK-13296][SQL] Move UserDefinedFunction into sql.expressions. This pull request has the following changes: 1. Moved UserDefinedFunction into expressions package. This is more consistent with how we structure the packages for window functions and UDAFs. 2. Moved UserDefinedPythonFunction into execution.python package, so we don't have a random private class in the top level sql package. 3. Move everything in execution/python.scala into the newly created execution.python package. Most of the diffs are just straight copy-paste. Author: Reynold Xin Closes #11181 from rxin/SPARK-13296. --- .../scala/org/apache/spark/sql/DataFrame.scala | 3 +- .../scala/org/apache/spark/sql/SQLContext.scala | 4 +- .../org/apache/spark/sql/UDFRegistration.scala | 3 +- .../org/apache/spark/sql/UserDefinedFunction.scala | 79 ---- .../spark/sql/execution/SparkStrategies.scala | 4 +- .../org/apache/spark/sql/execution/python.scala | 414 --------------------- .../execution/python/BatchPythonEvaluation.scala | 104 ++++++ .../sql/execution/python/EvaluatePython.scala | 261 +++++++++++++ .../sql/execution/python/ExtractPythonUDFs.scala | 79 ++++ .../spark/sql/execution/python/PythonUDF.scala | 44 +++ .../python/UserDefinedPythonFunction.scala | 51 +++ .../sql/expressions/UserDefinedFunction.scala | 48 +++ .../scala/org/apache/spark/sql/functions.scala | 1 + .../org/apache/spark/sql/hive/HiveContext.scala | 2 +- 14 files changed, 597 insertions(+), 500 deletions(-) delete mode 100644 sql/core/src/main/scala/org/apache/spark/sql/UserDefinedFunction.scala delete mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/python.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchPythonEvaluation.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvaluatePython.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDF.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonFunction.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala (limited to 'sql') diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala index c5b2b7d118..76c09a285d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala @@ -36,9 +36,10 @@ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.optimizer.CombineUnions import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ -import org.apache.spark.sql.execution.{EvaluatePython, ExplainCommand, FileRelation, LogicalRDD, Queryable, QueryExecution, SQLExecution} +import org.apache.spark.sql.execution.{ExplainCommand, FileRelation, LogicalRDD, Queryable, QueryExecution, SQLExecution} import org.apache.spark.sql.execution.datasources.{CreateTableUsingAsSelect, LogicalRelation} import org.apache.spark.sql.execution.datasources.json.JacksonGenerator +import org.apache.spark.sql.execution.python.EvaluatePython import org.apache.spark.sql.types._ import org.apache.spark.storage.StorageLevel import org.apache.spark.util.Utils diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index d58b99655c..c7d1096a13 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -193,7 +193,7 @@ class SQLContext private[sql]( protected[sql] lazy val analyzer: Analyzer = new Analyzer(catalog, functionRegistry, conf) { override val extendedResolutionRules = - ExtractPythonUDFs :: + python.ExtractPythonUDFs :: PreInsertCastAndRename :: (if (conf.runSQLOnFile) new ResolveDataSource(self) :: Nil else Nil) @@ -915,7 +915,7 @@ class SQLContext private[sql]( rdd: RDD[Array[Any]], schema: StructType): DataFrame = { - val rowRdd = rdd.map(r => EvaluatePython.fromJava(r, schema).asInstanceOf[InternalRow]) + val rowRdd = rdd.map(r => python.EvaluatePython.fromJava(r, schema).asInstanceOf[InternalRow]) DataFrame(this, LogicalRDD(schema.toAttributes, rowRdd)(self)) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala b/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala index f87a88d497..ecfc170bee 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala @@ -25,7 +25,8 @@ import org.apache.spark.sql.api.java._ import org.apache.spark.sql.catalyst.ScalaReflection import org.apache.spark.sql.catalyst.expressions.{Expression, ScalaUDF} import org.apache.spark.sql.execution.aggregate.ScalaUDAF -import org.apache.spark.sql.expressions.UserDefinedAggregateFunction +import org.apache.spark.sql.execution.python.UserDefinedPythonFunction +import org.apache.spark.sql.expressions.{UserDefinedAggregateFunction, UserDefinedFunction} import org.apache.spark.sql.types.DataType /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/UserDefinedFunction.scala b/sql/core/src/main/scala/org/apache/spark/sql/UserDefinedFunction.scala deleted file mode 100644 index 2fb3bf07aa..0000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/UserDefinedFunction.scala +++ /dev/null @@ -1,79 +0,0 @@ -/* -* Licensed to the Apache Software Foundation (ASF) under one or more -* contributor license agreements. See the NOTICE file distributed with -* this work for additional information regarding copyright ownership. -* The ASF licenses this file to You under the Apache License, Version 2.0 -* (the "License"); you may not use this file except in compliance with -* the License. You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*/ - -package org.apache.spark.sql - -import java.util.{List => JList, Map => JMap} - -import org.apache.spark.Accumulator -import org.apache.spark.annotation.Experimental -import org.apache.spark.api.python.PythonBroadcast -import org.apache.spark.broadcast.Broadcast -import org.apache.spark.sql.catalyst.expressions.{Expression, ScalaUDF} -import org.apache.spark.sql.execution.PythonUDF -import org.apache.spark.sql.types.DataType - -/** - * A user-defined function. To create one, use the `udf` functions in [[functions]]. - * As an example: - * {{{ - * // Defined a UDF that returns true or false based on some numeric score. - * val predict = udf((score: Double) => if (score > 0.5) true else false) - * - * // Projects a column that adds a prediction column based on the score column. - * df.select( predict(df("score")) ) - * }}} - * - * @since 1.3.0 - */ -@Experimental -case class UserDefinedFunction protected[sql] ( - f: AnyRef, - dataType: DataType, - inputTypes: Option[Seq[DataType]]) { - - def apply(exprs: Column*): Column = { - Column(ScalaUDF(f, dataType, exprs.map(_.expr), inputTypes.getOrElse(Nil))) - } -} - -/** - * A user-defined Python function. To create one, use the `pythonUDF` functions in [[functions]]. - * This is used by Python API. - */ -private[sql] case class UserDefinedPythonFunction( - name: String, - command: Array[Byte], - envVars: JMap[String, String], - pythonIncludes: JList[String], - pythonExec: String, - pythonVer: String, - broadcastVars: JList[Broadcast[PythonBroadcast]], - accumulator: Accumulator[JList[Array[Byte]]], - dataType: DataType) { - - def builder(e: Seq[Expression]): PythonUDF = { - PythonUDF(name, command, envVars, pythonIncludes, pythonExec, pythonVer, broadcastVars, - accumulator, dataType, e) - } - - /** Returns a [[Column]] that will evaluate to calling this UDF with the given input. */ - def apply(exprs: Column*): Column = { - val udf = builder(exprs.map(_.expr)) - Column(udf) - } -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 598ddd7161..73fd22b38e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -369,8 +369,8 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { case logical.RepartitionByExpression(expressions, child, nPartitions) => execution.Exchange(HashPartitioning( expressions, nPartitions.getOrElse(numPartitions)), planLater(child)) :: Nil - case e @ EvaluatePython(udf, child, _) => - BatchPythonEvaluation(udf, e.output, planLater(child)) :: Nil + case e @ python.EvaluatePython(udf, child, _) => + python.BatchPythonEvaluation(udf, e.output, planLater(child)) :: Nil case LogicalRDD(output, rdd) => PhysicalRDD(output, rdd, "ExistingRDD") :: Nil case BroadcastHint(child) => planLater(child) :: Nil case _ => Nil diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python.scala deleted file mode 100644 index bf62bb05c3..0000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python.scala +++ /dev/null @@ -1,414 +0,0 @@ -/* -* Licensed to the Apache Software Foundation (ASF) under one or more -* contributor license agreements. See the NOTICE file distributed with -* this work for additional information regarding copyright ownership. -* The ASF licenses this file to You under the Apache License, Version 2.0 -* (the "License"); you may not use this file except in compliance with -* the License. You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*/ - -package org.apache.spark.sql.execution - -import java.io.OutputStream -import java.util.{List => JList, Map => JMap} - -import scala.collection.JavaConverters._ - -import net.razorvine.pickle._ - -import org.apache.spark.{Accumulator, Logging => SparkLogging, TaskContext} -import org.apache.spark.api.python.{PythonBroadcast, PythonRDD, PythonRunner, SerDeUtil} -import org.apache.spark.broadcast.Broadcast -import org.apache.spark.rdd.RDD -import org.apache.spark.sql.DataFrame -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans.logical -import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan -import org.apache.spark.sql.catalyst.rules.Rule -import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, GenericArrayData, MapData} -import org.apache.spark.sql.types._ -import org.apache.spark.unsafe.types.UTF8String - -/** - * A serialized version of a Python lambda function. Suitable for use in a [[PythonRDD]]. - */ -private[spark] case class PythonUDF( - name: String, - command: Array[Byte], - envVars: JMap[String, String], - pythonIncludes: JList[String], - pythonExec: String, - pythonVer: String, - broadcastVars: JList[Broadcast[PythonBroadcast]], - accumulator: Accumulator[JList[Array[Byte]]], - dataType: DataType, - children: Seq[Expression]) extends Expression with Unevaluable with SparkLogging { - - override def toString: String = s"PythonUDF#$name(${children.mkString(",")})" - - override def nullable: Boolean = true -} - -/** - * Extracts PythonUDFs from operators, rewriting the query plan so that the UDF can be evaluated - * alone in a batch. - * - * This has the limitation that the input to the Python UDF is not allowed include attributes from - * multiple child operators. - */ -private[spark] object ExtractPythonUDFs extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { - // Skip EvaluatePython nodes. - case plan: EvaluatePython => plan - - case plan: LogicalPlan if plan.resolved => - // Extract any PythonUDFs from the current operator. - val udfs = plan.expressions.flatMap(_.collect { case udf: PythonUDF => udf }) - if (udfs.isEmpty) { - // If there aren't any, we are done. - plan - } else { - // Pick the UDF we are going to evaluate (TODO: Support evaluating multiple UDFs at a time) - // If there is more than one, we will add another evaluation operator in a subsequent pass. - udfs.find(_.resolved) match { - case Some(udf) => - var evaluation: EvaluatePython = null - - // Rewrite the child that has the input required for the UDF - val newChildren = plan.children.map { child => - // Check to make sure that the UDF can be evaluated with only the input of this child. - // Other cases are disallowed as they are ambiguous or would require a cartesian - // product. - if (udf.references.subsetOf(child.outputSet)) { - evaluation = EvaluatePython(udf, child) - evaluation - } else if (udf.references.intersect(child.outputSet).nonEmpty) { - sys.error(s"Invalid PythonUDF $udf, requires attributes from more than one child.") - } else { - child - } - } - - assert(evaluation != null, "Unable to evaluate PythonUDF. Missing input attributes.") - - // Trim away the new UDF value if it was only used for filtering or something. - logical.Project( - plan.output, - plan.transformExpressions { - case p: PythonUDF if p.fastEquals(udf) => evaluation.resultAttribute - }.withNewChildren(newChildren)) - - case None => - // If there is no Python UDF that is resolved, skip this round. - plan - } - } - } -} - -object EvaluatePython { - def apply(udf: PythonUDF, child: LogicalPlan): EvaluatePython = - new EvaluatePython(udf, child, AttributeReference("pythonUDF", udf.dataType)()) - - def takeAndServe(df: DataFrame, n: Int): Int = { - registerPicklers() - df.withNewExecutionId { - val iter = new SerDeUtil.AutoBatchedPickler( - df.queryExecution.executedPlan.executeTake(n).iterator.map { row => - EvaluatePython.toJava(row, df.schema) - }) - PythonRDD.serveIterator(iter, s"serve-DataFrame") - } - } - - /** - * Helper for converting from Catalyst type to java type suitable for Pyrolite. - */ - def toJava(obj: Any, dataType: DataType): Any = (obj, dataType) match { - case (null, _) => null - - case (row: InternalRow, struct: StructType) => - val values = new Array[Any](row.numFields) - var i = 0 - while (i < row.numFields) { - values(i) = toJava(row.get(i, struct.fields(i).dataType), struct.fields(i).dataType) - i += 1 - } - new GenericRowWithSchema(values, struct) - - case (a: ArrayData, array: ArrayType) => - val values = new java.util.ArrayList[Any](a.numElements()) - a.foreach(array.elementType, (_, e) => { - values.add(toJava(e, array.elementType)) - }) - values - - case (map: MapData, mt: MapType) => - val jmap = new java.util.HashMap[Any, Any](map.numElements()) - map.foreach(mt.keyType, mt.valueType, (k, v) => { - jmap.put(toJava(k, mt.keyType), toJava(v, mt.valueType)) - }) - jmap - - case (ud, udt: UserDefinedType[_]) => toJava(ud, udt.sqlType) - - case (d: Decimal, _) => d.toJavaBigDecimal - - case (s: UTF8String, StringType) => s.toString - - case (other, _) => other - } - - /** - * Converts `obj` to the type specified by the data type, or returns null if the type of obj is - * unexpected. Because Python doesn't enforce the type. - */ - def fromJava(obj: Any, dataType: DataType): Any = (obj, dataType) match { - case (null, _) => null - - case (c: Boolean, BooleanType) => c - - case (c: Int, ByteType) => c.toByte - case (c: Long, ByteType) => c.toByte - - case (c: Int, ShortType) => c.toShort - case (c: Long, ShortType) => c.toShort - - case (c: Int, IntegerType) => c - case (c: Long, IntegerType) => c.toInt - - case (c: Int, LongType) => c.toLong - case (c: Long, LongType) => c - - case (c: Double, FloatType) => c.toFloat - - case (c: Double, DoubleType) => c - - case (c: java.math.BigDecimal, dt: DecimalType) => Decimal(c, dt.precision, dt.scale) - - case (c: Int, DateType) => c - - case (c: Long, TimestampType) => c - - case (c, StringType) => UTF8String.fromString(c.toString) - - case (c: String, BinaryType) => c.getBytes("utf-8") - case (c, BinaryType) if c.getClass.isArray && c.getClass.getComponentType.getName == "byte" => c - - case (c: java.util.List[_], ArrayType(elementType, _)) => - new GenericArrayData(c.asScala.map { e => fromJava(e, elementType)}.toArray) - - case (c, ArrayType(elementType, _)) if c.getClass.isArray => - new GenericArrayData(c.asInstanceOf[Array[_]].map(e => fromJava(e, elementType))) - - case (c: java.util.Map[_, _], MapType(keyType, valueType, _)) => - val keyValues = c.asScala.toSeq - val keys = keyValues.map(kv => fromJava(kv._1, keyType)).toArray - val values = keyValues.map(kv => fromJava(kv._2, valueType)).toArray - ArrayBasedMapData(keys, values) - - case (c, StructType(fields)) if c.getClass.isArray => - val array = c.asInstanceOf[Array[_]] - if (array.length != fields.length) { - throw new IllegalStateException( - s"Input row doesn't have expected number of values required by the schema. " + - s"${fields.length} fields are required while ${array.length} values are provided." - ) - } - new GenericInternalRow(array.zip(fields).map { - case (e, f) => fromJava(e, f.dataType) - }) - - case (_, udt: UserDefinedType[_]) => fromJava(obj, udt.sqlType) - - // all other unexpected type should be null, or we will have runtime exception - // TODO(davies): we could improve this by try to cast the object to expected type - case (c, _) => null - } - - - private val module = "pyspark.sql.types" - - /** - * Pickler for StructType - */ - private class StructTypePickler extends IObjectPickler { - - private val cls = classOf[StructType] - - def register(): Unit = { - Pickler.registerCustomPickler(cls, this) - } - - def pickle(obj: Object, out: OutputStream, pickler: Pickler): Unit = { - out.write(Opcodes.GLOBAL) - out.write((module + "\n" + "_parse_datatype_json_string" + "\n").getBytes("utf-8")) - val schema = obj.asInstanceOf[StructType] - pickler.save(schema.json) - out.write(Opcodes.TUPLE1) - out.write(Opcodes.REDUCE) - } - } - - /** - * Pickler for external row. - */ - private class RowPickler extends IObjectPickler { - - private val cls = classOf[GenericRowWithSchema] - - // register this to Pickler and Unpickler - def register(): Unit = { - Pickler.registerCustomPickler(this.getClass, this) - Pickler.registerCustomPickler(cls, this) - } - - def pickle(obj: Object, out: OutputStream, pickler: Pickler): Unit = { - if (obj == this) { - out.write(Opcodes.GLOBAL) - out.write((module + "\n" + "_create_row_inbound_converter" + "\n").getBytes("utf-8")) - } else { - // it will be memorized by Pickler to save some bytes - pickler.save(this) - val row = obj.asInstanceOf[GenericRowWithSchema] - // schema should always be same object for memoization - pickler.save(row.schema) - out.write(Opcodes.TUPLE1) - out.write(Opcodes.REDUCE) - - out.write(Opcodes.MARK) - var i = 0 - while (i < row.values.size) { - pickler.save(row.values(i)) - i += 1 - } - out.write(Opcodes.TUPLE) - out.write(Opcodes.REDUCE) - } - } - } - - private[this] var registered = false - /** - * This should be called before trying to serialize any above classes un cluster mode, - * this should be put in the closure - */ - def registerPicklers(): Unit = { - synchronized { - if (!registered) { - SerDeUtil.initialize() - new StructTypePickler().register() - new RowPickler().register() - registered = true - } - } - } - - /** - * Convert an RDD of Java objects to an RDD of serialized Python objects, that is usable by - * PySpark. - */ - def javaToPython(rdd: RDD[Any]): RDD[Array[Byte]] = { - rdd.mapPartitions { iter => - registerPicklers() // let it called in executor - new SerDeUtil.AutoBatchedPickler(iter) - } - } -} - -/** - * Evaluates a [[PythonUDF]], appending the result to the end of the input tuple. - */ -case class EvaluatePython( - udf: PythonUDF, - child: LogicalPlan, - resultAttribute: AttributeReference) - extends logical.UnaryNode { - - def output: Seq[Attribute] = child.output :+ resultAttribute - - // References should not include the produced attribute. - override def references: AttributeSet = udf.references -} - -/** - * Uses PythonRDD to evaluate a [[PythonUDF]], one partition of tuples at a time. - * - * Python evaluation works by sending the necessary (projected) input data via a socket to an - * external Python process, and combine the result from the Python process with the original row. - * - * For each row we send to Python, we also put it in a queue. For each output row from Python, - * we drain the queue to find the original input row. Note that if the Python process is way too - * slow, this could lead to the queue growing unbounded and eventually run out of memory. - */ -case class BatchPythonEvaluation(udf: PythonUDF, output: Seq[Attribute], child: SparkPlan) - extends SparkPlan { - - def children: Seq[SparkPlan] = child :: Nil - - protected override def doExecute(): RDD[InternalRow] = { - val inputRDD = child.execute().map(_.copy()) - val bufferSize = inputRDD.conf.getInt("spark.buffer.size", 65536) - val reuseWorker = inputRDD.conf.getBoolean("spark.python.worker.reuse", defaultValue = true) - - inputRDD.mapPartitions { iter => - EvaluatePython.registerPicklers() // register pickler for Row - - // The queue used to buffer input rows so we can drain it to - // combine input with output from Python. - val queue = new java.util.concurrent.ConcurrentLinkedQueue[InternalRow]() - - val pickle = new Pickler - val currentRow = newMutableProjection(udf.children, child.output)() - val fields = udf.children.map(_.dataType) - val schema = new StructType(fields.map(t => new StructField("", t, true)).toArray) - - // Input iterator to Python: input rows are grouped so we send them in batches to Python. - // For each row, add it to the queue. - val inputIterator = iter.grouped(100).map { inputRows => - val toBePickled = inputRows.map { row => - queue.add(row) - EvaluatePython.toJava(currentRow(row), schema) - }.toArray - pickle.dumps(toBePickled) - } - - val context = TaskContext.get() - - // Output iterator for results from Python. - val outputIterator = new PythonRunner( - udf.command, - udf.envVars, - udf.pythonIncludes, - udf.pythonExec, - udf.pythonVer, - udf.broadcastVars, - udf.accumulator, - bufferSize, - reuseWorker - ).compute(inputIterator, context.partitionId(), context) - - val unpickle = new Unpickler - val row = new GenericMutableRow(1) - val joined = new JoinedRow - val resultProj = UnsafeProjection.create(output, output) - - outputIterator.flatMap { pickedResult => - val unpickledBatch = unpickle.loads(pickedResult) - unpickledBatch.asInstanceOf[java.util.ArrayList[Any]].asScala - }.map { result => - row(0) = EvaluatePython.fromJava(result, udf.dataType) - resultProj(joined(queue.poll(), row)) - } - } - } -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchPythonEvaluation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchPythonEvaluation.scala new file mode 100644 index 0000000000..00df019527 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchPythonEvaluation.scala @@ -0,0 +1,104 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package org.apache.spark.sql.execution.python + +import scala.collection.JavaConverters._ + +import net.razorvine.pickle.{Pickler, Unpickler} + +import org.apache.spark.TaskContext +import org.apache.spark.api.python.PythonRunner +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.{Attribute, GenericMutableRow, JoinedRow, UnsafeProjection} +import org.apache.spark.sql.execution.SparkPlan +import org.apache.spark.sql.types.{StructField, StructType} + + +/** + * A physical plan that evalutes a [[PythonUDF]], one partition of tuples at a time. + * + * Python evaluation works by sending the necessary (projected) input data via a socket to an + * external Python process, and combine the result from the Python process with the original row. + * + * For each row we send to Python, we also put it in a queue. For each output row from Python, + * we drain the queue to find the original input row. Note that if the Python process is way too + * slow, this could lead to the queue growing unbounded and eventually run out of memory. + */ +case class BatchPythonEvaluation(udf: PythonUDF, output: Seq[Attribute], child: SparkPlan) + extends SparkPlan { + + def children: Seq[SparkPlan] = child :: Nil + + protected override def doExecute(): RDD[InternalRow] = { + val inputRDD = child.execute().map(_.copy()) + val bufferSize = inputRDD.conf.getInt("spark.buffer.size", 65536) + val reuseWorker = inputRDD.conf.getBoolean("spark.python.worker.reuse", defaultValue = true) + + inputRDD.mapPartitions { iter => + EvaluatePython.registerPicklers() // register pickler for Row + + // The queue used to buffer input rows so we can drain it to + // combine input with output from Python. + val queue = new java.util.concurrent.ConcurrentLinkedQueue[InternalRow]() + + val pickle = new Pickler + val currentRow = newMutableProjection(udf.children, child.output)() + val fields = udf.children.map(_.dataType) + val schema = new StructType(fields.map(t => new StructField("", t, true)).toArray) + + // Input iterator to Python: input rows are grouped so we send them in batches to Python. + // For each row, add it to the queue. + val inputIterator = iter.grouped(100).map { inputRows => + val toBePickled = inputRows.map { row => + queue.add(row) + EvaluatePython.toJava(currentRow(row), schema) + }.toArray + pickle.dumps(toBePickled) + } + + val context = TaskContext.get() + + // Output iterator for results from Python. + val outputIterator = new PythonRunner( + udf.command, + udf.envVars, + udf.pythonIncludes, + udf.pythonExec, + udf.pythonVer, + udf.broadcastVars, + udf.accumulator, + bufferSize, + reuseWorker + ).compute(inputIterator, context.partitionId(), context) + + val unpickle = new Unpickler + val row = new GenericMutableRow(1) + val joined = new JoinedRow + val resultProj = UnsafeProjection.create(output, output) + + outputIterator.flatMap { pickedResult => + val unpickledBatch = unpickle.loads(pickedResult) + unpickledBatch.asInstanceOf[java.util.ArrayList[Any]].asScala + }.map { result => + row(0) = EvaluatePython.fromJava(result, udf.dataType) + resultProj(joined(queue.poll(), row)) + } + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvaluatePython.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvaluatePython.scala new file mode 100644 index 0000000000..8c46516594 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvaluatePython.scala @@ -0,0 +1,261 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package org.apache.spark.sql.execution.python + +import java.io.OutputStream + +import scala.collection.JavaConverters._ + +import net.razorvine.pickle.{IObjectPickler, Opcodes, Pickler} + +import org.apache.spark.api.python.{PythonRDD, SerDeUtil} +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.logical +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, GenericArrayData, MapData} +import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.UTF8String + +/** + * Evaluates a [[PythonUDF]], appending the result to the end of the input tuple. + */ +case class EvaluatePython( + udf: PythonUDF, + child: LogicalPlan, + resultAttribute: AttributeReference) + extends logical.UnaryNode { + + def output: Seq[Attribute] = child.output :+ resultAttribute + + // References should not include the produced attribute. + override def references: AttributeSet = udf.references +} + + +object EvaluatePython { + def apply(udf: PythonUDF, child: LogicalPlan): EvaluatePython = + new EvaluatePython(udf, child, AttributeReference("pythonUDF", udf.dataType)()) + + def takeAndServe(df: DataFrame, n: Int): Int = { + registerPicklers() + df.withNewExecutionId { + val iter = new SerDeUtil.AutoBatchedPickler( + df.queryExecution.executedPlan.executeTake(n).iterator.map { row => + EvaluatePython.toJava(row, df.schema) + }) + PythonRDD.serveIterator(iter, s"serve-DataFrame") + } + } + + /** + * Helper for converting from Catalyst type to java type suitable for Pyrolite. + */ + def toJava(obj: Any, dataType: DataType): Any = (obj, dataType) match { + case (null, _) => null + + case (row: InternalRow, struct: StructType) => + val values = new Array[Any](row.numFields) + var i = 0 + while (i < row.numFields) { + values(i) = toJava(row.get(i, struct.fields(i).dataType), struct.fields(i).dataType) + i += 1 + } + new GenericRowWithSchema(values, struct) + + case (a: ArrayData, array: ArrayType) => + val values = new java.util.ArrayList[Any](a.numElements()) + a.foreach(array.elementType, (_, e) => { + values.add(toJava(e, array.elementType)) + }) + values + + case (map: MapData, mt: MapType) => + val jmap = new java.util.HashMap[Any, Any](map.numElements()) + map.foreach(mt.keyType, mt.valueType, (k, v) => { + jmap.put(toJava(k, mt.keyType), toJava(v, mt.valueType)) + }) + jmap + + case (ud, udt: UserDefinedType[_]) => toJava(ud, udt.sqlType) + + case (d: Decimal, _) => d.toJavaBigDecimal + + case (s: UTF8String, StringType) => s.toString + + case (other, _) => other + } + + /** + * Converts `obj` to the type specified by the data type, or returns null if the type of obj is + * unexpected. Because Python doesn't enforce the type. + */ + def fromJava(obj: Any, dataType: DataType): Any = (obj, dataType) match { + case (null, _) => null + + case (c: Boolean, BooleanType) => c + + case (c: Int, ByteType) => c.toByte + case (c: Long, ByteType) => c.toByte + + case (c: Int, ShortType) => c.toShort + case (c: Long, ShortType) => c.toShort + + case (c: Int, IntegerType) => c + case (c: Long, IntegerType) => c.toInt + + case (c: Int, LongType) => c.toLong + case (c: Long, LongType) => c + + case (c: Double, FloatType) => c.toFloat + + case (c: Double, DoubleType) => c + + case (c: java.math.BigDecimal, dt: DecimalType) => Decimal(c, dt.precision, dt.scale) + + case (c: Int, DateType) => c + + case (c: Long, TimestampType) => c + + case (c, StringType) => UTF8String.fromString(c.toString) + + case (c: String, BinaryType) => c.getBytes("utf-8") + case (c, BinaryType) if c.getClass.isArray && c.getClass.getComponentType.getName == "byte" => c + + case (c: java.util.List[_], ArrayType(elementType, _)) => + new GenericArrayData(c.asScala.map { e => fromJava(e, elementType)}.toArray) + + case (c, ArrayType(elementType, _)) if c.getClass.isArray => + new GenericArrayData(c.asInstanceOf[Array[_]].map(e => fromJava(e, elementType))) + + case (c: java.util.Map[_, _], MapType(keyType, valueType, _)) => + val keyValues = c.asScala.toSeq + val keys = keyValues.map(kv => fromJava(kv._1, keyType)).toArray + val values = keyValues.map(kv => fromJava(kv._2, valueType)).toArray + ArrayBasedMapData(keys, values) + + case (c, StructType(fields)) if c.getClass.isArray => + val array = c.asInstanceOf[Array[_]] + if (array.length != fields.length) { + throw new IllegalStateException( + s"Input row doesn't have expected number of values required by the schema. " + + s"${fields.length} fields are required while ${array.length} values are provided." + ) + } + new GenericInternalRow(array.zip(fields).map { + case (e, f) => fromJava(e, f.dataType) + }) + + case (_, udt: UserDefinedType[_]) => fromJava(obj, udt.sqlType) + + // all other unexpected type should be null, or we will have runtime exception + // TODO(davies): we could improve this by try to cast the object to expected type + case (c, _) => null + } + + private val module = "pyspark.sql.types" + + /** + * Pickler for StructType + */ + private class StructTypePickler extends IObjectPickler { + + private val cls = classOf[StructType] + + def register(): Unit = { + Pickler.registerCustomPickler(cls, this) + } + + def pickle(obj: Object, out: OutputStream, pickler: Pickler): Unit = { + out.write(Opcodes.GLOBAL) + out.write((module + "\n" + "_parse_datatype_json_string" + "\n").getBytes("utf-8")) + val schema = obj.asInstanceOf[StructType] + pickler.save(schema.json) + out.write(Opcodes.TUPLE1) + out.write(Opcodes.REDUCE) + } + } + + /** + * Pickler for external row. + */ + private class RowPickler extends IObjectPickler { + + private val cls = classOf[GenericRowWithSchema] + + // register this to Pickler and Unpickler + def register(): Unit = { + Pickler.registerCustomPickler(this.getClass, this) + Pickler.registerCustomPickler(cls, this) + } + + def pickle(obj: Object, out: OutputStream, pickler: Pickler): Unit = { + if (obj == this) { + out.write(Opcodes.GLOBAL) + out.write((module + "\n" + "_create_row_inbound_converter" + "\n").getBytes("utf-8")) + } else { + // it will be memorized by Pickler to save some bytes + pickler.save(this) + val row = obj.asInstanceOf[GenericRowWithSchema] + // schema should always be same object for memoization + pickler.save(row.schema) + out.write(Opcodes.TUPLE1) + out.write(Opcodes.REDUCE) + + out.write(Opcodes.MARK) + var i = 0 + while (i < row.values.length) { + pickler.save(row.values(i)) + i += 1 + } + out.write(Opcodes.TUPLE) + out.write(Opcodes.REDUCE) + } + } + } + + private[this] var registered = false + + /** + * This should be called before trying to serialize any above classes un cluster mode, + * this should be put in the closure + */ + def registerPicklers(): Unit = { + synchronized { + if (!registered) { + SerDeUtil.initialize() + new StructTypePickler().register() + new RowPickler().register() + registered = true + } + } + } + + /** + * Convert an RDD of Java objects to an RDD of serialized Python objects, that is usable by + * PySpark. + */ + def javaToPython(rdd: RDD[Any]): RDD[Array[Byte]] = { + rdd.mapPartitions { iter => + registerPicklers() // let it called in executor + new SerDeUtil.AutoBatchedPickler(iter) + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala new file mode 100644 index 0000000000..6e76e9569f --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala @@ -0,0 +1,79 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package org.apache.spark.sql.execution.python + +import org.apache.spark.sql.catalyst.plans.logical +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.rules.Rule + +/** + * Extracts PythonUDFs from operators, rewriting the query plan so that the UDF can be evaluated + * alone in a batch. + * + * This has the limitation that the input to the Python UDF is not allowed include attributes from + * multiple child operators. + */ +private[spark] object ExtractPythonUDFs extends Rule[LogicalPlan] { + def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { + // Skip EvaluatePython nodes. + case plan: EvaluatePython => plan + + case plan: LogicalPlan if plan.resolved => + // Extract any PythonUDFs from the current operator. + val udfs = plan.expressions.flatMap(_.collect { case udf: PythonUDF => udf }) + if (udfs.isEmpty) { + // If there aren't any, we are done. + plan + } else { + // Pick the UDF we are going to evaluate (TODO: Support evaluating multiple UDFs at a time) + // If there is more than one, we will add another evaluation operator in a subsequent pass. + udfs.find(_.resolved) match { + case Some(udf) => + var evaluation: EvaluatePython = null + + // Rewrite the child that has the input required for the UDF + val newChildren = plan.children.map { child => + // Check to make sure that the UDF can be evaluated with only the input of this child. + // Other cases are disallowed as they are ambiguous or would require a cartesian + // product. + if (udf.references.subsetOf(child.outputSet)) { + evaluation = EvaluatePython(udf, child) + evaluation + } else if (udf.references.intersect(child.outputSet).nonEmpty) { + sys.error(s"Invalid PythonUDF $udf, requires attributes from more than one child.") + } else { + child + } + } + + assert(evaluation != null, "Unable to evaluate PythonUDF. Missing input attributes.") + + // Trim away the new UDF value if it was only used for filtering or something. + logical.Project( + plan.output, + plan.transformExpressions { + case p: PythonUDF if p.fastEquals(udf) => evaluation.resultAttribute + }.withNewChildren(newChildren)) + + case None => + // If there is no Python UDF that is resolved, skip this round. + plan + } + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDF.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDF.scala new file mode 100644 index 0000000000..0e53a0c473 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDF.scala @@ -0,0 +1,44 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.python + +import org.apache.spark.{Accumulator, Logging} +import org.apache.spark.api.python.PythonBroadcast +import org.apache.spark.broadcast.Broadcast +import org.apache.spark.sql.catalyst.expressions.{Expression, Unevaluable} +import org.apache.spark.sql.types.DataType + +/** + * A serialized version of a Python lambda function. + */ +case class PythonUDF( + name: String, + command: Array[Byte], + envVars: java.util.Map[String, String], + pythonIncludes: java.util.List[String], + pythonExec: String, + pythonVer: String, + broadcastVars: java.util.List[Broadcast[PythonBroadcast]], + accumulator: Accumulator[java.util.List[Array[Byte]]], + dataType: DataType, + children: Seq[Expression]) extends Expression with Unevaluable with Logging { + + override def toString: String = s"PythonUDF#$name(${children.mkString(",")})" + + override def nullable: Boolean = true +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonFunction.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonFunction.scala new file mode 100644 index 0000000000..79ac1c85c0 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonFunction.scala @@ -0,0 +1,51 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.python + +import org.apache.spark.Accumulator +import org.apache.spark.api.python.PythonBroadcast +import org.apache.spark.broadcast.Broadcast +import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.Column +import org.apache.spark.sql.types.DataType + +/** + * A user-defined Python function. This is used by the Python API. + */ +case class UserDefinedPythonFunction( + name: String, + command: Array[Byte], + envVars: java.util.Map[String, String], + pythonIncludes: java.util.List[String], + pythonExec: String, + pythonVer: String, + broadcastVars: java.util.List[Broadcast[PythonBroadcast]], + accumulator: Accumulator[java.util.List[Array[Byte]]], + dataType: DataType) { + + def builder(e: Seq[Expression]): PythonUDF = { + PythonUDF(name, command, envVars, pythonIncludes, pythonExec, pythonVer, broadcastVars, + accumulator, dataType, e) + } + + /** Returns a [[Column]] that will evaluate to calling this UDF with the given input. */ + def apply(exprs: Column*): Column = { + val udf = builder(exprs.map(_.expr)) + Column(udf) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala b/sql/core/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala new file mode 100644 index 0000000000..bd35d19aa2 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala @@ -0,0 +1,48 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package org.apache.spark.sql.expressions + +import org.apache.spark.annotation.Experimental +import org.apache.spark.sql.catalyst.expressions.ScalaUDF +import org.apache.spark.sql.Column +import org.apache.spark.sql.functions +import org.apache.spark.sql.types.DataType + +/** + * A user-defined function. To create one, use the `udf` functions in [[functions]]. + * As an example: + * {{{ + * // Defined a UDF that returns true or false based on some numeric score. + * val predict = udf((score: Double) => if (score > 0.5) true else false) + * + * // Projects a column that adds a prediction column based on the score column. + * df.select( predict(df("score")) ) + * }}} + * + * @since 1.3.0 + */ +@Experimental +case class UserDefinedFunction protected[sql] ( + f: AnyRef, + dataType: DataType, + inputTypes: Option[Seq[DataType]]) { + + def apply(exprs: Column*): Column = { + Column(ScalaUDF(f, dataType, exprs.map(_.expr), inputTypes.getOrElse(Nil))) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index d34d377ab6..e4ab6b4f23 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -28,6 +28,7 @@ import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.plans.logical.BroadcastHint +import org.apache.spark.sql.expressions.UserDefinedFunction import org.apache.spark.sql.types._ import org.apache.spark.util.Utils diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala index 2433b54ffc..ac174aa6bf 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala @@ -465,7 +465,7 @@ class HiveContext private[hive]( catalog.ParquetConversions :: catalog.CreateTables :: catalog.PreInsertionCasts :: - ExtractPythonUDFs :: + python.ExtractPythonUDFs :: PreInsertCastAndRename :: (if (conf.runSQLOnFile) new ResolveDataSource(self) :: Nil else Nil) -- cgit v1.2.3