aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
authorWenchen Fan <cloud0fan@outlook.com>2015-07-27 11:28:22 -0700
committerReynold Xin <rxin@databricks.com>2015-07-27 11:28:22 -0700
commit75438422c2cd90dca53f84879cddecfc2ee0e957 (patch)
treef602ad0cd494bb3e662aad343654c214fba52d8b /sql
parentdd9ae7945ab65d353ed2b113e0c1a00a0533ffd6 (diff)
downloadspark-75438422c2cd90dca53f84879cddecfc2ee0e957.tar.gz
spark-75438422c2cd90dca53f84879cddecfc2ee0e957.tar.bz2
spark-75438422c2cd90dca53f84879cddecfc2ee0e957.zip
[SPARK-9369][SQL] Support IntervalType in UnsafeRow
Author: Wenchen Fan <cloud0fan@outlook.com> Closes #7688 from cloud-fan/interval and squashes the following commits: 5b36b17 [Wenchen Fan] fix codegen a99ed50 [Wenchen Fan] address comment 9e6d319 [Wenchen Fan] Support IntervalType in UnsafeRow
Diffstat (limited to 'sql')
-rw-r--r--sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java23
-rw-r--r--sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRowWriters.java19
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala4
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala1
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala2
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala7
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala6
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala2
8 files changed, 50 insertions, 14 deletions
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java
index 0fb33dd5a1..fb084dd13b 100644
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java
@@ -29,6 +29,7 @@ import org.apache.spark.unsafe.PlatformDependent;
import org.apache.spark.unsafe.array.ByteArrayMethods;
import org.apache.spark.unsafe.bitset.BitSetMethods;
import org.apache.spark.unsafe.hash.Murmur3_x86_32;
+import org.apache.spark.unsafe.types.Interval;
import org.apache.spark.unsafe.types.UTF8String;
import static org.apache.spark.sql.types.DataTypes.*;
@@ -90,7 +91,8 @@ public final class UnsafeRow extends MutableRow {
final Set<DataType> _readableFieldTypes = new HashSet<>(
Arrays.asList(new DataType[]{
StringType,
- BinaryType
+ BinaryType,
+ IntervalType
}));
_readableFieldTypes.addAll(settableFieldTypes);
readableFieldTypes = Collections.unmodifiableSet(_readableFieldTypes);
@@ -333,11 +335,6 @@ public final class UnsafeRow extends MutableRow {
}
@Override
- public String getString(int ordinal) {
- return getUTF8String(ordinal).toString();
- }
-
- @Override
public byte[] getBinary(int ordinal) {
if (isNullAt(ordinal)) {
return null;
@@ -359,6 +356,20 @@ public final class UnsafeRow extends MutableRow {
}
@Override
+ public Interval getInterval(int ordinal) {
+ if (isNullAt(ordinal)) {
+ return null;
+ } else {
+ final long offsetAndSize = getLong(ordinal);
+ final int offset = (int) (offsetAndSize >> 32);
+ final int months = (int) PlatformDependent.UNSAFE.getLong(baseObject, baseOffset + offset);
+ final long microseconds =
+ PlatformDependent.UNSAFE.getLong(baseObject, baseOffset + offset + 8);
+ return new Interval(months, microseconds);
+ }
+ }
+
+ @Override
public UnsafeRow getStruct(int ordinal, int numFields) {
if (isNullAt(ordinal)) {
return null;
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRowWriters.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRowWriters.java
index 87521d1f23..0ba31d3b9b 100644
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRowWriters.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRowWriters.java
@@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.expressions;
import org.apache.spark.unsafe.PlatformDependent;
import org.apache.spark.unsafe.array.ByteArrayMethods;
import org.apache.spark.unsafe.types.ByteArray;
+import org.apache.spark.unsafe.types.Interval;
import org.apache.spark.unsafe.types.UTF8String;
/**
@@ -54,7 +55,7 @@ public class UnsafeRowWriters {
}
}
- /** Writer for bianry (byte array) type. */
+ /** Writer for binary (byte array) type. */
public static class BinaryWriter {
public static int getSize(byte[] input) {
@@ -80,4 +81,20 @@ public class UnsafeRowWriters {
}
}
+ /** Writer for interval type. */
+ public static class IntervalWriter {
+
+ public static int write(UnsafeRow target, int ordinal, int cursor, Interval input) {
+ final long offset = target.getBaseOffset() + cursor;
+
+ // Write the months and microseconds fields of Interval to the variable length portion.
+ PlatformDependent.UNSAFE.putLong(target.getBaseObject(), offset, input.months);
+ PlatformDependent.UNSAFE.putLong(target.getBaseObject(), offset + 8, input.microseconds);
+
+ // Set the fixed length portion.
+ target.setLong(ordinal, ((long) cursor) << 32);
+ return 16;
+ }
+ }
+
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala
index ad3977281d..9a11de3840 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala
@@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst
import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.types._
-import org.apache.spark.unsafe.types.UTF8String
+import org.apache.spark.unsafe.types.{Interval, UTF8String}
/**
* An abstract class for row used internal in Spark SQL, which only contain the columns as
@@ -60,6 +60,8 @@ abstract class InternalRow extends Serializable {
def getDecimal(ordinal: Int): Decimal = getAs[Decimal](ordinal, DecimalType.SYSTEM_DEFAULT)
+ def getInterval(ordinal: Int): Interval = getAs[Interval](ordinal, IntervalType)
+
// This is only use for test and will throw a null pointer exception if the position is null.
def getString(ordinal: Int): String = getUTF8String(ordinal).toString
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala
index 6b5c450e3f..41a877f214 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala
@@ -48,6 +48,7 @@ case class BoundReference(ordinal: Int, dataType: DataType, nullable: Boolean)
case DoubleType => input.getDouble(ordinal)
case StringType => input.getUTF8String(ordinal)
case BinaryType => input.getBinary(ordinal)
+ case IntervalType => input.getInterval(ordinal)
case t: StructType => input.getStruct(ordinal, t.size)
case dataType => input.get(ordinal, dataType)
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
index e208262da9..bd8b0177eb 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
@@ -630,7 +630,7 @@ case class Cast(child: Expression, dataType: DataType)
private[this] def castToIntervalCode(from: DataType): CastFunction = from match {
case StringType =>
(c, evPrim, evNull) =>
- s"$evPrim = org.apache.spark.unsafe.types.Interval.fromString($c.toString());"
+ s"$evPrim = Interval.fromString($c.toString());"
}
private[this] def decimalToTimestampCode(d: String): String =
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
index 2a1e288cb8..2f02c90b1d 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
@@ -79,7 +79,6 @@ class CodeGenContext {
mutableStates += ((javaType, variableName, initCode))
}
- final val intervalType: String = classOf[Interval].getName
final val JAVA_BOOLEAN = "boolean"
final val JAVA_BYTE = "byte"
final val JAVA_SHORT = "short"
@@ -109,6 +108,7 @@ class CodeGenContext {
case _ if isPrimitiveType(jt) => s"$row.get${primitiveTypeName(jt)}($ordinal)"
case StringType => s"$row.getUTF8String($ordinal)"
case BinaryType => s"$row.getBinary($ordinal)"
+ case IntervalType => s"$row.getInterval($ordinal)"
case t: StructType => s"$row.getStruct($ordinal, ${t.size})"
case _ => s"($jt)$row.get($ordinal)"
}
@@ -150,7 +150,7 @@ class CodeGenContext {
case dt: DecimalType => "Decimal"
case BinaryType => "byte[]"
case StringType => "UTF8String"
- case IntervalType => intervalType
+ case IntervalType => "Interval"
case _: StructType => "InternalRow"
case _: ArrayType => s"scala.collection.Seq"
case _: MapType => s"scala.collection.Map"
@@ -292,7 +292,8 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin
classOf[InternalRow].getName,
classOf[UnsafeRow].getName,
classOf[UTF8String].getName,
- classOf[Decimal].getName
+ classOf[Decimal].getName,
+ classOf[Interval].getName
))
evaluator.setExtendedClass(classOf[GeneratedClass])
try {
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala
index afd0d9cfa1..9d2161947b 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala
@@ -33,10 +33,12 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
private val StringWriter = classOf[UnsafeRowWriters.UTF8StringWriter].getName
private val BinaryWriter = classOf[UnsafeRowWriters.BinaryWriter].getName
+ private val IntervalWriter = classOf[UnsafeRowWriters.IntervalWriter].getName
/** Returns true iff we support this data type. */
def canSupport(dataType: DataType): Boolean = dataType match {
case t: AtomicType if !t.isInstanceOf[DecimalType] => true
+ case _: IntervalType => true
case NullType => true
case _ => false
}
@@ -68,6 +70,8 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
s" + (${exprs(i).isNull} ? 0 : $StringWriter.getSize(${exprs(i).primitive}))"
case BinaryType =>
s" + (${exprs(i).isNull} ? 0 : $BinaryWriter.getSize(${exprs(i).primitive}))"
+ case IntervalType =>
+ s" + (${exprs(i).isNull} ? 0 : 16)"
case _ => ""
}
}.mkString("")
@@ -80,6 +84,8 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
s"$cursorTerm += $StringWriter.write($ret, $i, $cursorTerm, ${exprs(i).primitive})"
case BinaryType =>
s"$cursorTerm += $BinaryWriter.write($ret, $i, $cursorTerm, ${exprs(i).primitive})"
+ case IntervalType =>
+ s"$cursorTerm += $IntervalWriter.write($ret, $i, $cursorTerm, ${exprs(i).primitive})"
case NullType => ""
case _ =>
throw new UnsupportedOperationException(s"Not supported DataType: ${e.dataType}")
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala
index 8b0f90cf3a..ab0cdc857c 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala
@@ -78,8 +78,6 @@ trait ExpressionEvalHelper {
generator
} catch {
case e: Throwable =>
- val ctx = new CodeGenContext
- val evaluated = expression.gen(ctx)
fail(
s"""
|Code generation of $expression failed: