diff options
author | Wenchen Fan <cloud0fan@outlook.com> | 2015-07-30 10:04:30 -0700 |
---|---|---|
committer | Reynold Xin <rxin@databricks.com> | 2015-07-30 10:04:30 -0700 |
commit | c0cc0eaec67208c087a30c1b1f50c00b2c1ebf08 (patch) | |
tree | 582bad5631cde3bac3b5c69e1f22b3c4098de684 /sql/hive | |
parent | 7492a33fdd074446c30c657d771a69932a00246d (diff) | |
download | spark-c0cc0eaec67208c087a30c1b1f50c00b2c1ebf08.tar.gz spark-c0cc0eaec67208c087a30c1b1f50c00b2c1ebf08.tar.bz2 spark-c0cc0eaec67208c087a30c1b1f50c00b2c1ebf08.zip |
[SPARK-9390][SQL] create a wrapper for array type
Author: Wenchen Fan <cloud0fan@outlook.com>
Closes #7724 from cloud-fan/array-data and squashes the following commits:
d0408a1 [Wenchen Fan] fix python
661e608 [Wenchen Fan] rebase
f39256c [Wenchen Fan] fix hive...
6dbfa6f [Wenchen Fan] fix hive again...
8cb8842 [Wenchen Fan] remove element type parameter from getArray
43e9816 [Wenchen Fan] fix mllib
e719afc [Wenchen Fan] fix hive
4346290 [Wenchen Fan] address comment
d4a38da [Wenchen Fan] remove sizeInBytes and add license
7e283e2 [Wenchen Fan] create a wrapper for array type
Diffstat (limited to 'sql/hive')
4 files changed, 28 insertions, 16 deletions
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala index f467500259..5926ef9aa3 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala @@ -52,9 +52,8 @@ import scala.collection.JavaConversions._ * java.sql.Timestamp * Complex Types => * Map: scala.collection.immutable.Map - * List: scala.collection.immutable.Seq - * Struct: - * [[org.apache.spark.sql.catalyst.InternalRow]] + * List: [[org.apache.spark.sql.types.ArrayData]] + * Struct: [[org.apache.spark.sql.catalyst.InternalRow]] * Union: NOT SUPPORTED YET * The Complex types plays as a container, which can hold arbitrary data types. * @@ -297,7 +296,10 @@ private[hive] trait HiveInspectors { }.toMap case li: StandardConstantListObjectInspector => // take the value from the list inspector object, rather than the input data - li.getWritableConstantValue.map(unwrap(_, li.getListElementObjectInspector)).toSeq + val values = li.getWritableConstantValue + .map(unwrap(_, li.getListElementObjectInspector)) + .toArray + new GenericArrayData(values) // if the value is null, we don't care about the object inspector type case _ if data == null => null case poi: VoidObjectInspector => null // always be null for void object inspector @@ -339,7 +341,10 @@ private[hive] trait HiveInspectors { } case li: ListObjectInspector => Option(li.getList(data)) - .map(_.map(unwrap(_, li.getListElementObjectInspector)).toSeq) + .map { l => + val values = l.map(unwrap(_, li.getListElementObjectInspector)).toArray + new GenericArrayData(values) + } .orNull case mi: MapObjectInspector => Option(mi.getMap(data)).map( @@ -391,7 +396,13 @@ private[hive] trait HiveInspectors { case loi: ListObjectInspector => val wrapper = wrapperFor(loi.getListElementObjectInspector) - (o: Any) => if (o != null) seqAsJavaList(o.asInstanceOf[Seq[_]].map(wrapper)) else null + (o: Any) => { + if (o != null) { + seqAsJavaList(o.asInstanceOf[ArrayData].toArray().map(wrapper)) + } else { + null + } + } case moi: MapObjectInspector => // The Predef.Map is scala.collection.immutable.Map. @@ -520,7 +531,7 @@ private[hive] trait HiveInspectors { case x: ListObjectInspector => val list = new java.util.ArrayList[Object] val tpe = dataType.asInstanceOf[ArrayType].elementType - a.asInstanceOf[Seq[_]].foreach { + a.asInstanceOf[ArrayData].toArray().foreach { v => list.add(wrap(v, x.getListElementObjectInspector, tpe)) } list @@ -634,7 +645,8 @@ private[hive] trait HiveInspectors { ObjectInspectorFactory.getStandardConstantListObjectInspector(listObjectInspector, null) } else { val list = new java.util.ArrayList[Object]() - value.asInstanceOf[Seq[_]].foreach(v => list.add(wrap(v, listObjectInspector, dt))) + value.asInstanceOf[ArrayData].toArray() + .foreach(v => list.add(wrap(v, listObjectInspector, dt))) ObjectInspectorFactory.getStandardConstantListObjectInspector(listObjectInspector, list) } case Literal(value, MapType(keyType, valueType, _)) => diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala index 741c705e2a..7e3342cc84 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala @@ -176,13 +176,13 @@ case class ScriptTransformation( val prevLine = curLine curLine = reader.readLine() if (!ioschema.schemaLess) { - new GenericInternalRow(CatalystTypeConverters.convertToCatalyst( - prevLine.split(ioschema.outputRowFormatMap("TOK_TABLEROWFORMATFIELD"))) - .asInstanceOf[Array[Any]]) + new GenericInternalRow( + prevLine.split(ioschema.outputRowFormatMap("TOK_TABLEROWFORMATFIELD")) + .map(CatalystTypeConverters.convertToCatalyst)) } else { - new GenericInternalRow(CatalystTypeConverters.convertToCatalyst( - prevLine.split(ioschema.outputRowFormatMap("TOK_TABLEROWFORMATFIELD"), 2)) - .asInstanceOf[Array[Any]]) + new GenericInternalRow( + prevLine.split(ioschema.outputRowFormatMap("TOK_TABLEROWFORMATFIELD"), 2) + .map(CatalystTypeConverters.convertToCatalyst)) } } else { val ret = deserialize() diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala index 8732e9abf8..4a13022edd 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala @@ -431,7 +431,7 @@ private[hive] case class HiveWindowFunction( // if pivotResult is true, we will get a Seq having the same size with the size // of the window frame. At here, we will return the result at the position of // index in the output buffer. - outputBuffer.asInstanceOf[Seq[Any]].get(index) + outputBuffer.asInstanceOf[ArrayData].get(index) } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveInspectorSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveInspectorSuite.scala index 0330013f53..f719f2e06a 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveInspectorSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveInspectorSuite.scala @@ -217,7 +217,7 @@ class HiveInspectorSuite extends SparkFunSuite with HiveInspectors { test("wrap / unwrap Array Type") { val dt = ArrayType(dataTypes(0)) - val d = row(0) :: row(0) :: Nil + val d = new GenericArrayData(Array(row(0), row(0))) checkValue(d, unwrap(wrap(d, toInspector(dt), dt), toInspector(dt))) checkValue(null, unwrap(wrap(null, toInspector(dt), dt), toInspector(dt))) checkValue(d, |