aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
authorDavies Liu <davies@databricks.com>2014-10-28 19:38:16 -0700
committerMichael Armbrust <michael@databricks.com>2014-10-28 19:38:16 -0700
commit8c0bfd08fc19fa5de7d77bf8306d19834f907ec0 (patch)
tree96d0424f06e2c20d9ee34cc482f792fdbff473a6 /sql
parentb5e79bf889700159d490cdac1f6322dff424b1d9 (diff)
downloadspark-8c0bfd08fc19fa5de7d77bf8306d19834f907ec0.tar.gz
spark-8c0bfd08fc19fa5de7d77bf8306d19834f907ec0.tar.bz2
spark-8c0bfd08fc19fa5de7d77bf8306d19834f907ec0.zip
[SPARK-4133] [SQL] [PySpark] type conversionfor python udf
Call Python UDF on ArrayType/MapType/PrimitiveType, the returnType can also be ArrayType/MapType/PrimitiveType. For StructType, it will act as tuple (without attributes). If returnType is StructType, it also should be tuple. Author: Davies Liu <davies@databricks.com> Closes #2973 from davies/udf_array and squashes the following commits: 306956e [Davies Liu] Merge branch 'master' of github.com:apache/spark into udf_array 2c00e43 [Davies Liu] fix merge 11395fa [Davies Liu] Merge branch 'master' of github.com:apache/spark into udf_array 9df50a2 [Davies Liu] address comments 79afb4e [Davies Liu] type conversionfor python udf
Diffstat (limited to 'sql')
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala43
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala42
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala91
3 files changed, 89 insertions, 87 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
index ca8706ee68..a41a500c9a 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
@@ -438,7 +438,6 @@ class SQLContext(@transient val sparkContext: SparkContext)
private[sql] def applySchemaToPythonRDD(
rdd: RDD[Array[Any]],
schema: StructType): SchemaRDD = {
- import scala.collection.JavaConversions._
def needsConversion(dataType: DataType): Boolean = dataType match {
case ByteType => true
@@ -452,49 +451,9 @@ class SQLContext(@transient val sparkContext: SparkContext)
case other => false
}
- // Converts value to the type specified by the data type.
- // Because Python does not have data types for DateType, TimestampType, FloatType, ShortType,
- // and ByteType, we need to explicitly convert values in columns of these data types to the
- // desired JVM data types.
- def convert(obj: Any, dataType: DataType): Any = (obj, dataType) match {
- // TODO: We should check nullable
- case (null, _) => null
-
- case (c: java.util.List[_], ArrayType(elementType, _)) =>
- c.map { e => convert(e, elementType)}: Seq[Any]
-
- case (c, ArrayType(elementType, _)) if c.getClass.isArray =>
- c.asInstanceOf[Array[_]].map(e => convert(e, elementType)): Seq[Any]
-
- case (c: java.util.Map[_, _], MapType(keyType, valueType, _)) => c.map {
- case (key, value) => (convert(key, keyType), convert(value, valueType))
- }.toMap
-
- case (c, StructType(fields)) if c.getClass.isArray =>
- new GenericRow(c.asInstanceOf[Array[_]].zip(fields).map {
- case (e, f) => convert(e, f.dataType)
- }): Row
-
- case (c: java.util.Calendar, DateType) =>
- new java.sql.Date(c.getTime().getTime())
-
- case (c: java.util.Calendar, TimestampType) =>
- new java.sql.Timestamp(c.getTime().getTime())
-
- case (c: Int, ByteType) => c.toByte
- case (c: Long, ByteType) => c.toByte
- case (c: Int, ShortType) => c.toShort
- case (c: Long, ShortType) => c.toShort
- case (c: Long, IntegerType) => c.toInt
- case (c: Double, FloatType) => c.toFloat
- case (c, StringType) if !c.isInstanceOf[String] => c.toString
-
- case (c, _) => c
- }
-
val convertedRdd = if (schema.fields.exists(f => needsConversion(f.dataType))) {
rdd.map(m => m.zip(schema.fields).map {
- case (value, field) => convert(value, field.dataType)
+ case (value, field) => EvaluatePython.fromJava(value, field.dataType)
})
} else {
rdd
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala
index 948122d42f..8b96df1096 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala
@@ -34,7 +34,7 @@ import org.apache.spark.sql.catalyst.analysis._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.plans.{Inner, JoinType}
-import org.apache.spark.sql.execution.LogicalRDD
+import org.apache.spark.sql.execution.{LogicalRDD, EvaluatePython}
import org.apache.spark.api.java.JavaRDD
/**
@@ -378,46 +378,14 @@ class SchemaRDD(
def toJavaSchemaRDD: JavaSchemaRDD = new JavaSchemaRDD(sqlContext, logicalPlan)
/**
- * Helper for converting a Row to a simple Array suitable for pyspark serialization.
- */
- private def rowToJArray(row: Row, structType: StructType): Array[Any] = {
- import scala.collection.Map
-
- def toJava(obj: Any, dataType: DataType): Any = (obj, dataType) match {
- case (null, _) => null
-
- case (obj: Row, struct: StructType) => rowToJArray(obj, struct)
-
- case (seq: Seq[Any], array: ArrayType) =>
- seq.map(x => toJava(x, array.elementType)).asJava
- case (list: JList[_], array: ArrayType) =>
- list.map(x => toJava(x, array.elementType)).asJava
- case (arr, array: ArrayType) if arr.getClass.isArray =>
- arr.asInstanceOf[Array[Any]].map(x => toJava(x, array.elementType))
-
- case (obj: Map[_, _], mt: MapType) => obj.map {
- case (k, v) => (k, toJava(v, mt.valueType)) // key should be primitive type
- }.asJava
-
- // Pyrolite can handle Timestamp
- case (other, _) => other
- }
-
- val fields = structType.fields.map(field => field.dataType)
- row.zip(fields).map {
- case (obj, dataType) => toJava(obj, dataType)
- }.toArray
- }
-
- /**
* Converts a JavaRDD to a PythonRDD. It is used by pyspark.
*/
private[sql] def javaToPython: JavaRDD[Array[Byte]] = {
- val rowSchema = StructType.fromAttributes(this.queryExecution.analyzed.output)
+ val fieldTypes = schema.fields.map(_.dataType)
this.mapPartitions { iter =>
val pickle = new Pickler
iter.map { row =>
- rowToJArray(row, rowSchema)
+ EvaluatePython.rowToArray(row, fieldTypes)
}.grouped(100).map(batched => pickle.dumps(batched.toArray))
}
}
@@ -427,10 +395,10 @@ class SchemaRDD(
* format as javaToPython. It is used by pyspark.
*/
private[sql] def collectToPython: JList[Array[Byte]] = {
- val rowSchema = StructType.fromAttributes(this.queryExecution.analyzed.output)
+ val fieldTypes = schema.fields.map(_.dataType)
val pickle = new Pickler
new java.util.ArrayList(collect().map { row =>
- rowToJArray(row, rowSchema)
+ EvaluatePython.rowToArray(row, fieldTypes)
}.grouped(100).map(batched => pickle.dumps(batched.toArray)).toIterable)
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala
index be729e5d24..a1961bba18 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala
@@ -19,11 +19,14 @@ package org.apache.spark.sql.execution
import java.util.{List => JList, Map => JMap}
+import scala.collection.JavaConversions._
+import scala.collection.JavaConverters._
+
import net.razorvine.pickle.{Pickler, Unpickler}
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.api.python.PythonRDD
import org.apache.spark.broadcast.Broadcast
-import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.catalyst.expressions.Row
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.logical
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
@@ -31,8 +34,6 @@ import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.catalyst.types._
import org.apache.spark.{Accumulator, Logging => SparkLogging}
-import scala.collection.JavaConversions._
-
/**
* A serialized version of a Python lambda function. Suitable for use in a [[PythonRDD]].
*/
@@ -108,6 +109,80 @@ private[spark] object ExtractPythonUdfs extends Rule[LogicalPlan] {
object EvaluatePython {
def apply(udf: PythonUDF, child: LogicalPlan) =
new EvaluatePython(udf, child, AttributeReference("pythonUDF", udf.dataType)())
+
+ /**
+ * Helper for converting a Scala object to a java suitable for pyspark serialization.
+ */
+ def toJava(obj: Any, dataType: DataType): Any = (obj, dataType) match {
+ case (null, _) => null
+
+ case (row: Row, struct: StructType) =>
+ val fields = struct.fields.map(field => field.dataType)
+ row.zip(fields).map {
+ case (obj, dataType) => toJava(obj, dataType)
+ }.toArray
+
+ case (seq: Seq[Any], array: ArrayType) =>
+ seq.map(x => toJava(x, array.elementType)).asJava
+ case (list: JList[_], array: ArrayType) =>
+ list.map(x => toJava(x, array.elementType)).asJava
+ case (arr, array: ArrayType) if arr.getClass.isArray =>
+ arr.asInstanceOf[Array[Any]].map(x => toJava(x, array.elementType))
+
+ case (obj: Map[_, _], mt: MapType) => obj.map {
+ case (k, v) => (k, toJava(v, mt.valueType)) // key should be primitive type
+ }.asJava
+
+ // Pyrolite can handle Timestamp
+ case (other, _) => other
+ }
+
+ /**
+ * Convert Row into Java Array (for pickled into Python)
+ */
+ def rowToArray(row: Row, fields: Seq[DataType]): Array[Any] = {
+ row.zip(fields).map {case (obj, dt) => toJava(obj, dt)}.toArray
+ }
+
+ // Converts value to the type specified by the data type.
+ // Because Python does not have data types for TimestampType, FloatType, ShortType, and
+ // ByteType, we need to explicitly convert values in columns of these data types to the desired
+ // JVM data types.
+ def fromJava(obj: Any, dataType: DataType): Any = (obj, dataType) match {
+ // TODO: We should check nullable
+ case (null, _) => null
+
+ case (c: java.util.List[_], ArrayType(elementType, _)) =>
+ c.map { e => fromJava(e, elementType)}: Seq[Any]
+
+ case (c, ArrayType(elementType, _)) if c.getClass.isArray =>
+ c.asInstanceOf[Array[_]].map(e => fromJava(e, elementType)): Seq[Any]
+
+ case (c: java.util.Map[_, _], MapType(keyType, valueType, _)) => c.map {
+ case (key, value) => (fromJava(key, keyType), fromJava(value, valueType))
+ }.toMap
+
+ case (c, StructType(fields)) if c.getClass.isArray =>
+ new GenericRow(c.asInstanceOf[Array[_]].zip(fields).map {
+ case (e, f) => fromJava(e, f.dataType)
+ }): Row
+
+ case (c: java.util.Calendar, DateType) =>
+ new java.sql.Date(c.getTime().getTime())
+
+ case (c: java.util.Calendar, TimestampType) =>
+ new java.sql.Timestamp(c.getTime().getTime())
+
+ case (c: Int, ByteType) => c.toByte
+ case (c: Long, ByteType) => c.toByte
+ case (c: Int, ShortType) => c.toShort
+ case (c: Long, ShortType) => c.toShort
+ case (c: Long, IntegerType) => c.toInt
+ case (c: Double, FloatType) => c.toFloat
+ case (c, StringType) if !c.isInstanceOf[String] => c.toString
+
+ case (c, _) => c
+ }
}
/**
@@ -141,8 +216,11 @@ case class BatchPythonEvaluation(udf: PythonUDF, output: Seq[Attribute], child:
val parent = childResults.mapPartitions { iter =>
val pickle = new Pickler
val currentRow = newMutableProjection(udf.children, child.output)()
+ val fields = udf.children.map(_.dataType)
iter.grouped(1000).map { inputRows =>
- val toBePickled = inputRows.map(currentRow(_).toArray).toArray
+ val toBePickled = inputRows.map { row =>
+ EvaluatePython.rowToArray(currentRow(row), fields)
+ }.toArray
pickle.dumps(toBePickled)
}
}
@@ -165,10 +243,7 @@ case class BatchPythonEvaluation(udf: PythonUDF, output: Seq[Attribute], child:
}.mapPartitions { iter =>
val row = new GenericMutableRow(1)
iter.map { result =>
- row(0) = udf.dataType match {
- case StringType => result.toString
- case other => result
- }
+ row(0) = EvaluatePython.fromJava(result, udf.dataType)
row: Row
}
}