diff options
author | Liang-Chi Hsieh <simonh@tw.ibm.com> | 2016-08-02 10:08:18 -0700 |
---|---|---|
committer | Davies Liu <davies.liu@gmail.com> | 2016-08-02 10:08:18 -0700 |
commit | 146001a9ffefc7aaedd3d888d68c7a9b80bca545 (patch) | |
tree | 194b8ce9975f93ea574541c7b643edd05fe70521 /sql/catalyst/src/main/scala | |
parent | 1dab63d8d3c59a3d6b4ee8e777810c44849e58b8 (diff) | |
download | spark-146001a9ffefc7aaedd3d888d68c7a9b80bca545.tar.gz spark-146001a9ffefc7aaedd3d888d68c7a9b80bca545.tar.bz2 spark-146001a9ffefc7aaedd3d888d68c7a9b80bca545.zip |
[SPARK-16062] [SPARK-15989] [SQL] Fix two bugs of Python-only UDTs
## What changes were proposed in this pull request?
There are two related bugs of Python-only UDTs. Because the test case of second one needs the first fix too. I put them into one PR. If it is not appropriate, please let me know.
### First bug: When MapObjects works on Python-only UDTs
`RowEncoder` will use `PythonUserDefinedType.sqlType` for its deserializer expression. If the sql type is `ArrayType`, we will have `MapObjects` working on it. But `MapObjects` doesn't consider `PythonUserDefinedType` as its input data type. It causes error like:
import pyspark.sql.group
from pyspark.sql.tests import PythonOnlyPoint, PythonOnlyUDT
from pyspark.sql.types import *
schema = StructType().add("key", LongType()).add("val", PythonOnlyUDT())
df = spark.createDataFrame([(i % 3, PythonOnlyPoint(float(i), float(i))) for i in range(10)], schema=schema)
df.show()
File "/home/spark/python/lib/py4j-0.10.1-src.zip/py4j/protocol.py", line 312, in get_return_value py4j.protocol.Py4JJavaError: An error occurred while calling o36.showString.
: java.lang.RuntimeException: Error while decoding: scala.MatchError: org.apache.spark.sql.types.PythonUserDefinedTypef4ceede8 (of class org.apache.spark.sql.types.PythonUserDefinedType)
...
### Second bug: When Python-only UDTs is the element type of ArrayType
import pyspark.sql.group
from pyspark.sql.tests import PythonOnlyPoint, PythonOnlyUDT
from pyspark.sql.types import *
schema = StructType().add("key", LongType()).add("val", ArrayType(PythonOnlyUDT()))
df = spark.createDataFrame([(i % 3, [PythonOnlyPoint(float(i), float(i))]) for i in range(10)], schema=schema)
df.show()
## How was this patch tested?
PySpark's sql tests.
Author: Liang-Chi Hsieh <simonh@tw.ibm.com>
Closes #13778 from viirya/fix-pyudt.
Diffstat (limited to 'sql/catalyst/src/main/scala')
-rw-r--r-- | sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala | 9 | ||||
-rw-r--r-- | sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala | 17 |
2 files changed, 23 insertions, 3 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala index 67fca153b5..2a6fcd03a2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala @@ -206,6 +206,7 @@ object RowEncoder { case _: ArrayType => ObjectType(classOf[scala.collection.Seq[_]]) case _: MapType => ObjectType(classOf[scala.collection.Map[_, _]]) case _: StructType => ObjectType(classOf[Row]) + case p: PythonUserDefinedType => externalDataTypeFor(p.sqlType) case udt: UserDefinedType[_] => ObjectType(udt.userClass) } @@ -220,9 +221,15 @@ object RowEncoder { CreateExternalRow(fields, schema) } - private def deserializerFor(input: Expression): Expression = input.dataType match { + private def deserializerFor(input: Expression): Expression = { + deserializerFor(input, input.dataType) + } + + private def deserializerFor(input: Expression, dataType: DataType): Expression = dataType match { case dt if ScalaReflection.isNativeType(dt) => input + case p: PythonUserDefinedType => deserializerFor(input, p.sqlType) + case udt: UserDefinedType[_] => val annotation = udt.userClass.getAnnotation(classOf[SQLUserDefinedType]) val udtClass: Class[_] = if (annotation != null) { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala index 06589411cf..952a5f3b04 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala @@ -359,6 +359,13 @@ case class LambdaVariable(value: String, isNull: String, dataType: DataType) ext object MapObjects { private val curId = new java.util.concurrent.atomic.AtomicInteger() + /** + * Construct an instance of MapObjects case class. + * + * @param function The function applied on the collection elements. + * @param inputData An expression that when evaluated returns a collection object. + * @param elementType The data type of elements in the collection. + */ def apply( function: Expression => Expression, inputData: Expression, @@ -446,8 +453,14 @@ case class MapObjects private( case _ => "" } + // The data with PythonUserDefinedType are actually stored with the data type of its sqlType. + // When we want to apply MapObjects on it, we have to use it. + val inputDataType = inputData.dataType match { + case p: PythonUserDefinedType => p.sqlType + case _ => inputData.dataType + } - val (getLength, getLoopVar) = inputData.dataType match { + val (getLength, getLoopVar) = inputDataType match { case ObjectType(cls) if classOf[Seq[_]].isAssignableFrom(cls) => s"${genInputData.value}.size()" -> s"${genInputData.value}.apply($loopIndex)" case ObjectType(cls) if cls.isArray => @@ -461,7 +474,7 @@ case class MapObjects private( s"$seq == null ? $array[$loopIndex] : $seq.apply($loopIndex)" } - val loopNullCheck = inputData.dataType match { + val loopNullCheck = inputDataType match { case _: ArrayType => s"$loopIsNull = ${genInputData.value}.isNullAt($loopIndex);" // The element of primitive array will never be null. case ObjectType(cls) if cls.isArray && cls.getComponentType.isPrimitive => |