aboutsummaryrefslogtreecommitdiff
path: root/sql/core
diff options
context:
space:
mode:
authorWenchen Fan <cloud0fan@outlook.com>2015-07-30 10:04:30 -0700
committerReynold Xin <rxin@databricks.com>2015-07-30 10:04:30 -0700
commitc0cc0eaec67208c087a30c1b1f50c00b2c1ebf08 (patch)
tree582bad5631cde3bac3b5c69e1f22b3c4098de684 /sql/core
parent7492a33fdd074446c30c657d771a69932a00246d (diff)
downloadspark-c0cc0eaec67208c087a30c1b1f50c00b2c1ebf08.tar.gz
spark-c0cc0eaec67208c087a30c1b1f50c00b2c1ebf08.tar.bz2
spark-c0cc0eaec67208c087a30c1b1f50c00b2c1ebf08.zip
[SPARK-9390][SQL] create a wrapper for array type
Author: Wenchen Fan <cloud0fan@outlook.com> Closes #7724 from cloud-fan/array-data and squashes the following commits: d0408a1 [Wenchen Fan] fix python 661e608 [Wenchen Fan] rebase f39256c [Wenchen Fan] fix hive... 6dbfa6f [Wenchen Fan] fix hive again... 8cb8842 [Wenchen Fan] remove element type parameter from getArray 43e9816 [Wenchen Fan] fix mllib e719afc [Wenchen Fan] fix hive 4346290 [Wenchen Fan] address comment d4a38da [Wenchen Fan] remove sizeInBytes and add license 7e283e2 [Wenchen Fan] create a wrapper for array type
Diffstat (limited to 'sql/core')
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala4
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUDFs.scala19
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/stat/FrequentItems.scala4
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/json/InferSchema.scala2
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/json/JacksonParser.scala30
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/parquet/CatalystRowConverter.scala2
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala3
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala12
-rw-r--r--sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java5
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala8
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala30
11 files changed, 69 insertions, 50 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala
index aeeb0e4527..f26f41fb75 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala
@@ -158,8 +158,8 @@ package object debug {
case (row: InternalRow, StructType(fields)) =>
row.toSeq.zip(fields.map(_.dataType)).foreach { case(d, t) => typeCheck(d, t) }
- case (s: Seq[_], ArrayType(elemType, _)) =>
- s.foreach(typeCheck(_, elemType))
+ case (a: ArrayData, ArrayType(elemType, _)) =>
+ a.toArray().foreach(typeCheck(_, elemType))
case (m: Map[_, _], MapType(keyType, valueType, _)) =>
m.keys.foreach(typeCheck(_, keyType))
m.values.foreach(typeCheck(_, valueType))
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUDFs.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUDFs.scala
index 3c38916fd7..ef1c6e57dc 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUDFs.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUDFs.scala
@@ -134,8 +134,19 @@ object EvaluatePython {
}
new GenericInternalRowWithSchema(values, struct)
- case (seq: Seq[Any], array: ArrayType) =>
- seq.map(x => toJava(x, array.elementType)).asJava
+ case (a: ArrayData, array: ArrayType) =>
+ val length = a.numElements()
+ val values = new java.util.ArrayList[Any](length)
+ var i = 0
+ while (i < length) {
+ if (a.isNullAt(i)) {
+ values.add(null)
+ } else {
+ values.add(toJava(a.get(i), array.elementType))
+ }
+ i += 1
+ }
+ values
case (obj: Map[_, _], mt: MapType) => obj.map {
case (k, v) => (toJava(k, mt.keyType), toJava(v, mt.valueType))
@@ -190,10 +201,10 @@ object EvaluatePython {
case (c, BinaryType) if c.getClass.isArray && c.getClass.getComponentType.getName == "byte" => c
case (c: java.util.List[_], ArrayType(elementType, _)) =>
- c.map { e => fromJava(e, elementType)}.toSeq
+ new GenericArrayData(c.map { e => fromJava(e, elementType)}.toArray)
case (c, ArrayType(elementType, _)) if c.getClass.isArray =>
- c.asInstanceOf[Array[_]].map(e => fromJava(e, elementType)).toSeq
+ new GenericArrayData(c.asInstanceOf[Array[_]].map(e => fromJava(e, elementType)))
case (c: java.util.Map[_, _], MapType(keyType, valueType, _)) => c.map {
case (key, value) => (fromJava(key, keyType), fromJava(value, valueType))
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/FrequentItems.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/FrequentItems.scala
index 78da2840da..9329148aa2 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/FrequentItems.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/FrequentItems.scala
@@ -22,7 +22,7 @@ import scala.collection.mutable.{Map => MutableMap}
import org.apache.spark.Logging
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.plans.logical.LocalRelation
-import org.apache.spark.sql.types.{DataType, ArrayType, StructField, StructType}
+import org.apache.spark.sql.types._
import org.apache.spark.sql.{Column, DataFrame}
private[sql] object FrequentItems extends Logging {
@@ -110,7 +110,7 @@ private[sql] object FrequentItems extends Logging {
baseCounts
}
)
- val justItems = freqItems.map(m => m.baseMap.keys.toSeq)
+ val justItems = freqItems.map(m => m.baseMap.keys.toArray).map(new GenericArrayData(_))
val resultRow = InternalRow(justItems : _*)
// append frequent Items to the column name for easy debugging
val outputCols = colInfo.map { v =>
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/json/InferSchema.scala b/sql/core/src/main/scala/org/apache/spark/sql/json/InferSchema.scala
index 0eb3b04007..04ab5e2217 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/json/InferSchema.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/json/InferSchema.scala
@@ -125,7 +125,7 @@ private[sql] object InferSchema {
* Convert NullType to StringType and remove StructTypes with no fields
*/
private def canonicalizeType: DataType => Option[DataType] = {
- case at@ArrayType(elementType, _) =>
+ case at @ ArrayType(elementType, _) =>
for {
canonicalType <- canonicalizeType(elementType)
} yield {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/json/JacksonParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/json/JacksonParser.scala
index 381e7ed544..1c309f8794 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/json/JacksonParser.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/json/JacksonParser.scala
@@ -110,8 +110,13 @@ private[sql] object JacksonParser {
case (START_OBJECT, st: StructType) =>
convertObject(factory, parser, st)
+ case (START_ARRAY, st: StructType) =>
+ // SPARK-3308: support reading top level JSON arrays and take every element
+ // in such an array as a row
+ convertArray(factory, parser, st)
+
case (START_ARRAY, ArrayType(st, _)) =>
- convertList(factory, parser, st)
+ convertArray(factory, parser, st)
case (START_OBJECT, ArrayType(st, _)) =>
// the business end of SPARK-3308:
@@ -165,16 +170,16 @@ private[sql] object JacksonParser {
builder.result()
}
- private def convertList(
+ private def convertArray(
factory: JsonFactory,
parser: JsonParser,
- schema: DataType): Seq[Any] = {
- val builder = Seq.newBuilder[Any]
+ elementType: DataType): ArrayData = {
+ val values = scala.collection.mutable.ArrayBuffer.empty[Any]
while (nextUntil(parser, JsonToken.END_ARRAY)) {
- builder += convertField(factory, parser, schema)
+ values += convertField(factory, parser, elementType)
}
- builder.result()
+ new GenericArrayData(values.toArray)
}
private def parseJson(
@@ -201,12 +206,15 @@ private[sql] object JacksonParser {
val parser = factory.createParser(record)
parser.nextToken()
- // to support both object and arrays (see SPARK-3308) we'll start
- // by converting the StructType schema to an ArrayType and let
- // convertField wrap an object into a single value array when necessary.
- convertField(factory, parser, ArrayType(schema)) match {
+ convertField(factory, parser, schema) match {
case null => failedRecord(record)
- case list: Seq[InternalRow @unchecked] => list
+ case row: InternalRow => row :: Nil
+ case array: ArrayData =>
+ if (array.numElements() == 0) {
+ Nil
+ } else {
+ array.toArray().map(_.asInstanceOf[InternalRow])
+ }
case _ =>
sys.error(
s"Failed to parse record $record. Please make sure that each line of the file " +
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/CatalystRowConverter.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/CatalystRowConverter.scala
index e00bd90edb..172db8362a 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/CatalystRowConverter.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/CatalystRowConverter.scala
@@ -325,7 +325,7 @@ private[parquet] class CatalystRowConverter(
override def getConverter(fieldIndex: Int): Converter = elementConverter
- override def end(): Unit = updater.set(currentArray)
+ override def end(): Unit = updater.set(new GenericArrayData(currentArray.toArray))
// NOTE: We can't reuse the mutable `ArrayBuffer` here and must instantiate a new buffer for the
// next value. `Row.copy()` only copies row cells, it doesn't do deep copy to objects stored
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala
index ea51650fe9..2332a36468 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala
@@ -18,6 +18,7 @@
package org.apache.spark.sql.parquet
import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.types.ArrayData
// TODO Removes this while fixing SPARK-8848
private[sql] object CatalystConverter {
@@ -32,7 +33,7 @@ private[sql] object CatalystConverter {
val MAP_SCHEMA_NAME = "map"
// TODO: consider using Array[T] for arrays to avoid boxing of primitive types
- type ArrayScalaType[T] = Seq[T]
+ type ArrayScalaType[T] = ArrayData
type StructScalaType[T] = InternalRow
type MapScalaType[K, V] = Map[K, V]
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala
index 78ecfad1d5..79dd16b7b0 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala
@@ -146,15 +146,15 @@ private[parquet] class RowWriteSupport extends WriteSupport[InternalRow] with Lo
array: CatalystConverter.ArrayScalaType[_]): Unit = {
val elementType = schema.elementType
writer.startGroup()
- if (array.size > 0) {
+ if (array.numElements() > 0) {
if (schema.containsNull) {
writer.startField(CatalystConverter.ARRAY_CONTAINS_NULL_BAG_SCHEMA_NAME, 0)
var i = 0
- while (i < array.size) {
+ while (i < array.numElements()) {
writer.startGroup()
- if (array(i) != null) {
+ if (!array.isNullAt(i)) {
writer.startField(CatalystConverter.ARRAY_ELEMENTS_SCHEMA_NAME, 0)
- writeValue(elementType, array(i))
+ writeValue(elementType, array.get(i))
writer.endField(CatalystConverter.ARRAY_ELEMENTS_SCHEMA_NAME, 0)
}
writer.endGroup()
@@ -164,8 +164,8 @@ private[parquet] class RowWriteSupport extends WriteSupport[InternalRow] with Lo
} else {
writer.startField(CatalystConverter.ARRAY_ELEMENTS_SCHEMA_NAME, 0)
var i = 0
- while (i < array.size) {
- writeValue(elementType, array(i))
+ while (i < array.numElements()) {
+ writeValue(elementType, array.get(i))
i = i + 1
}
writer.endField(CatalystConverter.ARRAY_ELEMENTS_SCHEMA_NAME, 0)
diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java
index 72c42f4fe3..9e61d06f40 100644
--- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java
+++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java
@@ -30,7 +30,6 @@ import org.junit.*;
import scala.collection.JavaConversions;
import scala.collection.Seq;
-import scala.collection.mutable.Buffer;
import java.io.Serializable;
import java.util.Arrays;
@@ -168,10 +167,10 @@ public class JavaDataFrameSuite {
for (int i = 0; i < result.length(); i++) {
Assert.assertEquals(bean.getB()[i], result.apply(i));
}
- Buffer<Integer> outputBuffer = (Buffer<Integer>) first.getJavaMap(2).get("hello");
+ Seq<Integer> outputBuffer = (Seq<Integer>) first.getJavaMap(2).get("hello");
Assert.assertArrayEquals(
bean.getC().get("hello"),
- Ints.toArray(JavaConversions.bufferAsJavaList(outputBuffer)));
+ Ints.toArray(JavaConversions.seqAsJavaList(outputBuffer)));
Seq<String> d = first.getAs(3);
Assert.assertEquals(bean.getD().size(), d.length());
for (int i = 0; i < d.length(); i++) {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala
index 45c9f06941..77ed4a9c0d 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala
@@ -47,17 +47,17 @@ private[sql] class MyDenseVectorUDT extends UserDefinedType[MyDenseVector] {
override def sqlType: DataType = ArrayType(DoubleType, containsNull = false)
- override def serialize(obj: Any): Seq[Double] = {
+ override def serialize(obj: Any): ArrayData = {
obj match {
case features: MyDenseVector =>
- features.data.toSeq
+ new GenericArrayData(features.data.map(_.asInstanceOf[Any]))
}
}
override def deserialize(datum: Any): MyDenseVector = {
datum match {
- case data: Seq[_] =>
- new MyDenseVector(data.asInstanceOf[Seq[Double]].toArray)
+ case data: ArrayData =>
+ new MyDenseVector(data.toArray.map(_.asInstanceOf[Double]))
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala
index 5e189c3563..cfb03ff485 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala
@@ -67,12 +67,12 @@ case class AllDataTypesScan(
override def schema: StructType = userSpecifiedSchema
- override def needConversion: Boolean = false
+ override def needConversion: Boolean = true
override def buildScan(): RDD[Row] = {
sqlContext.sparkContext.parallelize(from to to).map { i =>
- InternalRow(
- UTF8String.fromString(s"str_$i"),
+ Row(
+ s"str_$i",
s"str_$i".getBytes(),
i % 2 == 0,
i.toByte,
@@ -81,19 +81,19 @@ case class AllDataTypesScan(
i.toLong,
i.toFloat,
i.toDouble,
- Decimal(new java.math.BigDecimal(i)),
- Decimal(new java.math.BigDecimal(i)),
- DateTimeUtils.fromJavaDate(new Date(1970, 1, 1)),
- DateTimeUtils.fromJavaTimestamp(new Timestamp(20000 + i)),
- UTF8String.fromString(s"varchar_$i"),
+ new java.math.BigDecimal(i),
+ new java.math.BigDecimal(i),
+ new Date(1970, 1, 1),
+ new Timestamp(20000 + i),
+ s"varchar_$i",
Seq(i, i + 1),
- Seq(Map(UTF8String.fromString(s"str_$i") -> InternalRow(i.toLong))),
- Map(i -> UTF8String.fromString(i.toString)),
- Map(Map(UTF8String.fromString(s"str_$i") -> i.toFloat) -> InternalRow(i.toLong)),
- InternalRow(i, UTF8String.fromString(i.toString)),
- InternalRow(Seq(UTF8String.fromString(s"str_$i"), UTF8String.fromString(s"str_${i + 1}")),
- InternalRow(Seq(DateTimeUtils.fromJavaDate(new Date(1970, 1, i + 1))))))
- }.asInstanceOf[RDD[Row]]
+ Seq(Map(s"str_$i" -> Row(i.toLong))),
+ Map(i -> i.toString),
+ Map(Map(s"str_$i" -> i.toFloat) -> Row(i.toLong)),
+ Row(i, i.toString),
+ Row(Seq(s"str_$i", s"str_${i + 1}"),
+ Row(Seq(new Date(1970, 1, i + 1)))))
+ }
}
}