aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala53
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala32
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ArrayBasedMapData.scala81
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvaluatePython.scala10
-rw-r--r--sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala11
5 files changed, 119 insertions, 68 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala
index f542f5cf40..5b9161551a 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala
@@ -199,34 +199,14 @@ object CatalystTypeConverters {
private[this] val keyConverter = getConverterForType(keyType)
private[this] val valueConverter = getConverterForType(valueType)
- override def toCatalystImpl(scalaValue: Any): MapData = scalaValue match {
- case m: Map[_, _] =>
- val length = m.size
- val convertedKeys = new Array[Any](length)
- val convertedValues = new Array[Any](length)
-
- var i = 0
- for ((key, value) <- m) {
- convertedKeys(i) = keyConverter.toCatalyst(key)
- convertedValues(i) = valueConverter.toCatalyst(value)
- i += 1
- }
- ArrayBasedMapData(convertedKeys, convertedValues)
-
- case jmap: JavaMap[_, _] =>
- val length = jmap.size()
- val convertedKeys = new Array[Any](length)
- val convertedValues = new Array[Any](length)
-
- var i = 0
- val iter = jmap.entrySet.iterator
- while (iter.hasNext) {
- val entry = iter.next()
- convertedKeys(i) = keyConverter.toCatalyst(entry.getKey)
- convertedValues(i) = valueConverter.toCatalyst(entry.getValue)
- i += 1
- }
- ArrayBasedMapData(convertedKeys, convertedValues)
+ override def toCatalystImpl(scalaValue: Any): MapData = {
+ val keyFunction = (k: Any) => keyConverter.toCatalyst(k)
+ val valueFunction = (k: Any) => valueConverter.toCatalyst(k)
+
+ scalaValue match {
+ case map: Map[_, _] => ArrayBasedMapData(map, keyFunction, valueFunction)
+ case javaMap: JavaMap[_, _] => ArrayBasedMapData(javaMap, keyFunction, valueFunction)
+ }
}
override def toScala(catalystValue: MapData): Map[Any, Any] = {
@@ -433,18 +413,11 @@ object CatalystTypeConverters {
case seq: Seq[Any] => new GenericArrayData(seq.map(convertToCatalyst).toArray)
case r: Row => InternalRow(r.toSeq.map(convertToCatalyst): _*)
case arr: Array[Any] => new GenericArrayData(arr.map(convertToCatalyst))
- case m: Map[_, _] =>
- val length = m.size
- val convertedKeys = new Array[Any](length)
- val convertedValues = new Array[Any](length)
-
- var i = 0
- for ((key, value) <- m) {
- convertedKeys(i) = convertToCatalyst(key)
- convertedValues(i) = convertToCatalyst(value)
- i += 1
- }
- ArrayBasedMapData(convertedKeys, convertedValues)
+ case map: Map[_, _] =>
+ ArrayBasedMapData(
+ map,
+ (key: Any) => convertToCatalyst(key),
+ (value: Any) => convertToCatalyst(value))
case other => other
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala
index 09e22aaf3e..917aa08731 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala
@@ -427,18 +427,28 @@ case class StringToMap(text: Expression, pairDelim: Expression, keyValueDelim: E
}
}
- override def nullSafeEval(str: Any, delim1: Any, delim2: Any): Any = {
- val array = str.asInstanceOf[UTF8String]
- .split(delim1.asInstanceOf[UTF8String], -1)
- .map { kv =>
- val arr = kv.split(delim2.asInstanceOf[UTF8String], 2)
- if (arr.length < 2) {
- Array(arr(0), null)
- } else {
- arr
- }
+ override def nullSafeEval(
+ inputString: Any,
+ stringDelimiter: Any,
+ keyValueDelimiter: Any): Any = {
+ val keyValues =
+ inputString.asInstanceOf[UTF8String].split(stringDelimiter.asInstanceOf[UTF8String], -1)
+
+ val iterator = new Iterator[(UTF8String, UTF8String)] {
+ var index = 0
+ val keyValueDelimiterUTF8String = keyValueDelimiter.asInstanceOf[UTF8String]
+
+ override def hasNext: Boolean = {
+ keyValues.length > index
}
- ArrayBasedMapData(array.map(_ (0)), array.map(_ (1)))
+
+ override def next(): (UTF8String, UTF8String) = {
+ val keyValueArray = keyValues(index).split(keyValueDelimiterUTF8String, 2)
+ index += 1
+ (keyValueArray(0), if (keyValueArray.length < 2) null else keyValueArray(1))
+ }
+ }
+ ArrayBasedMapData(iterator, keyValues.size, identity, identity)
}
override def prettyName: String = "str_to_map"
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ArrayBasedMapData.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ArrayBasedMapData.scala
index 4449da13c0..91b3139443 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ArrayBasedMapData.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ArrayBasedMapData.scala
@@ -17,6 +17,8 @@
package org.apache.spark.sql.catalyst.util
+import java.util.{Map => JavaMap}
+
class ArrayBasedMapData(val keyArray: ArrayData, val valueArray: ArrayData) extends MapData {
require(keyArray.numElements() == valueArray.numElements())
@@ -30,12 +32,83 @@ class ArrayBasedMapData(val keyArray: ArrayData, val valueArray: ArrayData) exte
}
object ArrayBasedMapData {
- def apply(map: Map[Any, Any]): ArrayBasedMapData = {
- val array = map.toArray
- ArrayBasedMapData(array.map(_._1), array.map(_._2))
+ /**
+ * Creates a [[ArrayBasedMapData]] by applying the given converters over
+ * each (key -> value) pair of the input [[java.util.Map]]
+ *
+ * @param javaMap Input map
+ * @param keyConverter This function is applied over all the keys of the input map to
+ * obtain the output map's keys
+ * @param valueConverter This function is applied over all the values of the input map to
+ * obtain the output map's values
+ */
+ def apply(
+ javaMap: JavaMap[_, _],
+ keyConverter: (Any) => Any,
+ valueConverter: (Any) => Any): ArrayBasedMapData = {
+ import scala.language.existentials
+
+ val keys: Array[Any] = new Array[Any](javaMap.size())
+ val values: Array[Any] = new Array[Any](javaMap.size())
+
+ var i: Int = 0
+ val iterator = javaMap.entrySet().iterator()
+ while (iterator.hasNext) {
+ val entry = iterator.next()
+ keys(i) = keyConverter(entry.getKey)
+ values(i) = valueConverter(entry.getValue)
+ i += 1
+ }
+ ArrayBasedMapData(keys, values)
+ }
+
+ /**
+ * Creates a [[ArrayBasedMapData]] by applying the given converters over
+ * each (key -> value) pair of the input map
+ *
+ * @param map Input map
+ * @param keyConverter This function is applied over all the keys of the input map to
+ * obtain the output map's keys
+ * @param valueConverter This function is applied over all the values of the input map to
+ * obtain the output map's values
+ */
+ def apply(
+ map: scala.collection.Map[_, _],
+ keyConverter: (Any) => Any = identity,
+ valueConverter: (Any) => Any = identity): ArrayBasedMapData = {
+ ArrayBasedMapData(map.iterator, map.size, keyConverter, valueConverter)
+ }
+
+ /**
+ * Creates a [[ArrayBasedMapData]] by applying the given converters over
+ * each (key -> value) pair from the given iterator
+ *
+ * @param iterator Input iterator
+ * @param size Number of elements
+ * @param keyConverter This function is applied over all the keys extracted from the
+ * given iterator to obtain the output map's keys
+ * @param valueConverter This function is applied over all the values extracted from the
+ * given iterator to obtain the output map's values
+ */
+ def apply(
+ iterator: Iterator[(_, _)],
+ size: Int,
+ keyConverter: (Any) => Any,
+ valueConverter: (Any) => Any): ArrayBasedMapData = {
+
+ val keys: Array[Any] = new Array[Any](size)
+ val values: Array[Any] = new Array[Any](size)
+
+ var i = 0
+ for ((key, value) <- iterator) {
+ keys(i) = keyConverter(key)
+ values(i) = valueConverter(value)
+ i += 1
+ }
+ ArrayBasedMapData(keys, values)
}
- def apply(keys: Array[Any], values: Array[Any]): ArrayBasedMapData = {
+ def apply(keys: Array[_], values: Array[_]): ArrayBasedMapData = {
new ArrayBasedMapData(new GenericArrayData(keys), new GenericArrayData(values))
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvaluatePython.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvaluatePython.scala
index 724025b464..46fd54e5c7 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvaluatePython.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvaluatePython.scala
@@ -124,11 +124,11 @@ object EvaluatePython {
case (c, ArrayType(elementType, _)) if c.getClass.isArray =>
new GenericArrayData(c.asInstanceOf[Array[_]].map(e => fromJava(e, elementType)))
- case (c: java.util.Map[_, _], MapType(keyType, valueType, _)) =>
- val keyValues = c.asScala.toSeq
- val keys = keyValues.map(kv => fromJava(kv._1, keyType)).toArray
- val values = keyValues.map(kv => fromJava(kv._2, valueType)).toArray
- ArrayBasedMapData(keys, values)
+ case (javaMap: java.util.Map[_, _], MapType(keyType, valueType, _)) =>
+ ArrayBasedMapData(
+ javaMap,
+ (key: Any) => fromJava(key, keyType),
+ (value: Any) => fromJava(value, valueType))
case (c, StructType(fields)) if c.getClass.isArray =>
val array = c.asInstanceOf[Array[_]]
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala
index 1625116803..e303065127 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala
@@ -473,10 +473,8 @@ private[hive] trait HiveInspectors {
case mi: StandardConstantMapObjectInspector =>
val keyUnwrapper = unwrapperFor(mi.getMapKeyObjectInspector)
val valueUnwrapper = unwrapperFor(mi.getMapValueObjectInspector)
- val keyValues = mi.getWritableConstantValue.asScala.toSeq
- val keys = keyValues.map(kv => keyUnwrapper(kv._1)).toArray
- val values = keyValues.map(kv => valueUnwrapper(kv._2)).toArray
- val constant = ArrayBasedMapData(keys, values)
+ val keyValues = mi.getWritableConstantValue
+ val constant = ArrayBasedMapData(keyValues, keyUnwrapper, valueUnwrapper)
_ => constant
case li: StandardConstantListObjectInspector =>
val unwrapper = unwrapperFor(li.getListElementObjectInspector)
@@ -655,10 +653,7 @@ private[hive] trait HiveInspectors {
if (map == null) {
null
} else {
- val keyValues = map.asScala.toSeq
- val keys = keyValues.map(kv => keyUnwrapper(kv._1)).toArray
- val values = keyValues.map(kv => valueUnwrapper(kv._2)).toArray
- ArrayBasedMapData(keys, values)
+ ArrayBasedMapData(map, keyUnwrapper, valueUnwrapper)
}
} else {
null