aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
authorWenchen Fan <cloud0fan@outlook.com>2015-08-01 00:17:15 -0700
committerReynold Xin <rxin@databricks.com>2015-08-01 00:17:15 -0700
commit1d59a4162bf5142af270ed7f4b3eab42870c87b7 (patch)
tree98e1c51aafb41c2c64042b30d3ddcf2205da1414 /sql
parentd90f2cf7a2a1d1e69f9ab385f35f62d4091b5302 (diff)
downloadspark-1d59a4162bf5142af270ed7f4b3eab42870c87b7.tar.gz
spark-1d59a4162bf5142af270ed7f4b3eab42870c87b7.tar.bz2
spark-1d59a4162bf5142af270ed7f4b3eab42870c87b7.zip
[SPARK-9480][SQL] add MapData and cleanup internal row stuff
This PR adds a `MapData` as internal representation of map type in Spark SQL, and provides a default implementation with just 2 `ArrayData`. After that, we have specialized getters for all internal type, so I removed generic getter in `ArrayData` and added specialized `toArray` for it. Also did some refactor and cleanup for `InternalRow` and its subclasses. Author: Wenchen Fan <cloud0fan@outlook.com> Closes #7799 from cloud-fan/map-data and squashes the following commits: 77d482f [Wenchen Fan] fix python e8f6682 [Wenchen Fan] skip MapData equality check in HiveInspectorSuite 40cc9db [Wenchen Fan] add toString 6e06ec9 [Wenchen Fan] some more cleanup a90aca1 [Wenchen Fan] add MapData
Diffstat (limited to 'sql')
-rw-r--r--sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/SpecializedGetters.java6
-rw-r--r--sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java12
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala79
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala117
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala3
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala101
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/GenericSpecializedGetters.scala61
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala69
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala13
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala11
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala2
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala8
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala81
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala29
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala44
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala2
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayBasedMapData.scala51
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayData.scala155
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/types/GenericArrayData.scala116
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/types/MapData.scala38
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala4
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala2
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala14
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUDFs.scala31
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/json/JacksonParser.scala21
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/parquet/CatalystRowConverter.scala16
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala8
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala23
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala1
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/RowSuite.scala4
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala34
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/TestData.scala4
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala2
-rw-r--r--sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala89
-rw-r--r--sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala4
-rw-r--r--sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala4
-rw-r--r--sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala7
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveInspectorSuite.scala4
38 files changed, 744 insertions, 526 deletions
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/SpecializedGetters.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/SpecializedGetters.java
index e3d3ba7a9c..8f1027f316 100644
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/SpecializedGetters.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/SpecializedGetters.java
@@ -19,7 +19,9 @@ package org.apache.spark.sql.catalyst.expressions;
import org.apache.spark.sql.catalyst.InternalRow;
import org.apache.spark.sql.types.ArrayData;
+import org.apache.spark.sql.types.DataType;
import org.apache.spark.sql.types.Decimal;
+import org.apache.spark.sql.types.MapData;
import org.apache.spark.unsafe.types.CalendarInterval;
import org.apache.spark.unsafe.types.UTF8String;
@@ -52,4 +54,8 @@ public interface SpecializedGetters {
InternalRow getStruct(int ordinal, int numFields);
ArrayData getArray(int ordinal);
+
+ MapData getMap(int ordinal);
+
+ Object get(int ordinal, DataType dataType);
}
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java
index 24dc80b1a7..5a19aa8920 100644
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java
@@ -118,6 +118,11 @@ public final class UnsafeRow extends MutableRow {
return baseOffset + bitSetWidthInBytes + ordinal * 8L;
}
+ private void assertIndexIsValid(int index) {
+ assert index >= 0 : "index (" + index + ") should >= 0";
+ assert index < numFields : "index (" + index + ") should < " + numFields;
+ }
+
//////////////////////////////////////////////////////////////////////////////
// Public methods
//////////////////////////////////////////////////////////////////////////////
@@ -163,11 +168,6 @@ public final class UnsafeRow extends MutableRow {
pointTo(buf, PlatformDependent.BYTE_ARRAY_OFFSET, numFields, sizeInBytes);
}
- private void assertIndexIsValid(int index) {
- assert index >= 0 : "index (" + index + ") should >= 0";
- assert index < numFields : "index (" + index + ") should < " + numFields;
- }
-
@Override
public void setNullAt(int i) {
assertIndexIsValid(i);
@@ -254,7 +254,7 @@ public final class UnsafeRow extends MutableRow {
}
@Override
- public Object get(int ordinal) {
+ public Object genericGet(int ordinal) {
throw new UnsupportedOperationException();
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala
index 7ca20fe97f..c666864e43 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala
@@ -23,7 +23,6 @@ import java.sql.{Date, Timestamp}
import java.util.{Map => JavaMap}
import javax.annotation.Nullable
-import scala.collection.mutable.HashMap
import scala.language.existentials
import org.apache.spark.sql.Row
@@ -53,12 +52,6 @@ object CatalystTypeConverters {
}
}
- private def isWholePrimitive(dt: DataType): Boolean = dt match {
- case dt if isPrimitive(dt) => true
- case MapType(keyType, valueType, _) => isWholePrimitive(keyType) && isWholePrimitive(valueType)
- case _ => false
- }
-
private def getConverterForType(dataType: DataType): CatalystTypeConverter[Any, Any, Any] = {
val converter = dataType match {
case udt: UserDefinedType[_] => UDTConverter(udt)
@@ -157,8 +150,6 @@ object CatalystTypeConverters {
private[this] val elementConverter = getConverterForType(elementType)
- private[this] val isNoChange = isWholePrimitive(elementType)
-
override def toCatalystImpl(scalaValue: Any): ArrayData = {
scalaValue match {
case a: Array[_] =>
@@ -179,10 +170,14 @@ object CatalystTypeConverters {
override def toScala(catalystValue: ArrayData): Seq[Any] = {
if (catalystValue == null) {
null
- } else if (isNoChange) {
- catalystValue.toArray()
+ } else if (isPrimitive(elementType)) {
+ catalystValue.toArray[Any](elementType)
} else {
- catalystValue.toArray().map(elementConverter.toScala)
+ val result = new Array[Any](catalystValue.numElements())
+ catalystValue.foreach(elementType, (i, e) => {
+ result(i) = elementConverter.toScala(e)
+ })
+ result
}
}
@@ -193,44 +188,58 @@ object CatalystTypeConverters {
private case class MapConverter(
keyType: DataType,
valueType: DataType)
- extends CatalystTypeConverter[Any, Map[Any, Any], Map[Any, Any]] {
+ extends CatalystTypeConverter[Any, Map[Any, Any], MapData] {
private[this] val keyConverter = getConverterForType(keyType)
private[this] val valueConverter = getConverterForType(valueType)
- private[this] val isNoChange = isWholePrimitive(keyType) && isWholePrimitive(valueType)
-
- override def toCatalystImpl(scalaValue: Any): Map[Any, Any] = scalaValue match {
+ override def toCatalystImpl(scalaValue: Any): MapData = scalaValue match {
case m: Map[_, _] =>
- m.map { case (k, v) =>
- keyConverter.toCatalyst(k) -> valueConverter.toCatalyst(v)
+ val length = m.size
+ val convertedKeys = new Array[Any](length)
+ val convertedValues = new Array[Any](length)
+
+ var i = 0
+ for ((key, value) <- m) {
+ convertedKeys(i) = keyConverter.toCatalyst(key)
+ convertedValues(i) = valueConverter.toCatalyst(value)
+ i += 1
}
+ ArrayBasedMapData(convertedKeys, convertedValues)
case jmap: JavaMap[_, _] =>
+ val length = jmap.size()
+ val convertedKeys = new Array[Any](length)
+ val convertedValues = new Array[Any](length)
+
+ var i = 0
val iter = jmap.entrySet.iterator
- val convertedMap: HashMap[Any, Any] = HashMap()
while (iter.hasNext) {
val entry = iter.next()
- val key = keyConverter.toCatalyst(entry.getKey)
- convertedMap(key) = valueConverter.toCatalyst(entry.getValue)
+ convertedKeys(i) = keyConverter.toCatalyst(entry.getKey)
+ convertedValues(i) = valueConverter.toCatalyst(entry.getValue)
+ i += 1
}
- convertedMap
+ ArrayBasedMapData(convertedKeys, convertedValues)
}
- override def toScala(catalystValue: Map[Any, Any]): Map[Any, Any] = {
+ override def toScala(catalystValue: MapData): Map[Any, Any] = {
if (catalystValue == null) {
null
- } else if (isNoChange) {
- catalystValue
} else {
- catalystValue.map { case (k, v) =>
- keyConverter.toScala(k) -> valueConverter.toScala(v)
- }
+ val keys = catalystValue.keyArray().toArray[Any](keyType)
+ val values = catalystValue.valueArray().toArray[Any](valueType)
+ val convertedKeys =
+ if (isPrimitive(keyType)) keys else keys.map(keyConverter.toScala)
+ val convertedValues =
+ if (isPrimitive(valueType)) values else values.map(valueConverter.toScala)
+
+ convertedKeys.zip(convertedValues).toMap
}
}
override def toScalaImpl(row: InternalRow, column: Int): Map[Any, Any] =
- toScala(row.get(column, MapType(keyType, valueType)).asInstanceOf[Map[Any, Any]])
+ toScala(row.getMap(column))
}
private case class StructConverter(
@@ -410,7 +419,17 @@ object CatalystTypeConverters {
case r: Row => InternalRow(r.toSeq.map(convertToCatalyst): _*)
case arr: Array[Any] => new GenericArrayData(arr.map(convertToCatalyst))
case m: Map[_, _] =>
- m.map { case (k, v) => (convertToCatalyst(k), convertToCatalyst(v)) }.toMap
+ val length = m.size
+ val convertedKeys = new Array[Any](length)
+ val convertedValues = new Array[Any](length)
+
+ var i = 0
+ for ((key, value) <- m) {
+ convertedKeys(i) = convertToCatalyst(key)
+ convertedValues(i) = convertToCatalyst(value)
+ i += 1
+ }
+ ArrayBasedMapData(convertedKeys, convertedValues)
case other => other
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala
index b19bf4386b..7656d054dc 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala
@@ -19,71 +19,25 @@ package org.apache.spark.sql.catalyst
import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.types._
-import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String}
/**
* An abstract class for row used internal in Spark SQL, which only contain the columns as
* internal types.
*/
-abstract class InternalRow extends Serializable with SpecializedGetters {
+// todo: make InternalRow just extends SpecializedGetters, remove generic getter
+abstract class InternalRow extends GenericSpecializedGetters with Serializable {
def numFields: Int
- def get(ordinal: Int): Any = get(ordinal, null)
-
- def genericGet(ordinal: Int): Any = get(ordinal, null)
-
- def get(ordinal: Int, dataType: DataType): Any
-
- def getAs[T](ordinal: Int, dataType: DataType): T = get(ordinal, dataType).asInstanceOf[T]
-
- override def isNullAt(ordinal: Int): Boolean = get(ordinal) == null
-
- override def getBoolean(ordinal: Int): Boolean = getAs[Boolean](ordinal, BooleanType)
-
- override def getByte(ordinal: Int): Byte = getAs[Byte](ordinal, ByteType)
-
- override def getShort(ordinal: Int): Short = getAs[Short](ordinal, ShortType)
-
- override def getInt(ordinal: Int): Int = getAs[Int](ordinal, IntegerType)
-
- override def getLong(ordinal: Int): Long = getAs[Long](ordinal, LongType)
-
- override def getFloat(ordinal: Int): Float = getAs[Float](ordinal, FloatType)
-
- override def getDouble(ordinal: Int): Double = getAs[Double](ordinal, DoubleType)
-
- override def getUTF8String(ordinal: Int): UTF8String = getAs[UTF8String](ordinal, StringType)
-
- override def getBinary(ordinal: Int): Array[Byte] = getAs[Array[Byte]](ordinal, BinaryType)
-
- override def getDecimal(ordinal: Int, precision: Int, scale: Int): Decimal =
- getAs[Decimal](ordinal, DecimalType(precision, scale))
-
- override def getInterval(ordinal: Int): CalendarInterval =
- getAs[CalendarInterval](ordinal, CalendarIntervalType)
-
// This is only use for test and will throw a null pointer exception if the position is null.
def getString(ordinal: Int): String = getUTF8String(ordinal).toString
- /**
- * Returns a struct from ordinal position.
- *
- * @param ordinal position to get the struct from.
- * @param numFields number of fields the struct type has
- */
- override def getStruct(ordinal: Int, numFields: Int): InternalRow =
- getAs[InternalRow](ordinal, null)
-
- override def getArray(ordinal: Int): ArrayData = getAs(ordinal, null)
-
- override def toString: String = s"[${this.mkString(",")}]"
+ override def toString: String = mkString("[", ",", "]")
/**
* Make a copy of the current [[InternalRow]] object.
*/
- def copy(): InternalRow = this
+ def copy(): InternalRow
/** Returns true if there are any NULL values in this row. */
def anyNull: Boolean = {
@@ -117,8 +71,8 @@ abstract class InternalRow extends Serializable with SpecializedGetters {
return false
}
if (!isNullAt(i)) {
- val o1 = get(i)
- val o2 = other.get(i)
+ val o1 = genericGet(i)
+ val o2 = other.genericGet(i)
o1 match {
case b1: Array[Byte] =>
if (!o2.isInstanceOf[Array[Byte]] ||
@@ -143,34 +97,6 @@ abstract class InternalRow extends Serializable with SpecializedGetters {
true
}
- /* ---------------------- utility methods for Scala ---------------------- */
-
- /**
- * Return a Scala Seq representing the row. Elements are placed in the same order in the Seq.
- */
- def toSeq: Seq[Any] = {
- val n = numFields
- val values = new Array[Any](n)
- var i = 0
- while (i < n) {
- values.update(i, get(i))
- i += 1
- }
- values.toSeq
- }
-
- /** Displays all elements of this sequence in a string (without a separator). */
- def mkString: String = toSeq.mkString
-
- /** Displays all elements of this sequence in a string using a separator string. */
- def mkString(sep: String): String = toSeq.mkString(sep)
-
- /**
- * Displays all elements of this traversable or iterator in a string using
- * start, end, and separator strings.
- */
- def mkString(start: String, sep: String, end: String): String = toSeq.mkString(start, sep, end)
-
// Custom hashCode function that matches the efficient code generated version.
override def hashCode: Int = {
var result: Int = 37
@@ -181,7 +107,7 @@ abstract class InternalRow extends Serializable with SpecializedGetters {
if (isNullAt(i)) {
0
} else {
- get(i) match {
+ genericGet(i) match {
case b: Boolean => if (b) 0 else 1
case b: Byte => b.toInt
case s: Short => s.toInt
@@ -200,6 +126,35 @@ abstract class InternalRow extends Serializable with SpecializedGetters {
}
result
}
+
+ /* ---------------------- utility methods for Scala ---------------------- */
+
+ /**
+ * Return a Scala Seq representing the row. Elements are placed in the same order in the Seq.
+ */
+ // todo: remove this as it needs the generic getter
+ def toSeq: Seq[Any] = {
+ val n = numFields
+ val values = new Array[Any](n)
+ var i = 0
+ while (i < n) {
+ values.update(i, genericGet(i))
+ i += 1
+ }
+ values
+ }
+
+ /** Displays all elements of this sequence in a string (without a separator). */
+ def mkString: String = toSeq.mkString
+
+ /** Displays all elements of this sequence in a string using a separator string. */
+ def mkString(sep: String): String = toSeq.mkString(sep)
+
+ /**
+ * Displays all elements of this traversable or iterator in a string using
+ * start, end, and separator strings.
+ */
+ def mkString(start: String, sep: String, end: String): String = toSeq.mkString(start, sep, end)
}
object InternalRow {
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala
index 45709c1c8f..473b9b7870 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala
@@ -49,7 +49,10 @@ case class BoundReference(ordinal: Int, dataType: DataType, nullable: Boolean)
case StringType => input.getUTF8String(ordinal)
case BinaryType => input.getBinary(ordinal)
case CalendarIntervalType => input.getInterval(ordinal)
+ case t: DecimalType => input.getDecimal(ordinal, t.precision, t.scale)
case t: StructType => input.getStruct(ordinal, t.size)
+ case _: ArrayType => input.getArray(ordinal)
+ case _: MapType => input.getMap(ordinal)
case _ => input.get(ordinal, dataType)
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
index 43be11c48a..88429bb84b 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
@@ -361,30 +361,29 @@ case class Cast(child: Expression, dataType: DataType)
b => x.numeric.asInstanceOf[Numeric[Any]].toFloat(b)
}
- private[this] def castArray(from: ArrayType, to: ArrayType): Any => Any = {
- val elementCast = cast(from.elementType, to.elementType)
+ private[this] def castArray(fromType: DataType, toType: DataType): Any => Any = {
+ val elementCast = cast(fromType, toType)
// TODO: Could be faster?
buildCast[ArrayData](_, array => {
- val length = array.numElements()
- val values = new Array[Any](length)
- var i = 0
- while (i < length) {
- if (array.isNullAt(i)) {
+ val values = new Array[Any](array.numElements())
+ array.foreach(fromType, (i, e) => {
+ if (e == null) {
values(i) = null
} else {
- values(i) = elementCast(array.get(i))
+ values(i) = elementCast(e)
}
- i += 1
- }
+ })
new GenericArrayData(values)
})
}
private[this] def castMap(from: MapType, to: MapType): Any => Any = {
- val keyCast = cast(from.keyType, to.keyType)
- val valueCast = cast(from.valueType, to.valueType)
- buildCast[Map[Any, Any]](_, _.map {
- case (key, value) => (keyCast(key), if (value == null) null else valueCast(value))
+ val keyCast = castArray(from.keyType, to.keyType)
+ val valueCast = castArray(from.valueType, to.valueType)
+ buildCast[MapData](_, map => {
+ val keys = keyCast(map.keyArray()).asInstanceOf[ArrayData]
+ val values = valueCast(map.valueArray()).asInstanceOf[ArrayData]
+ new ArrayBasedMapData(keys, values)
})
}
@@ -420,7 +419,7 @@ case class Cast(child: Expression, dataType: DataType)
case FloatType => castToFloat(from)
case LongType => castToLong(from)
case DoubleType => castToDouble(from)
- case array: ArrayType => castArray(from.asInstanceOf[ArrayType], array)
+ case array: ArrayType => castArray(from.asInstanceOf[ArrayType].elementType, array.elementType)
case map: MapType => castMap(from.asInstanceOf[MapType], map)
case struct: StructType => castStruct(from.asInstanceOf[StructType], struct)
}
@@ -461,7 +460,8 @@ case class Cast(child: Expression, dataType: DataType)
case LongType => castToLongCode(from)
case DoubleType => castToDoubleCode(from)
- case array: ArrayType => castArrayCode(from.asInstanceOf[ArrayType], array, ctx)
+ case array: ArrayType =>
+ castArrayCode(from.asInstanceOf[ArrayType].elementType, array.elementType, ctx)
case map: MapType => castMapCode(from.asInstanceOf[MapType], map, ctx)
case struct: StructType => castStructCode(from.asInstanceOf[StructType], struct, ctx)
}
@@ -801,8 +801,8 @@ case class Cast(child: Expression, dataType: DataType)
}
private[this] def castArrayCode(
- from: ArrayType, to: ArrayType, ctx: CodeGenContext): CastFunction = {
- val elementCast = nullSafeCastFunction(from.elementType, to.elementType, ctx)
+ fromType: DataType, toType: DataType, ctx: CodeGenContext): CastFunction = {
+ val elementCast = nullSafeCastFunction(fromType, toType, ctx)
val arrayClass = classOf[GenericArrayData].getName
val fromElementNull = ctx.freshName("feNull")
val fromElementPrim = ctx.freshName("fePrim")
@@ -821,10 +821,10 @@ case class Cast(child: Expression, dataType: DataType)
$values[$j] = null;
} else {
boolean $fromElementNull = false;
- ${ctx.javaType(from.elementType)} $fromElementPrim =
- ${ctx.getValue(c, from.elementType, j)};
+ ${ctx.javaType(fromType)} $fromElementPrim =
+ ${ctx.getValue(c, fromType, j)};
${castCode(ctx, fromElementPrim,
- fromElementNull, toElementPrim, toElementNull, to.elementType, elementCast)}
+ fromElementNull, toElementPrim, toElementNull, toType, elementCast)}
if ($toElementNull) {
$values[$j] = null;
} else {
@@ -837,48 +837,29 @@ case class Cast(child: Expression, dataType: DataType)
}
private[this] def castMapCode(from: MapType, to: MapType, ctx: CodeGenContext): CastFunction = {
- val keyCast = nullSafeCastFunction(from.keyType, to.keyType, ctx)
- val valueCast = nullSafeCastFunction(from.valueType, to.valueType, ctx)
-
- val hashMapClass = classOf[mutable.HashMap[Any, Any]].getName
- val fromKeyPrim = ctx.freshName("fkp")
- val fromKeyNull = ctx.freshName("fkn")
- val fromValuePrim = ctx.freshName("fvp")
- val fromValueNull = ctx.freshName("fvn")
- val toKeyPrim = ctx.freshName("tkp")
- val toKeyNull = ctx.freshName("tkn")
- val toValuePrim = ctx.freshName("tvp")
- val toValueNull = ctx.freshName("tvn")
- val result = ctx.freshName("result")
+ val keysCast = castArrayCode(from.keyType, to.keyType, ctx)
+ val valuesCast = castArrayCode(from.valueType, to.valueType, ctx)
+
+ val mapClass = classOf[ArrayBasedMapData].getName
+
+ val keys = ctx.freshName("keys")
+ val convertedKeys = ctx.freshName("convertedKeys")
+ val convertedKeysNull = ctx.freshName("convertedKeysNull")
+
+ val values = ctx.freshName("values")
+ val convertedValues = ctx.freshName("convertedValues")
+ val convertedValuesNull = ctx.freshName("convertedValuesNull")
(c, evPrim, evNull) =>
s"""
- final $hashMapClass $result = new $hashMapClass();
- scala.collection.Iterator iter = $c.iterator();
- while (iter.hasNext()) {
- scala.Tuple2 kv = (scala.Tuple2) iter.next();
- boolean $fromKeyNull = false;
- ${ctx.javaType(from.keyType)} $fromKeyPrim =
- (${ctx.boxedType(from.keyType)}) kv._1();
- ${castCode(ctx, fromKeyPrim,
- fromKeyNull, toKeyPrim, toKeyNull, to.keyType, keyCast)}
-
- boolean $fromValueNull = kv._2() == null;
- if ($fromValueNull) {
- $result.put($toKeyPrim, null);
- } else {
- ${ctx.javaType(from.valueType)} $fromValuePrim =
- (${ctx.boxedType(from.valueType)}) kv._2();
- ${castCode(ctx, fromValuePrim,
- fromValueNull, toValuePrim, toValueNull, to.valueType, valueCast)}
- if ($toValueNull) {
- $result.put($toKeyPrim, null);
- } else {
- $result.put($toKeyPrim, $toValuePrim);
- }
- }
- }
- $evPrim = $result;
+ final ArrayData $keys = $c.keyArray();
+ final ArrayData $values = $c.valueArray();
+ ${castCode(ctx, keys, "false",
+ convertedKeys, convertedKeysNull, ArrayType(to.keyType), keysCast)}
+ ${castCode(ctx, values, "false",
+ convertedValues, convertedValuesNull, ArrayType(to.valueType), valuesCast)}
+
+ $evPrim = new $mapClass($convertedKeys, $convertedValues);
"""
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/GenericSpecializedGetters.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/GenericSpecializedGetters.scala
new file mode 100644
index 0000000000..6e957928e0
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/GenericSpecializedGetters.scala
@@ -0,0 +1,61 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.expressions
+
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.types.{DataType, MapData, ArrayData, Decimal}
+import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String}
+
+trait GenericSpecializedGetters extends SpecializedGetters {
+
+ def genericGet(ordinal: Int): Any
+
+ private def getAs[T](ordinal: Int) = genericGet(ordinal).asInstanceOf[T]
+
+ override def isNullAt(ordinal: Int): Boolean = getAs[AnyRef](ordinal) eq null
+
+ override def get(ordinal: Int, elementType: DataType): AnyRef = getAs(ordinal)
+
+ override def getBoolean(ordinal: Int): Boolean = getAs(ordinal)
+
+ override def getByte(ordinal: Int): Byte = getAs(ordinal)
+
+ override def getShort(ordinal: Int): Short = getAs(ordinal)
+
+ override def getInt(ordinal: Int): Int = getAs(ordinal)
+
+ override def getLong(ordinal: Int): Long = getAs(ordinal)
+
+ override def getFloat(ordinal: Int): Float = getAs(ordinal)
+
+ override def getDouble(ordinal: Int): Double = getAs(ordinal)
+
+ override def getDecimal(ordinal: Int, precision: Int, scale: Int): Decimal = getAs(ordinal)
+
+ override def getUTF8String(ordinal: Int): UTF8String = getAs(ordinal)
+
+ override def getBinary(ordinal: Int): Array[Byte] = getAs(ordinal)
+
+ override def getInterval(ordinal: Int): CalendarInterval = getAs(ordinal)
+
+ override def getStruct(ordinal: Int, numFields: Int): InternalRow = getAs(ordinal)
+
+ override def getArray(ordinal: Int): ArrayData = getAs(ordinal)
+
+ override def getMap(ordinal: Int): MapData = getAs(ordinal)
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala
index 7c7664e4c1..d79325aea8 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala
@@ -19,8 +19,8 @@ package org.apache.spark.sql.catalyst.expressions
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateUnsafeProjection, GenerateMutableProjection}
-import org.apache.spark.sql.types.{Decimal, StructType, DataType}
-import org.apache.spark.unsafe.types.UTF8String
+import org.apache.spark.sql.types._
+import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String}
/**
* A [[Projection]] that is calculated by calling the `eval` of each of the specified expressions.
@@ -190,45 +190,55 @@ class JoinedRow extends InternalRow {
override def numFields: Int = row1.numFields + row2.numFields
- override def getUTF8String(i: Int): UTF8String = {
- if (i < row1.numFields) row1.getUTF8String(i) else row2.getUTF8String(i - row1.numFields)
- }
-
- override def getBinary(i: Int): Array[Byte] = {
- if (i < row1.numFields) row1.getBinary(i) else row2.getBinary(i - row1.numFields)
- }
-
- override def get(i: Int, dataType: DataType): Any =
- if (i < row1.numFields) row1.get(i) else row2.get(i - row1.numFields)
+ override def genericGet(i: Int): Any =
+ if (i < row1.numFields) row1.genericGet(i) else row2.genericGet(i - row1.numFields)
override def isNullAt(i: Int): Boolean =
if (i < row1.numFields) row1.isNullAt(i) else row2.isNullAt(i - row1.numFields)
+ override def getBoolean(i: Int): Boolean =
+ if (i < row1.numFields) row1.getBoolean(i) else row2.getBoolean(i - row1.numFields)
+
+ override def getByte(i: Int): Byte =
+ if (i < row1.numFields) row1.getByte(i) else row2.getByte(i - row1.numFields)
+
+ override def getShort(i: Int): Short =
+ if (i < row1.numFields) row1.getShort(i) else row2.getShort(i - row1.numFields)
+
override def getInt(i: Int): Int =
if (i < row1.numFields) row1.getInt(i) else row2.getInt(i - row1.numFields)
override def getLong(i: Int): Long =
if (i < row1.numFields) row1.getLong(i) else row2.getLong(i - row1.numFields)
+ override def getFloat(i: Int): Float =
+ if (i < row1.numFields) row1.getFloat(i) else row2.getFloat(i - row1.numFields)
+
override def getDouble(i: Int): Double =
if (i < row1.numFields) row1.getDouble(i) else row2.getDouble(i - row1.numFields)
- override def getBoolean(i: Int): Boolean =
- if (i < row1.numFields) row1.getBoolean(i) else row2.getBoolean(i - row1.numFields)
+ override def getDecimal(i: Int, precision: Int, scale: Int): Decimal = {
+ if (i < row1.numFields) {
+ row1.getDecimal(i, precision, scale)
+ } else {
+ row2.getDecimal(i - row1.numFields, precision, scale)
+ }
+ }
- override def getShort(i: Int): Short =
- if (i < row1.numFields) row1.getShort(i) else row2.getShort(i - row1.numFields)
+ override def getUTF8String(i: Int): UTF8String =
+ if (i < row1.numFields) row1.getUTF8String(i) else row2.getUTF8String(i - row1.numFields)
- override def getByte(i: Int): Byte =
- if (i < row1.numFields) row1.getByte(i) else row2.getByte(i - row1.numFields)
+ override def getBinary(i: Int): Array[Byte] =
+ if (i < row1.numFields) row1.getBinary(i) else row2.getBinary(i - row1.numFields)
- override def getFloat(i: Int): Float =
- if (i < row1.numFields) row1.getFloat(i) else row2.getFloat(i - row1.numFields)
+ override def getArray(i: Int): ArrayData =
+ if (i < row1.numFields) row1.getArray(i) else row2.getArray(i - row1.numFields)
- override def getDecimal(i: Int, precision: Int, scale: Int): Decimal = {
- if (i < row1.numFields) row1.getDecimal(i, precision, scale)
- else row2.getDecimal(i - row1.numFields, precision, scale)
- }
+ override def getInterval(i: Int): CalendarInterval =
+ if (i < row1.numFields) row1.getInterval(i) else row2.getInterval(i - row1.numFields)
+
+ override def getMap(i: Int): MapData =
+ if (i < row1.numFields) row1.getMap(i) else row2.getMap(i - row1.numFields)
override def getStruct(i: Int, numFields: Int): InternalRow = {
if (i < row1.numFields) {
@@ -239,14 +249,9 @@ class JoinedRow extends InternalRow {
}
override def copy(): InternalRow = {
- val totalSize = row1.numFields + row2.numFields
- val copiedValues = new Array[Any](totalSize)
- var i = 0
- while(i < totalSize) {
- copiedValues(i) = get(i)
- i += 1
- }
- new GenericInternalRow(copiedValues)
+ val copy1 = row1.copy()
+ val copy2 = row2.copy()
+ new JoinedRow(copy1, copy2)
}
override def toString: String = {
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala
index b877ce47c0..d149a5b179 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala
@@ -213,18 +213,12 @@ final class SpecificMutableRow(val values: Array[MutableValue]) extends MutableR
override def numFields: Int = values.length
- override def toSeq: Seq[Any] = values.map(_.boxed).toSeq
+ override def toSeq: Seq[Any] = values.map(_.boxed)
override def setNullAt(i: Int): Unit = {
values(i).isNull = true
}
- override def get(i: Int, dataType: DataType): Any = values(i).boxed
-
- override def getStruct(ordinal: Int, numFields: Int): InternalRow = {
- values(ordinal).boxed.asInstanceOf[InternalRow]
- }
-
override def isNullAt(i: Int): Boolean = values(i).isNull
override def copy(): InternalRow = {
@@ -238,6 +232,8 @@ final class SpecificMutableRow(val values: Array[MutableValue]) extends MutableR
new GenericInternalRow(newValues)
}
+ override def genericGet(i: Int): Any = values(i).boxed
+
override def update(ordinal: Int, value: Any) {
if (value == null) {
setNullAt(ordinal)
@@ -246,9 +242,6 @@ final class SpecificMutableRow(val values: Array[MutableValue]) extends MutableR
}
}
- override def setString(ordinal: Int, value: String): Unit =
- update(ordinal, UTF8String.fromString(value))
-
override def setInt(ordinal: Int, value: Int): Unit = {
val currentValue = values(ordinal).asInstanceOf[MutableInt]
currentValue.isNull = false
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
index 36f4e9c6be..fc7cfee989 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
@@ -112,8 +112,10 @@ class CodeGenContext {
case BinaryType => s"$getter.getBinary($ordinal)"
case CalendarIntervalType => s"$getter.getInterval($ordinal)"
case t: StructType => s"$getter.getStruct($ordinal, ${t.size})"
- case a: ArrayType => s"$getter.getArray($ordinal)"
- case _ => s"($jt)$getter.get($ordinal)" // todo: remove generic getter.
+ case _: ArrayType => s"$getter.getArray($ordinal)"
+ case _: MapType => s"$getter.getMap($ordinal)"
+ case NullType => "null"
+ case _ => s"($jt)$getter.get($ordinal, null)"
}
}
@@ -156,7 +158,7 @@ class CodeGenContext {
case CalendarIntervalType => "CalendarInterval"
case _: StructType => "InternalRow"
case _: ArrayType => "ArrayData"
- case _: MapType => "scala.collection.Map"
+ case _: MapType => "MapData"
case dt: OpenHashSetUDT if dt.elementType == IntegerType => classOf[IntegerHashSet].getName
case dt: OpenHashSetUDT if dt.elementType == LongType => classOf[LongHashSet].getName
case _ => "Object"
@@ -300,7 +302,8 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin
classOf[UTF8String].getName,
classOf[Decimal].getName,
classOf[CalendarInterval].getName,
- classOf[ArrayData].getName
+ classOf[ArrayData].getName,
+ classOf[MapData].getName
))
evaluator.setExtendedClass(classOf[GeneratedClass])
try {
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala
index 3592014710..6f9acda071 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala
@@ -183,7 +183,7 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] {
public void setNullAt(int i) { nullBits[i] = true; }
public boolean isNullAt(int i) { return nullBits[i]; }
- public Object get(int i, ${classOf[DataType].getName} dataType) {
+ public Object genericGet(int i) {
if (isNullAt(i)) return null;
switch (i) {
$getCases
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
index 0a530596a9..1156797b2c 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
@@ -31,15 +31,11 @@ case class Size(child: Expression) extends UnaryExpression with ExpectsInputType
override def nullSafeEval(value: Any): Int = child.dataType match {
case _: ArrayType => value.asInstanceOf[ArrayData].numElements()
- case _: MapType => value.asInstanceOf[Map[Any, Any]].size
+ case _: MapType => value.asInstanceOf[MapData].numElements()
}
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
- val sizeCall = child.dataType match {
- case _: ArrayType => "numElements()"
- case _: MapType => "size()"
- }
- nullSafeCodeGen(ctx, ev, c => s"${ev.primitive} = ($c).$sizeCall;")
+ nullSafeCodeGen(ctx, ev, c => s"${ev.primitive} = ($c).numElements();")
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala
index 99393c9c76..9927da21b0 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala
@@ -17,8 +17,6 @@
package org.apache.spark.sql.catalyst.expressions
-import scala.collection.Map
-
import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis._
@@ -41,7 +39,7 @@ object ExtractValue {
* Struct | Literal String | GetStructField
* Array[Struct] | Literal String | GetArrayStructFields
* Array | Integral type | GetArrayItem
- * Map | Any type | GetMapValue
+ * Map | map key type | GetMapValue
*/
def apply(
child: Expression,
@@ -60,18 +58,14 @@ object ExtractValue {
GetArrayStructFields(child, fields(ordinal).copy(name = fieldName),
ordinal, fields.length, containsNull)
- case (_: ArrayType, _) if extraction.dataType.isInstanceOf[IntegralType] =>
- GetArrayItem(child, extraction)
+ case (_: ArrayType, _) => GetArrayItem(child, extraction)
- case (_: MapType, _) =>
- GetMapValue(child, extraction)
+ case (MapType(kt, _, _), _) => GetMapValue(child, extraction)
case (otherType, _) =>
val errorMsg = otherType match {
- case StructType(_) | ArrayType(StructType(_), _) =>
+ case StructType(_) =>
s"Field name should be String Literal, but it's $extraction"
- case _: ArrayType =>
- s"Array index should be integral type, but it's ${extraction.dataType}"
case other =>
s"Can't extract value from $child"
}
@@ -190,9 +184,13 @@ case class GetArrayStructFields(
/**
* Returns the field at `ordinal` in the Array `child`.
*
- * No need to do type checking since it is handled by [[ExtractValue]].
+ * We need to do type checking here as `ordinal` expression maybe unresolved.
*/
-case class GetArrayItem(child: Expression, ordinal: Expression) extends BinaryExpression {
+case class GetArrayItem(child: Expression, ordinal: Expression)
+ extends BinaryExpression with ExpectsInputTypes {
+
+ // We have done type checking for child in `ExtractValue`, so only need to check the `ordinal`.
+ override def inputTypes: Seq[AbstractDataType] = Seq(AnyDataType, IntegralType)
override def toString: String = s"$child[$ordinal]"
@@ -205,14 +203,12 @@ case class GetArrayItem(child: Expression, ordinal: Expression) extends BinaryEx
override def dataType: DataType = child.dataType.asInstanceOf[ArrayType].elementType
protected override def nullSafeEval(value: Any, ordinal: Any): Any = {
- // TODO: consider using Array[_] for ArrayType child to avoid
- // boxing of primitives
val baseValue = value.asInstanceOf[ArrayData]
val index = ordinal.asInstanceOf[Number].intValue()
if (index >= baseValue.numElements() || index < 0) {
null
} else {
- baseValue.get(index)
+ baseValue.get(index, dataType)
}
}
@@ -233,9 +229,15 @@ case class GetArrayItem(child: Expression, ordinal: Expression) extends BinaryEx
/**
* Returns the value of key `key` in Map `child`.
*
- * No need to do type checking since it is handled by [[ExtractValue]].
+ * We need to do type checking here as `key` expression maybe unresolved.
*/
-case class GetMapValue(child: Expression, key: Expression) extends BinaryExpression {
+case class GetMapValue(child: Expression, key: Expression)
+ extends BinaryExpression with ExpectsInputTypes {
+
+ private def keyType = child.dataType.asInstanceOf[MapType].keyType
+
+ // We have done type checking for child in `ExtractValue`, so only need to check the `key`.
+ override def inputTypes: Seq[AbstractDataType] = Seq(AnyDataType, keyType)
override def toString: String = s"$child[$key]"
@@ -247,16 +249,53 @@ case class GetMapValue(child: Expression, key: Expression) extends BinaryExpress
override def dataType: DataType = child.dataType.asInstanceOf[MapType].valueType
+ // todo: current search is O(n), improve it.
protected override def nullSafeEval(value: Any, ordinal: Any): Any = {
- val baseValue = value.asInstanceOf[Map[Any, _]]
- baseValue.get(ordinal).orNull
+ val map = value.asInstanceOf[MapData]
+ val length = map.numElements()
+ val keys = map.keyArray()
+
+ var i = 0
+ var found = false
+ while (i < length && !found) {
+ if (keys.get(i, keyType) == ordinal) {
+ found = true
+ } else {
+ i += 1
+ }
+ }
+
+ if (!found) {
+ null
+ } else {
+ map.valueArray().get(i, dataType)
+ }
}
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
+ val index = ctx.freshName("index")
+ val length = ctx.freshName("length")
+ val keys = ctx.freshName("keys")
+ val found = ctx.freshName("found")
+ val key = ctx.freshName("key")
nullSafeCodeGen(ctx, ev, (eval1, eval2) => {
s"""
- if ($eval1.contains($eval2)) {
- ${ev.primitive} = (${ctx.boxedType(dataType)})$eval1.apply($eval2);
+ final int $length = $eval1.numElements();
+ final ArrayData $keys = $eval1.keyArray();
+
+ int $index = 0;
+ boolean $found = false;
+ while ($index < $length && !$found) {
+ final ${ctx.javaType(keyType)} $key = ${ctx.getValue(keys, keyType, index)};
+ if (${ctx.genEqual(keyType, key, eval2)}) {
+ $found = true;
+ } else {
+ $index++;
+ }
+ }
+
+ if ($found) {
+ ${ev.primitive} = ${ctx.getValue(eval1 + ".valueArray()", dataType, index)};
} else {
${ev.isNull} = true;
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala
index 8064235c64..d474853355 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala
@@ -120,13 +120,30 @@ case class Explode(child: Expression) extends UnaryExpression with Generator wit
override def eval(input: InternalRow): TraversableOnce[InternalRow] = {
child.dataType match {
- case ArrayType(_, _) =>
+ case ArrayType(et, _) =>
val inputArray = child.eval(input).asInstanceOf[ArrayData]
- if (inputArray == null) Nil else inputArray.toArray().map(v => InternalRow(v))
- case MapType(_, _, _) =>
- val inputMap = child.eval(input).asInstanceOf[Map[Any, Any]]
- if (inputMap == null) Nil
- else inputMap.map { case (k, v) => InternalRow(k, v) }
+ if (inputArray == null) {
+ Nil
+ } else {
+ val rows = new Array[InternalRow](inputArray.numElements())
+ inputArray.foreach(et, (i, e) => {
+ rows(i) = InternalRow(e)
+ })
+ rows
+ }
+ case MapType(kt, vt, _) =>
+ val inputMap = child.eval(input).asInstanceOf[MapData]
+ if (inputMap == null) {
+ Nil
+ } else {
+ val rows = new Array[InternalRow](inputMap.numElements())
+ var i = 0
+ inputMap.foreach(kt, vt, (k, v) => {
+ rows(i) = InternalRow(k, v)
+ i += 1
+ })
+ rows
+ }
}
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala
index df6ea586c8..73f6b7a550 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala
@@ -32,28 +32,14 @@ abstract class MutableRow extends InternalRow {
def update(i: Int, value: Any)
// default implementation (slow)
- def setInt(i: Int, value: Int): Unit = { update(i, value) }
- def setLong(i: Int, value: Long): Unit = { update(i, value) }
- def setDouble(i: Int, value: Double): Unit = { update(i, value) }
def setBoolean(i: Int, value: Boolean): Unit = { update(i, value) }
- def setShort(i: Int, value: Short): Unit = { update(i, value) }
def setByte(i: Int, value: Byte): Unit = { update(i, value) }
+ def setShort(i: Int, value: Short): Unit = { update(i, value) }
+ def setInt(i: Int, value: Int): Unit = { update(i, value) }
+ def setLong(i: Int, value: Long): Unit = { update(i, value) }
def setFloat(i: Int, value: Float): Unit = { update(i, value) }
+ def setDouble(i: Int, value: Double): Unit = { update(i, value) }
def setDecimal(i: Int, value: Decimal, precision: Int) { update(i, value) }
- def setString(i: Int, value: String): Unit = {
- update(i, UTF8String.fromString(value))
- }
-
- override def copy(): InternalRow = {
- val n = numFields
- val arr = new Array[Any](n)
- var i = 0
- while (i < n) {
- arr(i) = get(i)
- i += 1
- }
- new GenericInternalRow(arr)
- }
}
/**
@@ -96,17 +82,13 @@ class GenericInternalRow(protected[sql] val values: Array[Any]) extends Internal
def this(size: Int) = this(new Array[Any](size))
- override def toSeq: Seq[Any] = values.toSeq
-
- override def numFields: Int = values.length
+ override def genericGet(ordinal: Int): Any = values(ordinal)
- override def get(i: Int, dataType: DataType): Any = values(i)
+ override def toSeq: Seq[Any] = values
- override def getStruct(ordinal: Int, numFields: Int): InternalRow = {
- values(ordinal).asInstanceOf[InternalRow]
- }
+ override def numFields: Int = values.length
- override def copy(): InternalRow = this
+ override def copy(): InternalRow = new GenericInternalRow(values.clone())
}
/**
@@ -127,15 +109,11 @@ class GenericMutableRow(val values: Array[Any]) extends MutableRow {
def this(size: Int) = this(new Array[Any](size))
- override def toSeq: Seq[Any] = values.toSeq
-
- override def numFields: Int = values.length
+ override def genericGet(ordinal: Int): Any = values(ordinal)
- override def get(i: Int, dataType: DataType): Any = values(i)
+ override def toSeq: Seq[Any] = values
- override def getStruct(ordinal: Int, numFields: Int): InternalRow = {
- values(ordinal).asInstanceOf[InternalRow]
- }
+ override def numFields: Int = values.length
override def setNullAt(i: Int): Unit = { values(i) = null}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala
index 5dd387a418..3ce5d6a9c7 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala
@@ -95,7 +95,7 @@ case class ConcatWs(children: Seq[Expression])
val flatInputs = children.flatMap { child =>
child.eval(input) match {
case s: UTF8String => Iterator(s)
- case arr: ArrayData => arr.toArray().map(_.asInstanceOf[UTF8String])
+ case arr: ArrayData => arr.toArray[UTF8String](StringType)
case null => Iterator(null.asInstanceOf[UTF8String])
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayBasedMapData.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayBasedMapData.scala
new file mode 100644
index 0000000000..db4876355d
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayBasedMapData.scala
@@ -0,0 +1,51 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.types
+
+class ArrayBasedMapData(val keyArray: ArrayData, val valueArray: ArrayData) extends MapData {
+ require(keyArray.numElements() == valueArray.numElements())
+
+ override def numElements(): Int = keyArray.numElements()
+
+ override def equals(o: Any): Boolean = {
+ if (!o.isInstanceOf[ArrayBasedMapData]) {
+ return false
+ }
+
+ val other = o.asInstanceOf[ArrayBasedMapData]
+ if (other eq null) {
+ return false
+ }
+
+ this.keyArray == other.keyArray && this.valueArray == other.valueArray
+ }
+
+ override def hashCode: Int = {
+ keyArray.hashCode() * 37 + valueArray.hashCode()
+ }
+
+ override def toString(): String = {
+ s"keys: $keyArray\nvalues: $valueArray"
+ }
+}
+
+object ArrayBasedMapData {
+ def apply(keys: Array[Any], values: Array[Any]): ArrayBasedMapData = {
+ new ArrayBasedMapData(new GenericArrayData(keys), new GenericArrayData(values))
+ }
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayData.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayData.scala
index 14a7285877..c99fc23325 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayData.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayData.scala
@@ -20,102 +20,111 @@ package org.apache.spark.sql.types
import org.apache.spark.sql.catalyst.expressions.SpecializedGetters
abstract class ArrayData extends SpecializedGetters with Serializable {
- // todo: remove this after we handle all types.(map type need special getter)
- def get(ordinal: Int): Any
-
def numElements(): Int
- // todo: need a more efficient way to iterate array type.
- def toArray(): Array[Any] = {
- val n = numElements()
- val values = new Array[Any](n)
+ def toBooleanArray(): Array[Boolean] = {
+ val size = numElements()
+ val values = new Array[Boolean](size)
var i = 0
- while (i < n) {
- if (isNullAt(i)) {
- values(i) = null
- } else {
- values(i) = get(i)
- }
+ while (i < size) {
+ values(i) = getBoolean(i)
i += 1
}
values
}
- override def toString(): String = toArray.mkString("[", ",", "]")
+ def toByteArray(): Array[Byte] = {
+ val size = numElements()
+ val values = new Array[Byte](size)
+ var i = 0
+ while (i < size) {
+ values(i) = getByte(i)
+ i += 1
+ }
+ values
+ }
- override def equals(o: Any): Boolean = {
- if (!o.isInstanceOf[ArrayData]) {
- return false
+ def toShortArray(): Array[Short] = {
+ val size = numElements()
+ val values = new Array[Short](size)
+ var i = 0
+ while (i < size) {
+ values(i) = getShort(i)
+ i += 1
}
+ values
+ }
- val other = o.asInstanceOf[ArrayData]
- if (other eq null) {
- return false
+ def toIntArray(): Array[Int] = {
+ val size = numElements()
+ val values = new Array[Int](size)
+ var i = 0
+ while (i < size) {
+ values(i) = getInt(i)
+ i += 1
}
+ values
+ }
- val len = numElements()
- if (len != other.numElements()) {
- return false
+ def toLongArray(): Array[Long] = {
+ val size = numElements()
+ val values = new Array[Long](size)
+ var i = 0
+ while (i < size) {
+ values(i) = getLong(i)
+ i += 1
}
+ values
+ }
+ def toFloatArray(): Array[Float] = {
+ val size = numElements()
+ val values = new Array[Float](size)
var i = 0
- while (i < len) {
- if (isNullAt(i) != other.isNullAt(i)) {
- return false
- }
- if (!isNullAt(i)) {
- val o1 = get(i)
- val o2 = other.get(i)
- o1 match {
- case b1: Array[Byte] =>
- if (!o2.isInstanceOf[Array[Byte]] ||
- !java.util.Arrays.equals(b1, o2.asInstanceOf[Array[Byte]])) {
- return false
- }
- case f1: Float if java.lang.Float.isNaN(f1) =>
- if (!o2.isInstanceOf[Float] || ! java.lang.Float.isNaN(o2.asInstanceOf[Float])) {
- return false
- }
- case d1: Double if java.lang.Double.isNaN(d1) =>
- if (!o2.isInstanceOf[Double] || ! java.lang.Double.isNaN(o2.asInstanceOf[Double])) {
- return false
- }
- case _ => if (o1 != o2) {
- return false
- }
- }
+ while (i < size) {
+ values(i) = getFloat(i)
+ i += 1
+ }
+ values
+ }
+
+ def toDoubleArray(): Array[Double] = {
+ val size = numElements()
+ val values = new Array[Double](size)
+ var i = 0
+ while (i < size) {
+ values(i) = getDouble(i)
+ i += 1
+ }
+ values
+ }
+
+ def toArray[T](elementType: DataType): Array[T] = {
+ val size = numElements()
+ val values = new Array[Any](size)
+ var i = 0
+ while (i < size) {
+ if (isNullAt(i)) {
+ values(i) = null
+ } else {
+ values(i) = get(i, elementType)
}
i += 1
}
- true
+ values.asInstanceOf[Array[T]]
}
- override def hashCode: Int = {
- var result: Int = 37
+ // todo: specialize this.
+ def foreach(elementType: DataType, f: (Int, Any) => Unit): Unit = {
+ val size = numElements()
var i = 0
- val len = numElements()
- while (i < len) {
- val update: Int =
- if (isNullAt(i)) {
- 0
- } else {
- get(i) match {
- case b: Boolean => if (b) 0 else 1
- case b: Byte => b.toInt
- case s: Short => s.toInt
- case i: Int => i
- case l: Long => (l ^ (l >>> 32)).toInt
- case f: Float => java.lang.Float.floatToIntBits(f)
- case d: Double =>
- val b = java.lang.Double.doubleToLongBits(d)
- (b ^ (b >>> 32)).toInt
- case a: Array[Byte] => java.util.Arrays.hashCode(a)
- case other => other.hashCode()
- }
- }
- result = 37 * result + update
+ while (i < size) {
+ if (isNullAt(i)) {
+ f(i, null)
+ } else {
+ f(i, get(i, elementType))
+ }
i += 1
}
- result
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/GenericArrayData.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/GenericArrayData.scala
index 35ace673fb..b3e75f8bad 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/GenericArrayData.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/GenericArrayData.scala
@@ -17,43 +17,91 @@
package org.apache.spark.sql.types
-import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.unsafe.types.{UTF8String, CalendarInterval}
+import org.apache.spark.sql.catalyst.expressions.GenericSpecializedGetters
-class GenericArrayData(array: Array[Any]) extends ArrayData {
- private def getAs[T](ordinal: Int) = get(ordinal).asInstanceOf[T]
+class GenericArrayData(array: Array[Any]) extends ArrayData with GenericSpecializedGetters {
- override def toArray(): Array[Any] = array
+ override def genericGet(ordinal: Int): Any = array(ordinal)
- override def get(ordinal: Int): Any = array(ordinal)
-
- override def isNullAt(ordinal: Int): Boolean = get(ordinal) == null
-
- override def getBoolean(ordinal: Int): Boolean = getAs(ordinal)
-
- override def getByte(ordinal: Int): Byte = getAs(ordinal)
-
- override def getShort(ordinal: Int): Short = getAs(ordinal)
-
- override def getInt(ordinal: Int): Int = getAs(ordinal)
-
- override def getLong(ordinal: Int): Long = getAs(ordinal)
-
- override def getFloat(ordinal: Int): Float = getAs(ordinal)
-
- override def getDouble(ordinal: Int): Double = getAs(ordinal)
-
- override def getDecimal(ordinal: Int, precision: Int, scale: Int): Decimal = getAs(ordinal)
-
- override def getUTF8String(ordinal: Int): UTF8String = getAs(ordinal)
-
- override def getBinary(ordinal: Int): Array[Byte] = getAs(ordinal)
-
- override def getInterval(ordinal: Int): CalendarInterval = getAs(ordinal)
-
- override def getStruct(ordinal: Int, numFields: Int): InternalRow = getAs(ordinal)
-
- override def getArray(ordinal: Int): ArrayData = getAs(ordinal)
+ override def toArray[T](elementType: DataType): Array[T] = array.asInstanceOf[Array[T]]
override def numElements(): Int = array.length
+
+ override def toString(): String = array.mkString("[", ",", "]")
+
+ override def equals(o: Any): Boolean = {
+ if (!o.isInstanceOf[GenericArrayData]) {
+ return false
+ }
+
+ val other = o.asInstanceOf[GenericArrayData]
+ if (other eq null) {
+ return false
+ }
+
+ val len = numElements()
+ if (len != other.numElements()) {
+ return false
+ }
+
+ var i = 0
+ while (i < len) {
+ if (isNullAt(i) != other.isNullAt(i)) {
+ return false
+ }
+ if (!isNullAt(i)) {
+ val o1 = genericGet(i)
+ val o2 = other.genericGet(i)
+ o1 match {
+ case b1: Array[Byte] =>
+ if (!o2.isInstanceOf[Array[Byte]] ||
+ !java.util.Arrays.equals(b1, o2.asInstanceOf[Array[Byte]])) {
+ return false
+ }
+ case f1: Float if java.lang.Float.isNaN(f1) =>
+ if (!o2.isInstanceOf[Float] || ! java.lang.Float.isNaN(o2.asInstanceOf[Float])) {
+ return false
+ }
+ case d1: Double if java.lang.Double.isNaN(d1) =>
+ if (!o2.isInstanceOf[Double] || ! java.lang.Double.isNaN(o2.asInstanceOf[Double])) {
+ return false
+ }
+ case _ => if (o1 != o2) {
+ return false
+ }
+ }
+ }
+ i += 1
+ }
+ true
+ }
+
+ override def hashCode: Int = {
+ var result: Int = 37
+ var i = 0
+ val len = numElements()
+ while (i < len) {
+ val update: Int =
+ if (isNullAt(i)) {
+ 0
+ } else {
+ genericGet(i) match {
+ case b: Boolean => if (b) 0 else 1
+ case b: Byte => b.toInt
+ case s: Short => s.toInt
+ case i: Int => i
+ case l: Long => (l ^ (l >>> 32)).toInt
+ case f: Float => java.lang.Float.floatToIntBits(f)
+ case d: Double =>
+ val b = java.lang.Double.doubleToLongBits(d)
+ (b ^ (b >>> 32)).toInt
+ case a: Array[Byte] => java.util.Arrays.hashCode(a)
+ case other => other.hashCode()
+ }
+ }
+ result = 37 * result + update
+ i += 1
+ }
+ result
+ }
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/MapData.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/MapData.scala
new file mode 100644
index 0000000000..5514c3cd85
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/MapData.scala
@@ -0,0 +1,38 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.types
+
+abstract class MapData extends Serializable {
+
+ def numElements(): Int
+
+ def keyArray(): ArrayData
+
+ def valueArray(): ArrayData
+
+ def foreach(keyType: DataType, valueType: DataType, f: (Any, Any) => Unit): Unit = {
+ val length = numElements()
+ val keys = keyArray()
+ val values = valueArray()
+ var i = 0
+ while (i < length) {
+ f(keys.get(i, keyType), values.get(i, valueType))
+ i += 1
+ }
+ }
+}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala
index 3fa246b69d..e60990aeb4 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala
@@ -171,8 +171,6 @@ class ComplexTypeSuite extends SparkFunSuite with ExpressionEvalHelper {
test("error message of ExtractValue") {
val structType = StructType(StructField("a", StringType, true) :: Nil)
- val arrayStructType = ArrayType(structType)
- val arrayType = ArrayType(StringType)
val otherType = StringType
def checkErrorMessage(
@@ -189,8 +187,6 @@ class ComplexTypeSuite extends SparkFunSuite with ExpressionEvalHelper {
}
checkErrorMessage(structType, IntegerType, "Field name should be String Literal")
- checkErrorMessage(arrayStructType, BooleanType, "Field name should be String Literal")
- checkErrorMessage(arrayType, StringType, "Array index should be integral type")
checkErrorMessage(otherType, StringType, "Can't extract value from")
}
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala
index a0e1701339..44f845620a 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala
@@ -87,7 +87,7 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers {
val row = new SpecificMutableRow(fieldTypes)
row.setLong(0, 0)
- row.setString(1, "Hello")
+ row.update(1, UTF8String.fromString("Hello"))
row.update(2, DateTimeUtils.fromJavaDate(Date.valueOf("1970-01-01")))
row.update(3, DateTimeUtils.fromJavaTimestamp(Timestamp.valueOf("2015-05-08 08:10:25")))
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala
index f26f41fb75..c37007f1ee 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala
@@ -159,10 +159,16 @@ package object debug {
case (row: InternalRow, StructType(fields)) =>
row.toSeq.zip(fields.map(_.dataType)).foreach { case(d, t) => typeCheck(d, t) }
case (a: ArrayData, ArrayType(elemType, _)) =>
- a.toArray().foreach(typeCheck(_, elemType))
- case (m: Map[_, _], MapType(keyType, valueType, _)) =>
- m.keys.foreach(typeCheck(_, keyType))
- m.values.foreach(typeCheck(_, valueType))
+ a.foreach(elemType, (_, e) => {
+ typeCheck(e, elemType)
+ })
+ case (m: MapData, MapType(keyType, valueType, _)) =>
+ m.keyArray().foreach(keyType, (_, e) => {
+ typeCheck(e, keyType)
+ })
+ m.valueArray().foreach(valueType, (_, e) => {
+ typeCheck(e, valueType)
+ })
case (_: Long, LongType) =>
case (_: Int, IntegerType) =>
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUDFs.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUDFs.scala
index ef1c6e57dc..aade2e769c 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUDFs.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUDFs.scala
@@ -135,22 +135,18 @@ object EvaluatePython {
new GenericInternalRowWithSchema(values, struct)
case (a: ArrayData, array: ArrayType) =>
- val length = a.numElements()
- val values = new java.util.ArrayList[Any](length)
- var i = 0
- while (i < length) {
- if (a.isNullAt(i)) {
- values.add(null)
- } else {
- values.add(toJava(a.get(i), array.elementType))
- }
- i += 1
- }
+ val values = new java.util.ArrayList[Any](a.numElements())
+ a.foreach(array.elementType, (_, e) => {
+ values.add(toJava(e, array.elementType))
+ })
values
- case (obj: Map[_, _], mt: MapType) => obj.map {
- case (k, v) => (toJava(k, mt.keyType), toJava(v, mt.valueType))
- }.asJava
+ case (map: MapData, mt: MapType) =>
+ val jmap = new java.util.HashMap[Any, Any](map.numElements())
+ map.foreach(mt.keyType, mt.valueType, (k, v) => {
+ jmap.put(toJava(k, mt.keyType), toJava(v, mt.valueType))
+ })
+ jmap
case (ud, udt: UserDefinedType[_]) => toJava(ud, udt.sqlType)
@@ -206,9 +202,10 @@ object EvaluatePython {
case (c, ArrayType(elementType, _)) if c.getClass.isArray =>
new GenericArrayData(c.asInstanceOf[Array[_]].map(e => fromJava(e, elementType)))
- case (c: java.util.Map[_, _], MapType(keyType, valueType, _)) => c.map {
- case (key, value) => (fromJava(key, keyType), fromJava(value, valueType))
- }.toMap
+ case (c: java.util.Map[_, _], MapType(keyType, valueType, _)) =>
+ val keys = c.keysIterator.map(fromJava(_, keyType)).toArray
+ val values = c.valuesIterator.map(fromJava(_, valueType)).toArray
+ ArrayBasedMapData(keys, values)
case (c, StructType(fields)) if c.getClass.isArray =>
new GenericInternalRow(c.asInstanceOf[Array[_]].zip(fields).map {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/json/JacksonParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/json/JacksonParser.scala
index 1c309f8794..bf0448ee96 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/json/JacksonParser.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/json/JacksonParser.scala
@@ -19,10 +19,10 @@ package org.apache.spark.sql.json
import java.io.ByteArrayOutputStream
-import scala.collection.Map
-
import com.fasterxml.jackson.core._
+import scala.collection.mutable.ArrayBuffer
+
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
@@ -31,7 +31,6 @@ import org.apache.spark.sql.json.JacksonUtils.nextUntil
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String
-
private[sql] object JacksonParser {
def apply(
json: RDD[String],
@@ -160,21 +159,21 @@ private[sql] object JacksonParser {
private def convertMap(
factory: JsonFactory,
parser: JsonParser,
- valueType: DataType): Map[UTF8String, Any] = {
- val builder = Map.newBuilder[UTF8String, Any]
+ valueType: DataType): MapData = {
+ val keys = ArrayBuffer.empty[UTF8String]
+ val values = ArrayBuffer.empty[Any]
while (nextUntil(parser, JsonToken.END_OBJECT)) {
- builder +=
- UTF8String.fromString(parser.getCurrentName) -> convertField(factory, parser, valueType)
+ keys += UTF8String.fromString(parser.getCurrentName)
+ values += convertField(factory, parser, valueType)
}
-
- builder.result()
+ ArrayBasedMapData(keys.toArray, values.toArray)
}
private def convertArray(
factory: JsonFactory,
parser: JsonParser,
elementType: DataType): ArrayData = {
- val values = scala.collection.mutable.ArrayBuffer.empty[Any]
+ val values = ArrayBuffer.empty[Any]
while (nextUntil(parser, JsonToken.END_ARRAY)) {
values += convertField(factory, parser, elementType)
}
@@ -213,7 +212,7 @@ private[sql] object JacksonParser {
if (array.numElements() == 0) {
Nil
} else {
- array.toArray().map(_.asInstanceOf[InternalRow])
+ array.toArray[InternalRow](schema)
}
case _ =>
sys.error(
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/CatalystRowConverter.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/CatalystRowConverter.scala
index 172db8362a..6938b07106 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/CatalystRowConverter.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/CatalystRowConverter.scala
@@ -385,7 +385,8 @@ private[parquet] class CatalystRowConverter(
updater: ParentContainerUpdater)
extends GroupConverter {
- private var currentMap: mutable.Map[Any, Any] = _
+ private var currentKeys: ArrayBuffer[Any] = _
+ private var currentValues: ArrayBuffer[Any] = _
private val keyValueConverter = {
val repeatedType = parquetType.getType(0).asGroupType()
@@ -398,12 +399,16 @@ private[parquet] class CatalystRowConverter(
override def getConverter(fieldIndex: Int): Converter = keyValueConverter
- override def end(): Unit = updater.set(currentMap)
+ override def end(): Unit =
+ updater.set(ArrayBasedMapData(currentKeys.toArray, currentValues.toArray))
// NOTE: We can't reuse the mutable Map here and must instantiate a new `Map` for the next
// value. `Row.copy()` only copies row cells, it doesn't do deep copy to objects stored in row
// cells.
- override def start(): Unit = currentMap = mutable.Map.empty[Any, Any]
+ override def start(): Unit = {
+ currentKeys = ArrayBuffer.empty[Any]
+ currentValues = ArrayBuffer.empty[Any]
+ }
/** Parquet converter for key-value pairs within the map. */
private final class KeyValueConverter(
@@ -430,7 +435,10 @@ private[parquet] class CatalystRowConverter(
override def getConverter(fieldIndex: Int): Converter = converters(fieldIndex)
- override def end(): Unit = currentMap(currentKey) = currentValue
+ override def end(): Unit = {
+ currentKeys += currentKey
+ currentValues += currentValue
+ }
override def start(): Unit = {
currentKey = null
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala
index 2332a36468..6ed3580af0 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala
@@ -18,7 +18,7 @@
package org.apache.spark.sql.parquet
import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.types.ArrayData
+import org.apache.spark.sql.types.{MapData, ArrayData}
// TODO Removes this while fixing SPARK-8848
private[sql] object CatalystConverter {
@@ -33,7 +33,7 @@ private[sql] object CatalystConverter {
val MAP_SCHEMA_NAME = "map"
// TODO: consider using Array[T] for arrays to avoid boxing of primitive types
- type ArrayScalaType[T] = ArrayData
- type StructScalaType[T] = InternalRow
- type MapScalaType[K, V] = Map[K, V]
+ type ArrayScalaType = ArrayData
+ type StructScalaType = InternalRow
+ type MapScalaType = MapData
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala
index ec8da38a3d..9cd0250f9c 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala
@@ -88,13 +88,13 @@ private[parquet] class RowWriteSupport extends WriteSupport[InternalRow] with Lo
case t: UserDefinedType[_] => writeValue(t.sqlType, value)
case t @ ArrayType(_, _) => writeArray(
t,
- value.asInstanceOf[CatalystConverter.ArrayScalaType[_]])
+ value.asInstanceOf[CatalystConverter.ArrayScalaType])
case t @ MapType(_, _, _) => writeMap(
t,
- value.asInstanceOf[CatalystConverter.MapScalaType[_, _]])
+ value.asInstanceOf[CatalystConverter.MapScalaType])
case t @ StructType(_) => writeStruct(
t,
- value.asInstanceOf[CatalystConverter.StructScalaType[_]])
+ value.asInstanceOf[CatalystConverter.StructScalaType])
case _ => writePrimitive(schema.asInstanceOf[AtomicType], value)
}
}
@@ -124,7 +124,7 @@ private[parquet] class RowWriteSupport extends WriteSupport[InternalRow] with Lo
private[parquet] def writeStruct(
schema: StructType,
- struct: CatalystConverter.StructScalaType[_]): Unit = {
+ struct: CatalystConverter.StructScalaType): Unit = {
if (struct != null) {
val fields = schema.fields.toArray
writer.startGroup()
@@ -143,7 +143,7 @@ private[parquet] class RowWriteSupport extends WriteSupport[InternalRow] with Lo
private[parquet] def writeArray(
schema: ArrayType,
- array: CatalystConverter.ArrayScalaType[_]): Unit = {
+ array: CatalystConverter.ArrayScalaType): Unit = {
val elementType = schema.elementType
writer.startGroup()
if (array.numElements() > 0) {
@@ -154,7 +154,7 @@ private[parquet] class RowWriteSupport extends WriteSupport[InternalRow] with Lo
writer.startGroup()
if (!array.isNullAt(i)) {
writer.startField(CatalystConverter.ARRAY_ELEMENTS_SCHEMA_NAME, 0)
- writeValue(elementType, array.get(i))
+ writeValue(elementType, array.get(i, elementType))
writer.endField(CatalystConverter.ARRAY_ELEMENTS_SCHEMA_NAME, 0)
}
writer.endGroup()
@@ -165,7 +165,7 @@ private[parquet] class RowWriteSupport extends WriteSupport[InternalRow] with Lo
writer.startField(CatalystConverter.ARRAY_ELEMENTS_SCHEMA_NAME, 0)
var i = 0
while (i < array.numElements()) {
- writeValue(elementType, array.get(i))
+ writeValue(elementType, array.get(i, elementType))
i = i + 1
}
writer.endField(CatalystConverter.ARRAY_ELEMENTS_SCHEMA_NAME, 0)
@@ -176,11 +176,12 @@ private[parquet] class RowWriteSupport extends WriteSupport[InternalRow] with Lo
private[parquet] def writeMap(
schema: MapType,
- map: CatalystConverter.MapScalaType[_, _]): Unit = {
+ map: CatalystConverter.MapScalaType): Unit = {
writer.startGroup()
- if (map.size > 0) {
+ val length = map.numElements()
+ if (length > 0) {
writer.startField(CatalystConverter.MAP_SCHEMA_NAME, 0)
- for ((key, value) <- map) {
+ map.foreach(schema.keyType, schema.valueType, (key, value) => {
writer.startGroup()
writer.startField(CatalystConverter.MAP_KEY_SCHEMA_NAME, 0)
writeValue(schema.keyType, key)
@@ -191,7 +192,7 @@ private[parquet] class RowWriteSupport extends WriteSupport[InternalRow] with Lo
writer.endField(CatalystConverter.MAP_VALUE_SCHEMA_NAME, 1)
}
writer.endGroup()
- }
+ })
writer.endField(CatalystConverter.MAP_SCHEMA_NAME, 0)
}
writer.endGroup()
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
index 97beae2f85..aef940a526 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
@@ -620,6 +620,7 @@ class DataFrameSuite extends QueryTest with SQLTestUtils {
assert(complexData.filter(complexData("m")("1") === 1).count() == 1)
assert(complexData.filter(complexData("s")("key") === 1).count() == 1)
assert(complexData.filter(complexData("m")(complexData("s")("value")) === 1).count() == 1)
+ assert(complexData.filter(complexData("a")(complexData("s")("key")) === 1).count() == 1)
}
test("SPARK-7551: support backticks for DataFrame attribute resolution") {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/RowSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/RowSuite.scala
index 01b7c21e84..8a679c7865 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/RowSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/RowSuite.scala
@@ -19,9 +19,9 @@ package org.apache.spark.sql
import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.execution.SparkSqlSerializer
-
import org.apache.spark.sql.catalyst.expressions.{GenericMutableRow, SpecificMutableRow}
import org.apache.spark.sql.types._
+import org.apache.spark.unsafe.types.UTF8String
class RowSuite extends SparkFunSuite {
@@ -31,7 +31,7 @@ class RowSuite extends SparkFunSuite {
test("create row") {
val expected = new GenericMutableRow(4)
expected.setInt(0, 2147483647)
- expected.setString(1, "this is a string")
+ expected.update(1, UTF8String.fromString("this is a string"))
expected.setBoolean(2, false)
expected.setNullAt(3)
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
index 535011fe3d..51fe9d9d98 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
@@ -581,42 +581,28 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils {
}
test("sorting") {
- val before = sqlContext.conf.externalSortEnabled
- sqlContext.setConf(SQLConf.EXTERNAL_SORT, false)
- sortTest()
- sqlContext.setConf(SQLConf.EXTERNAL_SORT, before)
+ withSQLConf(SQLConf.EXTERNAL_SORT.key -> "false") {
+ sortTest()
+ }
}
test("external sorting") {
- val before = sqlContext.conf.externalSortEnabled
- sqlContext.setConf(SQLConf.EXTERNAL_SORT, true)
- sortTest()
- sqlContext.setConf(SQLConf.EXTERNAL_SORT, before)
+ withSQLConf(SQLConf.EXTERNAL_SORT.key -> "true") {
+ sortTest()
+ }
}
test("SPARK-6927 sorting with codegen on") {
- val externalbefore = sqlContext.conf.externalSortEnabled
- val codegenbefore = sqlContext.conf.codegenEnabled
- sqlContext.setConf(SQLConf.EXTERNAL_SORT, false)
- sqlContext.setConf(SQLConf.CODEGEN_ENABLED, true)
- try{
+ withSQLConf(SQLConf.EXTERNAL_SORT.key -> "false",
+ SQLConf.CODEGEN_ENABLED.key -> "true") {
sortTest()
- } finally {
- sqlContext.setConf(SQLConf.EXTERNAL_SORT, externalbefore)
- sqlContext.setConf(SQLConf.CODEGEN_ENABLED, codegenbefore)
}
}
test("SPARK-6927 external sorting with codegen on") {
- val externalbefore = sqlContext.conf.externalSortEnabled
- val codegenbefore = sqlContext.conf.codegenEnabled
- sqlContext.setConf(SQLConf.CODEGEN_ENABLED, true)
- sqlContext.setConf(SQLConf.EXTERNAL_SORT, true)
- try {
+ withSQLConf(SQLConf.EXTERNAL_SORT.key -> "true",
+ SQLConf.CODEGEN_ENABLED.key -> "true") {
sortTest()
- } finally {
- sqlContext.setConf(SQLConf.EXTERNAL_SORT, externalbefore)
- sqlContext.setConf(SQLConf.CODEGEN_ENABLED, codegenbefore)
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala b/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala
index e340f54850..bd9729c431 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala
@@ -190,8 +190,8 @@ object TestData {
case class ComplexData(m: Map[String, Int], s: TestData, a: Seq[Int], b: Boolean)
val complexData =
TestSQLContext.sparkContext.parallelize(
- ComplexData(Map("1" -> 1), TestData(1, "1"), Seq(1), true)
- :: ComplexData(Map("2" -> 2), TestData(2, "2"), Seq(2), false)
+ ComplexData(Map("1" -> 1), TestData(1, "1"), Seq(1, 1, 1), true)
+ :: ComplexData(Map("2" -> 2), TestData(2, "2"), Seq(2, 2, 2), false)
:: Nil).toDF()
complexData.registerTempTable("complexData")
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala
index 77ed4a9c0d..f29935224e 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala
@@ -57,7 +57,7 @@ private[sql] class MyDenseVectorUDT extends UserDefinedType[MyDenseVector] {
override def deserialize(datum: Any): MyDenseVector = {
datum match {
case data: ArrayData =>
- new MyDenseVector(data.toArray.map(_.asInstanceOf[Double]))
+ new MyDenseVector(data.toDoubleArray())
}
}
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala
index 5926ef9aa3..39d798d072 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala
@@ -51,7 +51,7 @@ import scala.collection.JavaConversions._
* java.sql.Date
* java.sql.Timestamp
* Complex Types =>
- * Map: scala.collection.immutable.Map
+ * Map: [[org.apache.spark.sql.types.MapData]]
* List: [[org.apache.spark.sql.types.ArrayData]]
* Struct: [[org.apache.spark.sql.catalyst.InternalRow]]
* Union: NOT SUPPORTED YET
@@ -290,10 +290,10 @@ private[hive] trait HiveInspectors {
DateTimeUtils.fromJavaDate(poi.getWritableConstantValue.get())
case mi: StandardConstantMapObjectInspector =>
// take the value from the map inspector object, rather than the input data
- mi.getWritableConstantValue.map { case (k, v) =>
- (unwrap(k, mi.getMapKeyObjectInspector),
- unwrap(v, mi.getMapValueObjectInspector))
- }.toMap
+ val map = mi.getWritableConstantValue
+ val keys = map.keysIterator.map(unwrap(_, mi.getMapKeyObjectInspector)).toArray
+ val values = map.valuesIterator.map(unwrap(_, mi.getMapValueObjectInspector)).toArray
+ ArrayBasedMapData(keys, values)
case li: StandardConstantListObjectInspector =>
// take the value from the list inspector object, rather than the input data
val values = li.getWritableConstantValue
@@ -347,12 +347,14 @@ private[hive] trait HiveInspectors {
}
.orNull
case mi: MapObjectInspector =>
- Option(mi.getMap(data)).map(
- _.map {
- case (k, v) =>
- (unwrap(k, mi.getMapKeyObjectInspector),
- unwrap(v, mi.getMapValueObjectInspector))
- }.toMap).orNull
+ val map = mi.getMap(data)
+ if (map == null) {
+ null
+ } else {
+ val keys = map.keysIterator.map(unwrap(_, mi.getMapKeyObjectInspector)).toArray
+ val values = map.valuesIterator.map(unwrap(_, mi.getMapValueObjectInspector)).toArray
+ ArrayBasedMapData(keys, values)
+ }
// currently, hive doesn't provide the ConstantStructObjectInspector
case si: StructObjectInspector =>
val allRefs = si.getAllStructFieldRefs
@@ -365,7 +367,7 @@ private[hive] trait HiveInspectors {
* Wraps with Hive types based on object inspector.
* TODO: Consolidate all hive OI/data interface code.
*/
- protected def wrapperFor(oi: ObjectInspector): Any => Any = oi match {
+ protected def wrapperFor(oi: ObjectInspector, dataType: DataType): Any => Any = oi match {
case _: JavaHiveVarcharObjectInspector =>
(o: Any) =>
val s = o.asInstanceOf[UTF8String].toString
@@ -381,7 +383,10 @@ private[hive] trait HiveInspectors {
(o: Any) => DateTimeUtils.toJavaTimestamp(o.asInstanceOf[Long])
case soi: StandardStructObjectInspector =>
- val wrappers = soi.getAllStructFieldRefs.map(ref => wrapperFor(ref.getFieldObjectInspector))
+ val schema = dataType.asInstanceOf[StructType]
+ val wrappers = soi.getAllStructFieldRefs.zip(schema.fields).map { case (ref, field) =>
+ wrapperFor(ref.getFieldObjectInspector, field.dataType)
+ }
(o: Any) => {
if (o != null) {
val struct = soi.create()
@@ -395,27 +400,34 @@ private[hive] trait HiveInspectors {
}
case loi: ListObjectInspector =>
- val wrapper = wrapperFor(loi.getListElementObjectInspector)
+ val elementType = dataType.asInstanceOf[ArrayType].elementType
+ val wrapper = wrapperFor(loi.getListElementObjectInspector, elementType)
(o: Any) => {
if (o != null) {
- seqAsJavaList(o.asInstanceOf[ArrayData].toArray().map(wrapper))
+ val array = o.asInstanceOf[ArrayData]
+ val values = new java.util.ArrayList[Any](array.numElements())
+ array.foreach(elementType, (_, e) => {
+ values.add(wrapper(e))
+ })
+ values
} else {
null
}
}
case moi: MapObjectInspector =>
- // The Predef.Map is scala.collection.immutable.Map.
- // Since the map values can be mutable, we explicitly import scala.collection.Map at here.
- import scala.collection.Map
+ val mt = dataType.asInstanceOf[MapType]
+ val keyWrapper = wrapperFor(moi.getMapKeyObjectInspector, mt.keyType)
+ val valueWrapper = wrapperFor(moi.getMapValueObjectInspector, mt.valueType)
- val keyWrapper = wrapperFor(moi.getMapKeyObjectInspector)
- val valueWrapper = wrapperFor(moi.getMapValueObjectInspector)
(o: Any) => {
if (o != null) {
- mapAsJavaMap(o.asInstanceOf[Map[_, _]].map { case (key, value) =>
- keyWrapper(key) -> valueWrapper(value)
+ val map = o.asInstanceOf[MapData]
+ val jmap = new java.util.HashMap[Any, Any](map.numElements())
+ map.foreach(mt.keyType, mt.valueType, (k, v) => {
+ jmap.put(keyWrapper(k), valueWrapper(v))
})
+ jmap
} else {
null
}
@@ -531,18 +543,21 @@ private[hive] trait HiveInspectors {
case x: ListObjectInspector =>
val list = new java.util.ArrayList[Object]
val tpe = dataType.asInstanceOf[ArrayType].elementType
- a.asInstanceOf[ArrayData].toArray().foreach {
- v => list.add(wrap(v, x.getListElementObjectInspector, tpe))
- }
+ a.asInstanceOf[ArrayData].foreach(tpe, (_, e) => {
+ list.add(wrap(e, x.getListElementObjectInspector, tpe))
+ })
list
case x: MapObjectInspector =>
val keyType = dataType.asInstanceOf[MapType].keyType
val valueType = dataType.asInstanceOf[MapType].valueType
+ val map = a.asInstanceOf[MapData]
+
// Some UDFs seem to assume we pass in a HashMap.
- val hashMap = new java.util.HashMap[AnyRef, AnyRef]()
- hashMap.putAll(a.asInstanceOf[Map[_, _]].map { case (k, v) =>
- wrap(k, x.getMapKeyObjectInspector, keyType) ->
- wrap(v, x.getMapValueObjectInspector, valueType)
+ val hashMap = new java.util.HashMap[Any, Any](map.numElements())
+
+ map.foreach(keyType, valueType, (k, v) => {
+ hashMap.put(wrap(k, x.getMapKeyObjectInspector, keyType),
+ wrap(v, x.getMapValueObjectInspector, valueType))
})
hashMap
@@ -645,8 +660,9 @@ private[hive] trait HiveInspectors {
ObjectInspectorFactory.getStandardConstantListObjectInspector(listObjectInspector, null)
} else {
val list = new java.util.ArrayList[Object]()
- value.asInstanceOf[ArrayData].toArray()
- .foreach(v => list.add(wrap(v, listObjectInspector, dt)))
+ value.asInstanceOf[ArrayData].foreach(dt, (_, e) => {
+ list.add(wrap(e, listObjectInspector, dt))
+ })
ObjectInspectorFactory.getStandardConstantListObjectInspector(listObjectInspector, list)
}
case Literal(value, MapType(keyType, valueType, _)) =>
@@ -655,11 +671,14 @@ private[hive] trait HiveInspectors {
if (value == null) {
ObjectInspectorFactory.getStandardConstantMapObjectInspector(keyOI, valueOI, null)
} else {
- val map = new java.util.HashMap[Object, Object]()
- value.asInstanceOf[Map[_, _]].foreach (entry => {
- map.put(wrap(entry._1, keyOI, keyType), wrap(entry._2, valueOI, valueType))
+ val map = value.asInstanceOf[MapData]
+ val jmap = new java.util.HashMap[Any, Any](map.numElements())
+
+ map.foreach(keyType, valueType, (k, v) => {
+ jmap.put(wrap(k, keyOI, keyType), wrap(v, valueOI, valueType))
})
- ObjectInspectorFactory.getStandardConstantMapObjectInspector(keyOI, valueOI, map)
+
+ ObjectInspectorFactory.getStandardConstantMapObjectInspector(keyOI, valueOI, jmap)
}
// We will enumerate all of the possible constant expressions, throw exception if we missed
case Literal(_, dt) => sys.error(s"Hive doesn't support the constant type [$dt].")
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala
index e4944caeff..40a6a32156 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala
@@ -95,9 +95,9 @@ case class InsertIntoHiveTable(
.asInstanceOf[StructObjectInspector]
val fieldOIs = standardOI.getAllStructFieldRefs.map(_.getFieldObjectInspector).toArray
- val wrappers = fieldOIs.map(wrapperFor)
- val outputData = new Array[Any](fieldOIs.length)
val dataTypes: Array[DataType] = child.output.map(_.dataType).toArray
+ val wrappers = fieldOIs.zip(dataTypes).map { case (f, dt) => wrapperFor(f, dt)}
+ val outputData = new Array[Any](fieldOIs.length)
writerContainer.executorSideSetup(context.stageId, context.partitionId, context.attemptNumber)
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala
index 4a13022edd..abe5c69003 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala
@@ -428,10 +428,10 @@ private[hive] case class HiveWindowFunction(
// if pivotResult is false, we will get a single value for all rows in the frame.
outputBuffer
} else {
- // if pivotResult is true, we will get a Seq having the same size with the size
+ // if pivotResult is true, we will get a ArrayData having the same size with the size
// of the window frame. At here, we will return the result at the position of
// index in the output buffer.
- outputBuffer.asInstanceOf[ArrayData].get(index)
+ outputBuffer.asInstanceOf[ArrayData].get(index, dataType)
}
}
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala
index 924f4d37ce..6fa5997348 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala
@@ -95,9 +95,10 @@ private[orc] class OrcOutputWriter(
private val reusableOutputBuffer = new Array[Any](dataSchema.length)
// Used to convert Catalyst values into Hadoop `Writable`s.
- private val wrappers = structOI.getAllStructFieldRefs.map { ref =>
- wrapperFor(ref.getFieldObjectInspector)
- }.toArray
+ private val wrappers = structOI.getAllStructFieldRefs.zip(dataSchema.fields.map(_.dataType))
+ .map { case (ref, dt) =>
+ wrapperFor(ref.getFieldObjectInspector, dt)
+ }.toArray
// `OrcRecordWriter.close()` creates an empty file if no rows are written at all. We use this
// flag to decide whether `OrcRecordWriter.close()` needs to be called.
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveInspectorSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveInspectorSuite.scala
index f719f2e06a..99e95fb921 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveInspectorSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveInspectorSuite.scala
@@ -147,6 +147,8 @@ class HiveInspectorSuite extends SparkFunSuite with HiveInspectors {
case (r1: Array[Byte], r2: Array[Byte])
if r1 != null && r2 != null && r1.length == r2.length =>
r1.zip(r2).foreach { case (b1, b2) => assert(b1 === b2) }
+ // We don't support equality & ordering for map type, so skip it.
+ case (r1: MapData, r2: MapData) =>
case (r1, r2) => assert(r1 === r2)
}
}
@@ -230,7 +232,7 @@ class HiveInspectorSuite extends SparkFunSuite with HiveInspectors {
test("wrap / unwrap Map Type") {
val dt = MapType(dataTypes(0), dataTypes(1))
- val d = Map(row(0) -> row(1))
+ val d = ArrayBasedMapData(Array(row(0)), Array(row(1)))
checkValue(d, unwrap(wrap(d, toInspector(dt), dt), toInspector(dt)))
checkValue(null, unwrap(wrap(null, toInspector(dt), dt), toInspector(dt)))
checkValue(d,