aboutsummaryrefslogtreecommitdiff
path: root/sql/core
diff options
context:
space:
mode:
authorWenchen Fan <cloud0fan@outlook.com>2015-08-01 00:17:15 -0700
committerReynold Xin <rxin@databricks.com>2015-08-01 00:17:15 -0700
commit1d59a4162bf5142af270ed7f4b3eab42870c87b7 (patch)
tree98e1c51aafb41c2c64042b30d3ddcf2205da1414 /sql/core
parentd90f2cf7a2a1d1e69f9ab385f35f62d4091b5302 (diff)
downloadspark-1d59a4162bf5142af270ed7f4b3eab42870c87b7.tar.gz
spark-1d59a4162bf5142af270ed7f4b3eab42870c87b7.tar.bz2
spark-1d59a4162bf5142af270ed7f4b3eab42870c87b7.zip
[SPARK-9480][SQL] add MapData and cleanup internal row stuff
This PR adds a `MapData` as internal representation of map type in Spark SQL, and provides a default implementation with just 2 `ArrayData`. After that, we have specialized getters for all internal type, so I removed generic getter in `ArrayData` and added specialized `toArray` for it. Also did some refactor and cleanup for `InternalRow` and its subclasses. Author: Wenchen Fan <cloud0fan@outlook.com> Closes #7799 from cloud-fan/map-data and squashes the following commits: 77d482f [Wenchen Fan] fix python e8f6682 [Wenchen Fan] skip MapData equality check in HiveInspectorSuite 40cc9db [Wenchen Fan] add toString 6e06ec9 [Wenchen Fan] some more cleanup a90aca1 [Wenchen Fan] add MapData
Diffstat (limited to 'sql/core')
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala14
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUDFs.scala31
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/json/JacksonParser.scala21
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/parquet/CatalystRowConverter.scala16
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala8
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala23
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala1
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/RowSuite.scala4
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala34
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/TestData.scala4
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala2
11 files changed, 78 insertions, 80 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala
index f26f41fb75..c37007f1ee 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala
@@ -159,10 +159,16 @@ package object debug {
case (row: InternalRow, StructType(fields)) =>
row.toSeq.zip(fields.map(_.dataType)).foreach { case(d, t) => typeCheck(d, t) }
case (a: ArrayData, ArrayType(elemType, _)) =>
- a.toArray().foreach(typeCheck(_, elemType))
- case (m: Map[_, _], MapType(keyType, valueType, _)) =>
- m.keys.foreach(typeCheck(_, keyType))
- m.values.foreach(typeCheck(_, valueType))
+ a.foreach(elemType, (_, e) => {
+ typeCheck(e, elemType)
+ })
+ case (m: MapData, MapType(keyType, valueType, _)) =>
+ m.keyArray().foreach(keyType, (_, e) => {
+ typeCheck(e, keyType)
+ })
+ m.valueArray().foreach(valueType, (_, e) => {
+ typeCheck(e, valueType)
+ })
case (_: Long, LongType) =>
case (_: Int, IntegerType) =>
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUDFs.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUDFs.scala
index ef1c6e57dc..aade2e769c 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUDFs.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUDFs.scala
@@ -135,22 +135,18 @@ object EvaluatePython {
new GenericInternalRowWithSchema(values, struct)
case (a: ArrayData, array: ArrayType) =>
- val length = a.numElements()
- val values = new java.util.ArrayList[Any](length)
- var i = 0
- while (i < length) {
- if (a.isNullAt(i)) {
- values.add(null)
- } else {
- values.add(toJava(a.get(i), array.elementType))
- }
- i += 1
- }
+ val values = new java.util.ArrayList[Any](a.numElements())
+ a.foreach(array.elementType, (_, e) => {
+ values.add(toJava(e, array.elementType))
+ })
values
- case (obj: Map[_, _], mt: MapType) => obj.map {
- case (k, v) => (toJava(k, mt.keyType), toJava(v, mt.valueType))
- }.asJava
+ case (map: MapData, mt: MapType) =>
+ val jmap = new java.util.HashMap[Any, Any](map.numElements())
+ map.foreach(mt.keyType, mt.valueType, (k, v) => {
+ jmap.put(toJava(k, mt.keyType), toJava(v, mt.valueType))
+ })
+ jmap
case (ud, udt: UserDefinedType[_]) => toJava(ud, udt.sqlType)
@@ -206,9 +202,10 @@ object EvaluatePython {
case (c, ArrayType(elementType, _)) if c.getClass.isArray =>
new GenericArrayData(c.asInstanceOf[Array[_]].map(e => fromJava(e, elementType)))
- case (c: java.util.Map[_, _], MapType(keyType, valueType, _)) => c.map {
- case (key, value) => (fromJava(key, keyType), fromJava(value, valueType))
- }.toMap
+ case (c: java.util.Map[_, _], MapType(keyType, valueType, _)) =>
+ val keys = c.keysIterator.map(fromJava(_, keyType)).toArray
+ val values = c.valuesIterator.map(fromJava(_, valueType)).toArray
+ ArrayBasedMapData(keys, values)
case (c, StructType(fields)) if c.getClass.isArray =>
new GenericInternalRow(c.asInstanceOf[Array[_]].zip(fields).map {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/json/JacksonParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/json/JacksonParser.scala
index 1c309f8794..bf0448ee96 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/json/JacksonParser.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/json/JacksonParser.scala
@@ -19,10 +19,10 @@ package org.apache.spark.sql.json
import java.io.ByteArrayOutputStream
-import scala.collection.Map
-
import com.fasterxml.jackson.core._
+import scala.collection.mutable.ArrayBuffer
+
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
@@ -31,7 +31,6 @@ import org.apache.spark.sql.json.JacksonUtils.nextUntil
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String
-
private[sql] object JacksonParser {
def apply(
json: RDD[String],
@@ -160,21 +159,21 @@ private[sql] object JacksonParser {
private def convertMap(
factory: JsonFactory,
parser: JsonParser,
- valueType: DataType): Map[UTF8String, Any] = {
- val builder = Map.newBuilder[UTF8String, Any]
+ valueType: DataType): MapData = {
+ val keys = ArrayBuffer.empty[UTF8String]
+ val values = ArrayBuffer.empty[Any]
while (nextUntil(parser, JsonToken.END_OBJECT)) {
- builder +=
- UTF8String.fromString(parser.getCurrentName) -> convertField(factory, parser, valueType)
+ keys += UTF8String.fromString(parser.getCurrentName)
+ values += convertField(factory, parser, valueType)
}
-
- builder.result()
+ ArrayBasedMapData(keys.toArray, values.toArray)
}
private def convertArray(
factory: JsonFactory,
parser: JsonParser,
elementType: DataType): ArrayData = {
- val values = scala.collection.mutable.ArrayBuffer.empty[Any]
+ val values = ArrayBuffer.empty[Any]
while (nextUntil(parser, JsonToken.END_ARRAY)) {
values += convertField(factory, parser, elementType)
}
@@ -213,7 +212,7 @@ private[sql] object JacksonParser {
if (array.numElements() == 0) {
Nil
} else {
- array.toArray().map(_.asInstanceOf[InternalRow])
+ array.toArray[InternalRow](schema)
}
case _ =>
sys.error(
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/CatalystRowConverter.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/CatalystRowConverter.scala
index 172db8362a..6938b07106 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/CatalystRowConverter.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/CatalystRowConverter.scala
@@ -385,7 +385,8 @@ private[parquet] class CatalystRowConverter(
updater: ParentContainerUpdater)
extends GroupConverter {
- private var currentMap: mutable.Map[Any, Any] = _
+ private var currentKeys: ArrayBuffer[Any] = _
+ private var currentValues: ArrayBuffer[Any] = _
private val keyValueConverter = {
val repeatedType = parquetType.getType(0).asGroupType()
@@ -398,12 +399,16 @@ private[parquet] class CatalystRowConverter(
override def getConverter(fieldIndex: Int): Converter = keyValueConverter
- override def end(): Unit = updater.set(currentMap)
+ override def end(): Unit =
+ updater.set(ArrayBasedMapData(currentKeys.toArray, currentValues.toArray))
// NOTE: We can't reuse the mutable Map here and must instantiate a new `Map` for the next
// value. `Row.copy()` only copies row cells, it doesn't do deep copy to objects stored in row
// cells.
- override def start(): Unit = currentMap = mutable.Map.empty[Any, Any]
+ override def start(): Unit = {
+ currentKeys = ArrayBuffer.empty[Any]
+ currentValues = ArrayBuffer.empty[Any]
+ }
/** Parquet converter for key-value pairs within the map. */
private final class KeyValueConverter(
@@ -430,7 +435,10 @@ private[parquet] class CatalystRowConverter(
override def getConverter(fieldIndex: Int): Converter = converters(fieldIndex)
- override def end(): Unit = currentMap(currentKey) = currentValue
+ override def end(): Unit = {
+ currentKeys += currentKey
+ currentValues += currentValue
+ }
override def start(): Unit = {
currentKey = null
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala
index 2332a36468..6ed3580af0 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala
@@ -18,7 +18,7 @@
package org.apache.spark.sql.parquet
import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.types.ArrayData
+import org.apache.spark.sql.types.{MapData, ArrayData}
// TODO Removes this while fixing SPARK-8848
private[sql] object CatalystConverter {
@@ -33,7 +33,7 @@ private[sql] object CatalystConverter {
val MAP_SCHEMA_NAME = "map"
// TODO: consider using Array[T] for arrays to avoid boxing of primitive types
- type ArrayScalaType[T] = ArrayData
- type StructScalaType[T] = InternalRow
- type MapScalaType[K, V] = Map[K, V]
+ type ArrayScalaType = ArrayData
+ type StructScalaType = InternalRow
+ type MapScalaType = MapData
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala
index ec8da38a3d..9cd0250f9c 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala
@@ -88,13 +88,13 @@ private[parquet] class RowWriteSupport extends WriteSupport[InternalRow] with Lo
case t: UserDefinedType[_] => writeValue(t.sqlType, value)
case t @ ArrayType(_, _) => writeArray(
t,
- value.asInstanceOf[CatalystConverter.ArrayScalaType[_]])
+ value.asInstanceOf[CatalystConverter.ArrayScalaType])
case t @ MapType(_, _, _) => writeMap(
t,
- value.asInstanceOf[CatalystConverter.MapScalaType[_, _]])
+ value.asInstanceOf[CatalystConverter.MapScalaType])
case t @ StructType(_) => writeStruct(
t,
- value.asInstanceOf[CatalystConverter.StructScalaType[_]])
+ value.asInstanceOf[CatalystConverter.StructScalaType])
case _ => writePrimitive(schema.asInstanceOf[AtomicType], value)
}
}
@@ -124,7 +124,7 @@ private[parquet] class RowWriteSupport extends WriteSupport[InternalRow] with Lo
private[parquet] def writeStruct(
schema: StructType,
- struct: CatalystConverter.StructScalaType[_]): Unit = {
+ struct: CatalystConverter.StructScalaType): Unit = {
if (struct != null) {
val fields = schema.fields.toArray
writer.startGroup()
@@ -143,7 +143,7 @@ private[parquet] class RowWriteSupport extends WriteSupport[InternalRow] with Lo
private[parquet] def writeArray(
schema: ArrayType,
- array: CatalystConverter.ArrayScalaType[_]): Unit = {
+ array: CatalystConverter.ArrayScalaType): Unit = {
val elementType = schema.elementType
writer.startGroup()
if (array.numElements() > 0) {
@@ -154,7 +154,7 @@ private[parquet] class RowWriteSupport extends WriteSupport[InternalRow] with Lo
writer.startGroup()
if (!array.isNullAt(i)) {
writer.startField(CatalystConverter.ARRAY_ELEMENTS_SCHEMA_NAME, 0)
- writeValue(elementType, array.get(i))
+ writeValue(elementType, array.get(i, elementType))
writer.endField(CatalystConverter.ARRAY_ELEMENTS_SCHEMA_NAME, 0)
}
writer.endGroup()
@@ -165,7 +165,7 @@ private[parquet] class RowWriteSupport extends WriteSupport[InternalRow] with Lo
writer.startField(CatalystConverter.ARRAY_ELEMENTS_SCHEMA_NAME, 0)
var i = 0
while (i < array.numElements()) {
- writeValue(elementType, array.get(i))
+ writeValue(elementType, array.get(i, elementType))
i = i + 1
}
writer.endField(CatalystConverter.ARRAY_ELEMENTS_SCHEMA_NAME, 0)
@@ -176,11 +176,12 @@ private[parquet] class RowWriteSupport extends WriteSupport[InternalRow] with Lo
private[parquet] def writeMap(
schema: MapType,
- map: CatalystConverter.MapScalaType[_, _]): Unit = {
+ map: CatalystConverter.MapScalaType): Unit = {
writer.startGroup()
- if (map.size > 0) {
+ val length = map.numElements()
+ if (length > 0) {
writer.startField(CatalystConverter.MAP_SCHEMA_NAME, 0)
- for ((key, value) <- map) {
+ map.foreach(schema.keyType, schema.valueType, (key, value) => {
writer.startGroup()
writer.startField(CatalystConverter.MAP_KEY_SCHEMA_NAME, 0)
writeValue(schema.keyType, key)
@@ -191,7 +192,7 @@ private[parquet] class RowWriteSupport extends WriteSupport[InternalRow] with Lo
writer.endField(CatalystConverter.MAP_VALUE_SCHEMA_NAME, 1)
}
writer.endGroup()
- }
+ })
writer.endField(CatalystConverter.MAP_SCHEMA_NAME, 0)
}
writer.endGroup()
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
index 97beae2f85..aef940a526 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
@@ -620,6 +620,7 @@ class DataFrameSuite extends QueryTest with SQLTestUtils {
assert(complexData.filter(complexData("m")("1") === 1).count() == 1)
assert(complexData.filter(complexData("s")("key") === 1).count() == 1)
assert(complexData.filter(complexData("m")(complexData("s")("value")) === 1).count() == 1)
+ assert(complexData.filter(complexData("a")(complexData("s")("key")) === 1).count() == 1)
}
test("SPARK-7551: support backticks for DataFrame attribute resolution") {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/RowSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/RowSuite.scala
index 01b7c21e84..8a679c7865 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/RowSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/RowSuite.scala
@@ -19,9 +19,9 @@ package org.apache.spark.sql
import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.execution.SparkSqlSerializer
-
import org.apache.spark.sql.catalyst.expressions.{GenericMutableRow, SpecificMutableRow}
import org.apache.spark.sql.types._
+import org.apache.spark.unsafe.types.UTF8String
class RowSuite extends SparkFunSuite {
@@ -31,7 +31,7 @@ class RowSuite extends SparkFunSuite {
test("create row") {
val expected = new GenericMutableRow(4)
expected.setInt(0, 2147483647)
- expected.setString(1, "this is a string")
+ expected.update(1, UTF8String.fromString("this is a string"))
expected.setBoolean(2, false)
expected.setNullAt(3)
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
index 535011fe3d..51fe9d9d98 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
@@ -581,42 +581,28 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils {
}
test("sorting") {
- val before = sqlContext.conf.externalSortEnabled
- sqlContext.setConf(SQLConf.EXTERNAL_SORT, false)
- sortTest()
- sqlContext.setConf(SQLConf.EXTERNAL_SORT, before)
+ withSQLConf(SQLConf.EXTERNAL_SORT.key -> "false") {
+ sortTest()
+ }
}
test("external sorting") {
- val before = sqlContext.conf.externalSortEnabled
- sqlContext.setConf(SQLConf.EXTERNAL_SORT, true)
- sortTest()
- sqlContext.setConf(SQLConf.EXTERNAL_SORT, before)
+ withSQLConf(SQLConf.EXTERNAL_SORT.key -> "true") {
+ sortTest()
+ }
}
test("SPARK-6927 sorting with codegen on") {
- val externalbefore = sqlContext.conf.externalSortEnabled
- val codegenbefore = sqlContext.conf.codegenEnabled
- sqlContext.setConf(SQLConf.EXTERNAL_SORT, false)
- sqlContext.setConf(SQLConf.CODEGEN_ENABLED, true)
- try{
+ withSQLConf(SQLConf.EXTERNAL_SORT.key -> "false",
+ SQLConf.CODEGEN_ENABLED.key -> "true") {
sortTest()
- } finally {
- sqlContext.setConf(SQLConf.EXTERNAL_SORT, externalbefore)
- sqlContext.setConf(SQLConf.CODEGEN_ENABLED, codegenbefore)
}
}
test("SPARK-6927 external sorting with codegen on") {
- val externalbefore = sqlContext.conf.externalSortEnabled
- val codegenbefore = sqlContext.conf.codegenEnabled
- sqlContext.setConf(SQLConf.CODEGEN_ENABLED, true)
- sqlContext.setConf(SQLConf.EXTERNAL_SORT, true)
- try {
+ withSQLConf(SQLConf.EXTERNAL_SORT.key -> "true",
+ SQLConf.CODEGEN_ENABLED.key -> "true") {
sortTest()
- } finally {
- sqlContext.setConf(SQLConf.EXTERNAL_SORT, externalbefore)
- sqlContext.setConf(SQLConf.CODEGEN_ENABLED, codegenbefore)
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala b/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala
index e340f54850..bd9729c431 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala
@@ -190,8 +190,8 @@ object TestData {
case class ComplexData(m: Map[String, Int], s: TestData, a: Seq[Int], b: Boolean)
val complexData =
TestSQLContext.sparkContext.parallelize(
- ComplexData(Map("1" -> 1), TestData(1, "1"), Seq(1), true)
- :: ComplexData(Map("2" -> 2), TestData(2, "2"), Seq(2), false)
+ ComplexData(Map("1" -> 1), TestData(1, "1"), Seq(1, 1, 1), true)
+ :: ComplexData(Map("2" -> 2), TestData(2, "2"), Seq(2, 2, 2), false)
:: Nil).toDF()
complexData.registerTempTable("complexData")
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala
index 77ed4a9c0d..f29935224e 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala
@@ -57,7 +57,7 @@ private[sql] class MyDenseVectorUDT extends UserDefinedType[MyDenseVector] {
override def deserialize(datum: Any): MyDenseVector = {
datum match {
case data: ArrayData =>
- new MyDenseVector(data.toArray.map(_.asInstanceOf[Double]))
+ new MyDenseVector(data.toDoubleArray())
}
}