aboutsummaryrefslogtreecommitdiff
path: root/sql/catalyst
diff options
context:
space:
mode:
authorWenchen Fan <cloud0fan@outlook.com>2015-07-30 10:04:30 -0700
committerReynold Xin <rxin@databricks.com>2015-07-30 10:04:30 -0700
commitc0cc0eaec67208c087a30c1b1f50c00b2c1ebf08 (patch)
tree582bad5631cde3bac3b5c69e1f22b3c4098de684 /sql/catalyst
parent7492a33fdd074446c30c657d771a69932a00246d (diff)
downloadspark-c0cc0eaec67208c087a30c1b1f50c00b2c1ebf08.tar.gz
spark-c0cc0eaec67208c087a30c1b1f50c00b2c1ebf08.tar.bz2
spark-c0cc0eaec67208c087a30c1b1f50c00b2c1ebf08.zip
[SPARK-9390][SQL] create a wrapper for array type
Author: Wenchen Fan <cloud0fan@outlook.com> Closes #7724 from cloud-fan/array-data and squashes the following commits: d0408a1 [Wenchen Fan] fix python 661e608 [Wenchen Fan] rebase f39256c [Wenchen Fan] fix hive... 6dbfa6f [Wenchen Fan] fix hive again... 8cb8842 [Wenchen Fan] remove element type parameter from getArray 43e9816 [Wenchen Fan] fix mllib e719afc [Wenchen Fan] fix hive 4346290 [Wenchen Fan] address comment d4a38da [Wenchen Fan] remove sizeInBytes and add license 7e283e2 [Wenchen Fan] create a wrapper for array type
Diffstat (limited to 'sql/catalyst')
-rw-r--r--sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/SpecializedGetters.java2
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala29
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala2
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala2
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala39
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala28
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala4
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala10
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala20
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala59
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala4
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala12
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala3
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayData.scala121
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/types/GenericArrayData.scala59
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala21
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala2
17 files changed, 320 insertions, 97 deletions
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/SpecializedGetters.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/SpecializedGetters.java
index bc345dcd00..f7cea13688 100644
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/SpecializedGetters.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/SpecializedGetters.java
@@ -18,6 +18,7 @@
package org.apache.spark.sql.catalyst.expressions;
import org.apache.spark.sql.catalyst.InternalRow;
+import org.apache.spark.sql.types.ArrayData;
import org.apache.spark.sql.types.Decimal;
import org.apache.spark.unsafe.types.CalendarInterval;
import org.apache.spark.unsafe.types.UTF8String;
@@ -50,4 +51,5 @@ public interface SpecializedGetters {
InternalRow getStruct(int ordinal, int numFields);
+ ArrayData getArray(int ordinal);
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala
index d1d89a1f48..22452c0f20 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala
@@ -55,7 +55,6 @@ object CatalystTypeConverters {
private def isWholePrimitive(dt: DataType): Boolean = dt match {
case dt if isPrimitive(dt) => true
- case ArrayType(elementType, _) => isWholePrimitive(elementType)
case MapType(keyType, valueType, _) => isWholePrimitive(keyType) && isWholePrimitive(valueType)
case _ => false
}
@@ -154,39 +153,41 @@ object CatalystTypeConverters {
/** Converter for arrays, sequences, and Java iterables. */
private case class ArrayConverter(
- elementType: DataType) extends CatalystTypeConverter[Any, Seq[Any], Seq[Any]] {
+ elementType: DataType) extends CatalystTypeConverter[Any, Seq[Any], ArrayData] {
private[this] val elementConverter = getConverterForType(elementType)
private[this] val isNoChange = isWholePrimitive(elementType)
- override def toCatalystImpl(scalaValue: Any): Seq[Any] = {
+ override def toCatalystImpl(scalaValue: Any): ArrayData = {
scalaValue match {
- case a: Array[_] => a.toSeq.map(elementConverter.toCatalyst)
- case s: Seq[_] => s.map(elementConverter.toCatalyst)
+ case a: Array[_] =>
+ new GenericArrayData(a.map(elementConverter.toCatalyst))
+ case s: Seq[_] =>
+ new GenericArrayData(s.map(elementConverter.toCatalyst).toArray)
case i: JavaIterable[_] =>
val iter = i.iterator
- var convertedIterable: List[Any] = List()
+ val convertedIterable = scala.collection.mutable.ArrayBuffer.empty[Any]
while (iter.hasNext) {
val item = iter.next()
- convertedIterable :+= elementConverter.toCatalyst(item)
+ convertedIterable += elementConverter.toCatalyst(item)
}
- convertedIterable
+ new GenericArrayData(convertedIterable.toArray)
}
}
- override def toScala(catalystValue: Seq[Any]): Seq[Any] = {
+ override def toScala(catalystValue: ArrayData): Seq[Any] = {
if (catalystValue == null) {
null
} else if (isNoChange) {
- catalystValue
+ catalystValue.toArray()
} else {
- catalystValue.map(elementConverter.toScala)
+ catalystValue.toArray().map(elementConverter.toScala)
}
}
override def toScalaImpl(row: InternalRow, column: Int): Seq[Any] =
- toScala(row.get(column, ArrayType(elementType)).asInstanceOf[Seq[Any]])
+ toScala(row.getArray(column))
}
private case class MapConverter(
@@ -402,9 +403,9 @@ object CatalystTypeConverters {
case t: Timestamp => TimestampConverter.toCatalyst(t)
case d: BigDecimal => BigDecimalConverter.toCatalyst(d)
case d: JavaBigDecimal => BigDecimalConverter.toCatalyst(d)
- case seq: Seq[Any] => seq.map(convertToCatalyst)
+ case seq: Seq[Any] => new GenericArrayData(seq.map(convertToCatalyst).toArray)
case r: Row => InternalRow(r.toSeq.map(convertToCatalyst): _*)
- case arr: Array[Any] => arr.map(convertToCatalyst)
+ case arr: Array[Any] => new GenericArrayData(arr.map(convertToCatalyst))
case m: Map[_, _] =>
m.map { case (k, v) => (convertToCatalyst(k), convertToCatalyst(v)) }.toMap
case other => other
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala
index a5999e64ec..486ba03654 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala
@@ -76,6 +76,8 @@ abstract class InternalRow extends Serializable with SpecializedGetters {
override def getStruct(ordinal: Int, numFields: Int): InternalRow =
getAs[InternalRow](ordinal, null)
+ override def getArray(ordinal: Int): ArrayData = getAs(ordinal, null)
+
override def toString: String = s"[${this.mkString(",")}]"
/**
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala
index 371681b5d4..45709c1c8f 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala
@@ -65,7 +65,7 @@ case class BoundReference(ordinal: Int, dataType: DataType, nullable: Boolean)
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
val javaType = ctx.javaType(dataType)
- val value = ctx.getColumn("i", dataType, ordinal)
+ val value = ctx.getValue("i", dataType, ordinal.toString)
s"""
boolean ${ev.isNull} = i.isNullAt($ordinal);
$javaType ${ev.primitive} = ${ev.isNull} ? ${ctx.defaultValue(dataType)} : ($value);
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
index 8c01c13c9c..43be11c48a 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
@@ -363,7 +363,21 @@ case class Cast(child: Expression, dataType: DataType)
private[this] def castArray(from: ArrayType, to: ArrayType): Any => Any = {
val elementCast = cast(from.elementType, to.elementType)
- buildCast[Seq[Any]](_, _.map(v => if (v == null) null else elementCast(v)))
+ // TODO: Could be faster?
+ buildCast[ArrayData](_, array => {
+ val length = array.numElements()
+ val values = new Array[Any](length)
+ var i = 0
+ while (i < length) {
+ if (array.isNullAt(i)) {
+ values(i) = null
+ } else {
+ values(i) = elementCast(array.get(i))
+ }
+ i += 1
+ }
+ new GenericArrayData(values)
+ })
}
private[this] def castMap(from: MapType, to: MapType): Any => Any = {
@@ -789,37 +803,36 @@ case class Cast(child: Expression, dataType: DataType)
private[this] def castArrayCode(
from: ArrayType, to: ArrayType, ctx: CodeGenContext): CastFunction = {
val elementCast = nullSafeCastFunction(from.elementType, to.elementType, ctx)
-
- val arraySeqClass = classOf[mutable.ArraySeq[Any]].getName
+ val arrayClass = classOf[GenericArrayData].getName
val fromElementNull = ctx.freshName("feNull")
val fromElementPrim = ctx.freshName("fePrim")
val toElementNull = ctx.freshName("teNull")
val toElementPrim = ctx.freshName("tePrim")
val size = ctx.freshName("n")
val j = ctx.freshName("j")
- val result = ctx.freshName("result")
+ val values = ctx.freshName("values")
(c, evPrim, evNull) =>
s"""
- final int $size = $c.size();
- final $arraySeqClass<Object> $result = new $arraySeqClass<Object>($size);
+ final int $size = $c.numElements();
+ final Object[] $values = new Object[$size];
for (int $j = 0; $j < $size; $j ++) {
- if ($c.apply($j) == null) {
- $result.update($j, null);
+ if ($c.isNullAt($j)) {
+ $values[$j] = null;
} else {
boolean $fromElementNull = false;
${ctx.javaType(from.elementType)} $fromElementPrim =
- (${ctx.boxedType(from.elementType)}) $c.apply($j);
+ ${ctx.getValue(c, from.elementType, j)};
${castCode(ctx, fromElementPrim,
fromElementNull, toElementPrim, toElementNull, to.elementType, elementCast)}
if ($toElementNull) {
- $result.update($j, null);
+ $values[$j] = null;
} else {
- $result.update($j, $toElementPrim);
+ $values[$j] = $toElementPrim;
}
}
}
- $evPrim = $result;
+ $evPrim = new $arrayClass($values);
"""
}
@@ -891,7 +904,7 @@ case class Cast(child: Expression, dataType: DataType)
$result.setNullAt($i);
} else {
$fromType $fromFieldPrim =
- ${ctx.getColumn(tmpRow, from.fields(i).dataType, i)};
+ ${ctx.getValue(tmpRow, from.fields(i).dataType, i.toString)};
${castCode(ctx, fromFieldPrim,
fromFieldNull, toFieldPrim, toFieldNull, to.fields(i).dataType, cast)}
if ($toFieldNull) {
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
index 092f4c9fb0..c39e0df6fa 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
@@ -100,17 +100,18 @@ class CodeGenContext {
}
/**
- * Returns the code to access a column in Row for a given DataType.
+ * Returns the code to access a value in `SpecializedGetters` for a given DataType.
*/
- def getColumn(row: String, dataType: DataType, ordinal: Int): String = {
+ def getValue(getter: String, dataType: DataType, ordinal: String): String = {
val jt = javaType(dataType)
dataType match {
- case _ if isPrimitiveType(jt) => s"$row.get${primitiveTypeName(jt)}($ordinal)"
- case StringType => s"$row.getUTF8String($ordinal)"
- case BinaryType => s"$row.getBinary($ordinal)"
- case CalendarIntervalType => s"$row.getInterval($ordinal)"
- case t: StructType => s"$row.getStruct($ordinal, ${t.size})"
- case _ => s"($jt)$row.get($ordinal)"
+ case _ if isPrimitiveType(jt) => s"$getter.get${primitiveTypeName(jt)}($ordinal)"
+ case StringType => s"$getter.getUTF8String($ordinal)"
+ case BinaryType => s"$getter.getBinary($ordinal)"
+ case CalendarIntervalType => s"$getter.getInterval($ordinal)"
+ case t: StructType => s"$getter.getStruct($ordinal, ${t.size})"
+ case a: ArrayType => s"$getter.getArray($ordinal)"
+ case _ => s"($jt)$getter.get($ordinal)" // todo: remove generic getter.
}
}
@@ -152,8 +153,8 @@ class CodeGenContext {
case StringType => "UTF8String"
case CalendarIntervalType => "CalendarInterval"
case _: StructType => "InternalRow"
- case _: ArrayType => s"scala.collection.Seq"
- case _: MapType => s"scala.collection.Map"
+ case _: ArrayType => "ArrayData"
+ case _: MapType => "scala.collection.Map"
case dt: OpenHashSetUDT if dt.elementType == IntegerType => classOf[IntegerHashSet].getName
case dt: OpenHashSetUDT if dt.elementType == LongType => classOf[LongHashSet].getName
case _ => "Object"
@@ -214,7 +215,9 @@ class CodeGenContext {
case dt: DataType if isPrimitiveType(dt) => s"($c1 > $c2 ? 1 : $c1 < $c2 ? -1 : 0)"
case BinaryType => s"org.apache.spark.sql.catalyst.util.TypeUtils.compareBinary($c1, $c2)"
case NullType => "0"
- case other => s"$c1.compare($c2)"
+ case other if other.isInstanceOf[AtomicType] => s"$c1.compare($c2)"
+ case _ => throw new IllegalArgumentException(
+ "cannot generate compare code for un-comparable type")
}
/**
@@ -293,7 +296,8 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin
classOf[UnsafeRow].getName,
classOf[UTF8String].getName,
classOf[Decimal].getName,
- classOf[CalendarInterval].getName
+ classOf[CalendarInterval].getName,
+ classOf[ArrayData].getName
))
evaluator.setExtendedClass(classOf[GeneratedClass])
try {
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala
index 7be60114ce..a662357fb6 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala
@@ -153,14 +153,14 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
val nestedStructEv = GeneratedExpressionCode(
code = "",
isNull = s"${input.primitive}.isNullAt($i)",
- primitive = s"${ctx.getColumn(input.primitive, dt, i)}"
+ primitive = s"${ctx.getValue(input.primitive, dt, i.toString)}"
)
createCodeForStruct(ctx, nestedStructEv, st)
case _ =>
GeneratedExpressionCode(
code = "",
isNull = s"${input.primitive}.isNullAt($i)",
- primitive = s"${ctx.getColumn(input.primitive, dt, i)}"
+ primitive = s"${ctx.getValue(input.primitive, dt, i.toString)}"
)
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
index 2d92dcf23a..1a00dbc254 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
@@ -27,11 +27,15 @@ case class Size(child: Expression) extends UnaryExpression with ExpectsInputType
override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection(ArrayType, MapType))
override def nullSafeEval(value: Any): Int = child.dataType match {
- case ArrayType(_, _) => value.asInstanceOf[Seq[Any]].size
- case MapType(_, _, _) => value.asInstanceOf[Map[Any, Any]].size
+ case _: ArrayType => value.asInstanceOf[ArrayData].numElements()
+ case _: MapType => value.asInstanceOf[Map[Any, Any]].size
}
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
- nullSafeCodeGen(ctx, ev, c => s"${ev.primitive} = ($c).size();")
+ val sizeCall = child.dataType match {
+ case _: ArrayType => "numElements()"
+ case _: MapType => "size()"
+ }
+ nullSafeCodeGen(ctx, ev, c => s"${ev.primitive} = ($c).$sizeCall;")
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala
index 0517050a45..a145dfb4bb 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala
@@ -18,12 +18,9 @@
package org.apache.spark.sql.catalyst.expressions
import org.apache.spark.unsafe.types.UTF8String
-
-import scala.collection.mutable
-
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
-import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateUnsafeProjection, GeneratedExpressionCode, CodeGenContext}
+import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.util.TypeUtils
import org.apache.spark.sql.types._
@@ -46,25 +43,26 @@ case class CreateArray(children: Seq[Expression]) extends Expression {
override def nullable: Boolean = false
override def eval(input: InternalRow): Any = {
- children.map(_.eval(input))
+ new GenericArrayData(children.map(_.eval(input)).toArray)
}
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
- val arraySeqClass = classOf[mutable.ArraySeq[Any]].getName
+ val arrayClass = classOf[GenericArrayData].getName
s"""
- boolean ${ev.isNull} = false;
- $arraySeqClass<Object> ${ev.primitive} = new $arraySeqClass<Object>(${children.size});
+ final boolean ${ev.isNull} = false;
+ final Object[] values = new Object[${children.size}];
""" +
children.zipWithIndex.map { case (e, i) =>
val eval = e.gen(ctx)
eval.code + s"""
if (${eval.isNull}) {
- ${ev.primitive}.update($i, null);
+ values[$i] = null;
} else {
- ${ev.primitive}.update($i, ${eval.primitive});
+ values[$i] = ${eval.primitive};
}
"""
- }.mkString("\n")
+ }.mkString("\n") +
+ s"final ${ctx.javaType(dataType)} ${ev.primitive} = new $arrayClass(values);"
}
override def prettyName: String = "array"
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala
index 6331a9eb60..99393c9c76 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala
@@ -57,7 +57,8 @@ object ExtractValue {
case (ArrayType(StructType(fields), containsNull), NonNullLiteral(v, StringType)) =>
val fieldName = v.toString
val ordinal = findField(fields, fieldName, resolver)
- GetArrayStructFields(child, fields(ordinal).copy(name = fieldName), ordinal, containsNull)
+ GetArrayStructFields(child, fields(ordinal).copy(name = fieldName),
+ ordinal, fields.length, containsNull)
case (_: ArrayType, _) if extraction.dataType.isInstanceOf[IntegralType] =>
GetArrayItem(child, extraction)
@@ -118,7 +119,7 @@ case class GetStructField(child: Expression, field: StructField, ordinal: Int)
if ($eval.isNullAt($ordinal)) {
${ev.isNull} = true;
} else {
- ${ev.primitive} = ${ctx.getColumn(eval, dataType, ordinal)};
+ ${ev.primitive} = ${ctx.getValue(eval, dataType, ordinal.toString)};
}
"""
})
@@ -134,6 +135,7 @@ case class GetArrayStructFields(
child: Expression,
field: StructField,
ordinal: Int,
+ numFields: Int,
containsNull: Boolean) extends UnaryExpression {
override def dataType: DataType = ArrayType(field.dataType, containsNull)
@@ -141,26 +143,45 @@ case class GetArrayStructFields(
override def toString: String = s"$child.${field.name}"
protected override def nullSafeEval(input: Any): Any = {
- input.asInstanceOf[Seq[InternalRow]].map { row =>
- if (row == null) null else row.get(ordinal, field.dataType)
+ val array = input.asInstanceOf[ArrayData]
+ val length = array.numElements()
+ val result = new Array[Any](length)
+ var i = 0
+ while (i < length) {
+ if (array.isNullAt(i)) {
+ result(i) = null
+ } else {
+ val row = array.getStruct(i, numFields)
+ if (row.isNullAt(ordinal)) {
+ result(i) = null
+ } else {
+ result(i) = row.get(ordinal, field.dataType)
+ }
+ }
+ i += 1
}
+ new GenericArrayData(result)
}
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
- val arraySeqClass = "scala.collection.mutable.ArraySeq"
- // TODO: consider using Array[_] for ArrayType child to avoid
- // boxing of primitives
+ val arrayClass = classOf[GenericArrayData].getName
nullSafeCodeGen(ctx, ev, eval => {
s"""
- final int n = $eval.size();
- final $arraySeqClass<Object> values = new $arraySeqClass<Object>(n);
+ final int n = $eval.numElements();
+ final Object[] values = new Object[n];
for (int j = 0; j < n; j++) {
- InternalRow row = (InternalRow) $eval.apply(j);
- if (row != null && !row.isNullAt($ordinal)) {
- values.update(j, ${ctx.getColumn("row", field.dataType, ordinal)});
+ if ($eval.isNullAt(j)) {
+ values[j] = null;
+ } else {
+ final InternalRow row = $eval.getStruct(j, $numFields);
+ if (row.isNullAt($ordinal)) {
+ values[j] = null;
+ } else {
+ values[j] = ${ctx.getValue("row", field.dataType, ordinal.toString)};
+ }
}
}
- ${ev.primitive} = (${ctx.javaType(dataType)}) values;
+ ${ev.primitive} = new $arrayClass(values);
"""
})
}
@@ -186,23 +207,23 @@ case class GetArrayItem(child: Expression, ordinal: Expression) extends BinaryEx
protected override def nullSafeEval(value: Any, ordinal: Any): Any = {
// TODO: consider using Array[_] for ArrayType child to avoid
// boxing of primitives
- val baseValue = value.asInstanceOf[Seq[_]]
+ val baseValue = value.asInstanceOf[ArrayData]
val index = ordinal.asInstanceOf[Number].intValue()
- if (index >= baseValue.size || index < 0) {
+ if (index >= baseValue.numElements() || index < 0) {
null
} else {
- baseValue(index)
+ baseValue.get(index)
}
}
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
nullSafeCodeGen(ctx, ev, (eval1, eval2) => {
s"""
- final int index = (int)$eval2;
- if (index >= $eval1.size() || index < 0) {
+ final int index = (int) $eval2;
+ if (index >= $eval1.numElements() || index < 0) {
${ev.isNull} = true;
} else {
- ${ev.primitive} = (${ctx.boxedType(dataType)})$eval1.apply(index);
+ ${ev.primitive} = ${ctx.getValue(eval1, dataType, "index")};
}
"""
})
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala
index 2dbcf2830f..8064235c64 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala
@@ -121,8 +121,8 @@ case class Explode(child: Expression) extends UnaryExpression with Generator wit
override def eval(input: InternalRow): TraversableOnce[InternalRow] = {
child.dataType match {
case ArrayType(_, _) =>
- val inputArray = child.eval(input).asInstanceOf[Seq[Any]]
- if (inputArray == null) Nil else inputArray.map(v => InternalRow(v))
+ val inputArray = child.eval(input).asInstanceOf[ArrayData]
+ if (inputArray == null) Nil else inputArray.toArray().map(v => InternalRow(v))
case MapType(_, _, _) =>
val inputMap = child.eval(input).asInstanceOf[Map[Any, Any]]
if (inputMap == null) Nil
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala
index 5b3a64a096..79c0ca56a8 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala
@@ -92,7 +92,7 @@ case class ConcatWs(children: Seq[Expression])
val flatInputs = children.flatMap { child =>
child.eval(input) match {
case s: UTF8String => Iterator(s)
- case arr: Seq[_] => arr.asInstanceOf[Seq[UTF8String]]
+ case arr: ArrayData => arr.toArray().map(_.asInstanceOf[UTF8String])
case null => Iterator(null.asInstanceOf[UTF8String])
}
}
@@ -105,7 +105,7 @@ case class ConcatWs(children: Seq[Expression])
val evals = children.map(_.gen(ctx))
val inputs = evals.map { eval =>
- s"${eval.isNull} ? (UTF8String)null : ${eval.primitive}"
+ s"${eval.isNull} ? (UTF8String) null : ${eval.primitive}"
}.mkString(", ")
evals.map(_.code).mkString("\n") + s"""
@@ -665,13 +665,15 @@ case class StringSplit(str: Expression, pattern: Expression)
override def inputTypes: Seq[DataType] = Seq(StringType, StringType)
override def nullSafeEval(string: Any, regex: Any): Any = {
- string.asInstanceOf[UTF8String].split(regex.asInstanceOf[UTF8String], -1).toSeq
+ val strings = string.asInstanceOf[UTF8String].split(regex.asInstanceOf[UTF8String], -1)
+ new GenericArrayData(strings.asInstanceOf[Array[Any]])
}
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
+ val arrayClass = classOf[GenericArrayData].getName
nullSafeCodeGen(ctx, ev, (str, pattern) =>
- s"""${ev.primitive} = scala.collection.JavaConversions.asScalaBuffer(
- java.util.Arrays.asList($str.split($pattern, -1)));""")
+ // Array in java is covariant, so we don't need to cast UTF8String[] to Object[].
+ s"""${ev.primitive} = new $arrayClass($str.split($pattern, -1));""")
}
override def prettyName: String = "split"
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
index 813c620096..29d706dcb3 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
@@ -312,7 +312,8 @@ object NullPropagation extends Rule[LogicalPlan] {
case e @ GetMapValue(Literal(null, _), _) => Literal.create(null, e.dataType)
case e @ GetMapValue(_, Literal(null, _)) => Literal.create(null, e.dataType)
case e @ GetStructField(Literal(null, _), _, _) => Literal.create(null, e.dataType)
- case e @ GetArrayStructFields(Literal(null, _), _, _, _) => Literal.create(null, e.dataType)
+ case e @ GetArrayStructFields(Literal(null, _), _, _, _, _) =>
+ Literal.create(null, e.dataType)
case e @ EqualNullSafe(Literal(null, _), r) => IsNull(r)
case e @ EqualNullSafe(l, Literal(null, _)) => IsNull(l)
case e @ Count(expr) if !expr.nullable => Count(Literal(1))
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayData.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayData.scala
new file mode 100644
index 0000000000..14a7285877
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayData.scala
@@ -0,0 +1,121 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.types
+
+import org.apache.spark.sql.catalyst.expressions.SpecializedGetters
+
+abstract class ArrayData extends SpecializedGetters with Serializable {
+ // todo: remove this after we handle all types.(map type need special getter)
+ def get(ordinal: Int): Any
+
+ def numElements(): Int
+
+ // todo: need a more efficient way to iterate array type.
+ def toArray(): Array[Any] = {
+ val n = numElements()
+ val values = new Array[Any](n)
+ var i = 0
+ while (i < n) {
+ if (isNullAt(i)) {
+ values(i) = null
+ } else {
+ values(i) = get(i)
+ }
+ i += 1
+ }
+ values
+ }
+
+ override def toString(): String = toArray.mkString("[", ",", "]")
+
+ override def equals(o: Any): Boolean = {
+ if (!o.isInstanceOf[ArrayData]) {
+ return false
+ }
+
+ val other = o.asInstanceOf[ArrayData]
+ if (other eq null) {
+ return false
+ }
+
+ val len = numElements()
+ if (len != other.numElements()) {
+ return false
+ }
+
+ var i = 0
+ while (i < len) {
+ if (isNullAt(i) != other.isNullAt(i)) {
+ return false
+ }
+ if (!isNullAt(i)) {
+ val o1 = get(i)
+ val o2 = other.get(i)
+ o1 match {
+ case b1: Array[Byte] =>
+ if (!o2.isInstanceOf[Array[Byte]] ||
+ !java.util.Arrays.equals(b1, o2.asInstanceOf[Array[Byte]])) {
+ return false
+ }
+ case f1: Float if java.lang.Float.isNaN(f1) =>
+ if (!o2.isInstanceOf[Float] || ! java.lang.Float.isNaN(o2.asInstanceOf[Float])) {
+ return false
+ }
+ case d1: Double if java.lang.Double.isNaN(d1) =>
+ if (!o2.isInstanceOf[Double] || ! java.lang.Double.isNaN(o2.asInstanceOf[Double])) {
+ return false
+ }
+ case _ => if (o1 != o2) {
+ return false
+ }
+ }
+ }
+ i += 1
+ }
+ true
+ }
+
+ override def hashCode: Int = {
+ var result: Int = 37
+ var i = 0
+ val len = numElements()
+ while (i < len) {
+ val update: Int =
+ if (isNullAt(i)) {
+ 0
+ } else {
+ get(i) match {
+ case b: Boolean => if (b) 0 else 1
+ case b: Byte => b.toInt
+ case s: Short => s.toInt
+ case i: Int => i
+ case l: Long => (l ^ (l >>> 32)).toInt
+ case f: Float => java.lang.Float.floatToIntBits(f)
+ case d: Double =>
+ val b = java.lang.Double.doubleToLongBits(d)
+ (b ^ (b >>> 32)).toInt
+ case a: Array[Byte] => java.util.Arrays.hashCode(a)
+ case other => other.hashCode()
+ }
+ }
+ result = 37 * result + update
+ i += 1
+ }
+ result
+ }
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/GenericArrayData.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/GenericArrayData.scala
new file mode 100644
index 0000000000..7992ba947c
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/GenericArrayData.scala
@@ -0,0 +1,59 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.types
+
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.unsafe.types.{UTF8String, CalendarInterval}
+
+class GenericArrayData(array: Array[Any]) extends ArrayData {
+ private def getAs[T](ordinal: Int) = get(ordinal).asInstanceOf[T]
+
+ override def toArray(): Array[Any] = array
+
+ override def get(ordinal: Int): Any = array(ordinal)
+
+ override def isNullAt(ordinal: Int): Boolean = get(ordinal) == null
+
+ override def getBoolean(ordinal: Int): Boolean = getAs(ordinal)
+
+ override def getByte(ordinal: Int): Byte = getAs(ordinal)
+
+ override def getShort(ordinal: Int): Short = getAs(ordinal)
+
+ override def getInt(ordinal: Int): Int = getAs(ordinal)
+
+ override def getLong(ordinal: Int): Long = getAs(ordinal)
+
+ override def getFloat(ordinal: Int): Float = getAs(ordinal)
+
+ override def getDouble(ordinal: Int): Double = getAs(ordinal)
+
+ override def getDecimal(ordinal: Int): Decimal = getAs(ordinal)
+
+ override def getUTF8String(ordinal: Int): UTF8String = getAs(ordinal)
+
+ override def getBinary(ordinal: Int): Array[Byte] = getAs(ordinal)
+
+ override def getInterval(ordinal: Int): CalendarInterval = getAs(ordinal)
+
+ override def getStruct(ordinal: Int, numFields: Int): InternalRow = getAs(ordinal)
+
+ override def getArray(ordinal: Int): ArrayData = getAs(ordinal)
+
+ override def numElements(): Int = array.length
+}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala
index a517da9872..4f35b653d7 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala
@@ -21,6 +21,7 @@ import java.sql.{Timestamp, Date}
import java.util.{TimeZone, Calendar}
import org.apache.spark.SparkFunSuite
+import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.util.DateTimeUtils
import org.apache.spark.sql.types._
@@ -730,13 +731,10 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper {
test("complex casting") {
val complex = Literal.create(
- InternalRow(
- Seq(UTF8String.fromString("123"), UTF8String.fromString("abc"), UTF8String.fromString("")),
- Map(
- UTF8String.fromString("a") -> UTF8String.fromString("123"),
- UTF8String.fromString("b") -> UTF8String.fromString("abc"),
- UTF8String.fromString("c") -> UTF8String.fromString("")),
- InternalRow(0)),
+ Row(
+ Seq("123", "abc", ""),
+ Map("a" ->"123", "b" -> "abc", "c" -> ""),
+ Row(0)),
StructType(Seq(
StructField("a",
ArrayType(StringType, containsNull = false), nullable = true),
@@ -756,13 +754,10 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper {
StructField("l", LongType, nullable = true)))))))
assert(ret.resolved === true)
- checkEvaluation(ret, InternalRow(
+ checkEvaluation(ret, Row(
Seq(123, null, null),
- Map(
- UTF8String.fromString("a") -> true,
- UTF8String.fromString("b") -> true,
- UTF8String.fromString("c") -> false),
- InternalRow(0L)))
+ Map("a" -> true, "b" -> true, "c" -> false),
+ Row(0L)))
}
test("case between string and interval") {
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala
index 5de5ddce97..3fa246b69d 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala
@@ -110,7 +110,7 @@ class ComplexTypeSuite extends SparkFunSuite with ExpressionEvalHelper {
expr.dataType match {
case ArrayType(StructType(fields), containsNull) =>
val field = fields.find(_.name == fieldName).get
- GetArrayStructFields(expr, field, fields.indexOf(field), containsNull)
+ GetArrayStructFields(expr, field, fields.indexOf(field), fields.length, containsNull)
}
}