diff options
5 files changed, 119 insertions, 68 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala index f542f5cf40..5b9161551a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala @@ -199,34 +199,14 @@ object CatalystTypeConverters { private[this] val keyConverter = getConverterForType(keyType) private[this] val valueConverter = getConverterForType(valueType) - override def toCatalystImpl(scalaValue: Any): MapData = scalaValue match { - case m: Map[_, _] => - val length = m.size - val convertedKeys = new Array[Any](length) - val convertedValues = new Array[Any](length) - - var i = 0 - for ((key, value) <- m) { - convertedKeys(i) = keyConverter.toCatalyst(key) - convertedValues(i) = valueConverter.toCatalyst(value) - i += 1 - } - ArrayBasedMapData(convertedKeys, convertedValues) - - case jmap: JavaMap[_, _] => - val length = jmap.size() - val convertedKeys = new Array[Any](length) - val convertedValues = new Array[Any](length) - - var i = 0 - val iter = jmap.entrySet.iterator - while (iter.hasNext) { - val entry = iter.next() - convertedKeys(i) = keyConverter.toCatalyst(entry.getKey) - convertedValues(i) = valueConverter.toCatalyst(entry.getValue) - i += 1 - } - ArrayBasedMapData(convertedKeys, convertedValues) + override def toCatalystImpl(scalaValue: Any): MapData = { + val keyFunction = (k: Any) => keyConverter.toCatalyst(k) + val valueFunction = (k: Any) => valueConverter.toCatalyst(k) + + scalaValue match { + case map: Map[_, _] => ArrayBasedMapData(map, keyFunction, valueFunction) + case javaMap: JavaMap[_, _] => ArrayBasedMapData(javaMap, keyFunction, valueFunction) + } } override def toScala(catalystValue: MapData): Map[Any, Any] = { @@ -433,18 +413,11 @@ object CatalystTypeConverters { case seq: Seq[Any] => new GenericArrayData(seq.map(convertToCatalyst).toArray) case r: Row => InternalRow(r.toSeq.map(convertToCatalyst): _*) case arr: Array[Any] => new GenericArrayData(arr.map(convertToCatalyst)) - case m: Map[_, _] => - val length = m.size - val convertedKeys = new Array[Any](length) - val convertedValues = new Array[Any](length) - - var i = 0 - for ((key, value) <- m) { - convertedKeys(i) = convertToCatalyst(key) - convertedValues(i) = convertToCatalyst(value) - i += 1 - } - ArrayBasedMapData(convertedKeys, convertedValues) + case map: Map[_, _] => + ArrayBasedMapData( + map, + (key: Any) => convertToCatalyst(key), + (value: Any) => convertToCatalyst(value)) case other => other } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala index 09e22aaf3e..917aa08731 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala @@ -427,18 +427,28 @@ case class StringToMap(text: Expression, pairDelim: Expression, keyValueDelim: E } } - override def nullSafeEval(str: Any, delim1: Any, delim2: Any): Any = { - val array = str.asInstanceOf[UTF8String] - .split(delim1.asInstanceOf[UTF8String], -1) - .map { kv => - val arr = kv.split(delim2.asInstanceOf[UTF8String], 2) - if (arr.length < 2) { - Array(arr(0), null) - } else { - arr - } + override def nullSafeEval( + inputString: Any, + stringDelimiter: Any, + keyValueDelimiter: Any): Any = { + val keyValues = + inputString.asInstanceOf[UTF8String].split(stringDelimiter.asInstanceOf[UTF8String], -1) + + val iterator = new Iterator[(UTF8String, UTF8String)] { + var index = 0 + val keyValueDelimiterUTF8String = keyValueDelimiter.asInstanceOf[UTF8String] + + override def hasNext: Boolean = { + keyValues.length > index } - ArrayBasedMapData(array.map(_ (0)), array.map(_ (1))) + + override def next(): (UTF8String, UTF8String) = { + val keyValueArray = keyValues(index).split(keyValueDelimiterUTF8String, 2) + index += 1 + (keyValueArray(0), if (keyValueArray.length < 2) null else keyValueArray(1)) + } + } + ArrayBasedMapData(iterator, keyValues.size, identity, identity) } override def prettyName: String = "str_to_map" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ArrayBasedMapData.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ArrayBasedMapData.scala index 4449da13c0..91b3139443 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ArrayBasedMapData.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ArrayBasedMapData.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.catalyst.util +import java.util.{Map => JavaMap} + class ArrayBasedMapData(val keyArray: ArrayData, val valueArray: ArrayData) extends MapData { require(keyArray.numElements() == valueArray.numElements()) @@ -30,12 +32,83 @@ class ArrayBasedMapData(val keyArray: ArrayData, val valueArray: ArrayData) exte } object ArrayBasedMapData { - def apply(map: Map[Any, Any]): ArrayBasedMapData = { - val array = map.toArray - ArrayBasedMapData(array.map(_._1), array.map(_._2)) + /** + * Creates a [[ArrayBasedMapData]] by applying the given converters over + * each (key -> value) pair of the input [[java.util.Map]] + * + * @param javaMap Input map + * @param keyConverter This function is applied over all the keys of the input map to + * obtain the output map's keys + * @param valueConverter This function is applied over all the values of the input map to + * obtain the output map's values + */ + def apply( + javaMap: JavaMap[_, _], + keyConverter: (Any) => Any, + valueConverter: (Any) => Any): ArrayBasedMapData = { + import scala.language.existentials + + val keys: Array[Any] = new Array[Any](javaMap.size()) + val values: Array[Any] = new Array[Any](javaMap.size()) + + var i: Int = 0 + val iterator = javaMap.entrySet().iterator() + while (iterator.hasNext) { + val entry = iterator.next() + keys(i) = keyConverter(entry.getKey) + values(i) = valueConverter(entry.getValue) + i += 1 + } + ArrayBasedMapData(keys, values) + } + + /** + * Creates a [[ArrayBasedMapData]] by applying the given converters over + * each (key -> value) pair of the input map + * + * @param map Input map + * @param keyConverter This function is applied over all the keys of the input map to + * obtain the output map's keys + * @param valueConverter This function is applied over all the values of the input map to + * obtain the output map's values + */ + def apply( + map: scala.collection.Map[_, _], + keyConverter: (Any) => Any = identity, + valueConverter: (Any) => Any = identity): ArrayBasedMapData = { + ArrayBasedMapData(map.iterator, map.size, keyConverter, valueConverter) + } + + /** + * Creates a [[ArrayBasedMapData]] by applying the given converters over + * each (key -> value) pair from the given iterator + * + * @param iterator Input iterator + * @param size Number of elements + * @param keyConverter This function is applied over all the keys extracted from the + * given iterator to obtain the output map's keys + * @param valueConverter This function is applied over all the values extracted from the + * given iterator to obtain the output map's values + */ + def apply( + iterator: Iterator[(_, _)], + size: Int, + keyConverter: (Any) => Any, + valueConverter: (Any) => Any): ArrayBasedMapData = { + + val keys: Array[Any] = new Array[Any](size) + val values: Array[Any] = new Array[Any](size) + + var i = 0 + for ((key, value) <- iterator) { + keys(i) = keyConverter(key) + values(i) = valueConverter(value) + i += 1 + } + ArrayBasedMapData(keys, values) } - def apply(keys: Array[Any], values: Array[Any]): ArrayBasedMapData = { + def apply(keys: Array[_], values: Array[_]): ArrayBasedMapData = { new ArrayBasedMapData(new GenericArrayData(keys), new GenericArrayData(values)) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvaluatePython.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvaluatePython.scala index 724025b464..46fd54e5c7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvaluatePython.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvaluatePython.scala @@ -124,11 +124,11 @@ object EvaluatePython { case (c, ArrayType(elementType, _)) if c.getClass.isArray => new GenericArrayData(c.asInstanceOf[Array[_]].map(e => fromJava(e, elementType))) - case (c: java.util.Map[_, _], MapType(keyType, valueType, _)) => - val keyValues = c.asScala.toSeq - val keys = keyValues.map(kv => fromJava(kv._1, keyType)).toArray - val values = keyValues.map(kv => fromJava(kv._2, valueType)).toArray - ArrayBasedMapData(keys, values) + case (javaMap: java.util.Map[_, _], MapType(keyType, valueType, _)) => + ArrayBasedMapData( + javaMap, + (key: Any) => fromJava(key, keyType), + (value: Any) => fromJava(value, valueType)) case (c, StructType(fields)) if c.getClass.isArray => val array = c.asInstanceOf[Array[_]] diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala index 1625116803..e303065127 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala @@ -473,10 +473,8 @@ private[hive] trait HiveInspectors { case mi: StandardConstantMapObjectInspector => val keyUnwrapper = unwrapperFor(mi.getMapKeyObjectInspector) val valueUnwrapper = unwrapperFor(mi.getMapValueObjectInspector) - val keyValues = mi.getWritableConstantValue.asScala.toSeq - val keys = keyValues.map(kv => keyUnwrapper(kv._1)).toArray - val values = keyValues.map(kv => valueUnwrapper(kv._2)).toArray - val constant = ArrayBasedMapData(keys, values) + val keyValues = mi.getWritableConstantValue + val constant = ArrayBasedMapData(keyValues, keyUnwrapper, valueUnwrapper) _ => constant case li: StandardConstantListObjectInspector => val unwrapper = unwrapperFor(li.getListElementObjectInspector) @@ -655,10 +653,7 @@ private[hive] trait HiveInspectors { if (map == null) { null } else { - val keyValues = map.asScala.toSeq - val keys = keyValues.map(kv => keyUnwrapper(kv._1)).toArray - val values = keyValues.map(kv => valueUnwrapper(kv._2)).toArray - ArrayBasedMapData(keys, values) + ArrayBasedMapData(map, keyUnwrapper, valueUnwrapper) } } else { null |