aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java6
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverter.scala4
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala3
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala30
4 files changed, 39 insertions, 4 deletions
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java
index aec88c9241..c4b7f8490a 100644
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java
@@ -103,7 +103,9 @@ public final class UnsafeRow extends BaseMutableRow {
IntegerType,
LongType,
FloatType,
- DoubleType
+ DoubleType,
+ DateType,
+ TimestampType
})));
// We support get() on a superset of the types for which we support set():
@@ -331,8 +333,6 @@ public final class UnsafeRow extends BaseMutableRow {
return getUTF8String(i).toString();
}
-
-
@Override
public InternalRow copy() {
throw new UnsupportedOperationException();
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverter.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverter.scala
index 5c92f41c63..72f740ecae 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverter.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverter.scala
@@ -17,6 +17,8 @@
package org.apache.spark.sql.catalyst.expressions
+import org.apache.spark.sql.catalyst.util.DateUtils
+import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.PlatformDependent
import org.apache.spark.unsafe.array.ByteArrayMethods
@@ -120,6 +122,8 @@ private object UnsafeColumnWriter {
case FloatType => FloatUnsafeColumnWriter
case DoubleType => DoubleUnsafeColumnWriter
case StringType => StringUnsafeColumnWriter
+ case DateType => IntUnsafeColumnWriter
+ case TimestampType => LongUnsafeColumnWriter
case t =>
throw new UnsupportedOperationException(s"Do not know how to write columns of type $t")
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala
index 534dac1f92..1098962ddc 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala
@@ -197,9 +197,10 @@ class GenericMutableRow(v: Array[Any]) extends GenericRow(v) with MutableRow {
override def setFloat(ordinal: Int, value: Float): Unit = { values(ordinal) = value }
override def setInt(ordinal: Int, value: Int): Unit = { values(ordinal) = value }
override def setLong(ordinal: Int, value: Long): Unit = { values(ordinal) = value }
- override def setString(ordinal: Int, value: String) {
+ override def setString(ordinal: Int, value: String): Unit = {
values(ordinal) = UTF8String.fromString(value)
}
+
override def setNullAt(i: Int): Unit = { values(i) = null }
override def setShort(ordinal: Int, value: Short): Unit = { values(ordinal) = value }
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala
index 577c7a0de0..721ef8a226 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala
@@ -17,12 +17,14 @@
package org.apache.spark.sql.catalyst.expressions
+import java.sql.{Date, Timestamp}
import java.util.Arrays
import org.scalatest.Matchers
import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.types._
+import org.apache.spark.sql.catalyst.util.DateUtils
import org.apache.spark.unsafe.PlatformDependent
import org.apache.spark.unsafe.array.ByteArrayMethods
@@ -74,6 +76,34 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers {
unsafeRow.getString(2) should be ("World")
}
+ test("basic conversion with primitive, string, date and timestamp types") {
+ val fieldTypes: Array[DataType] = Array(LongType, StringType, DateType, TimestampType)
+ val converter = new UnsafeRowConverter(fieldTypes)
+
+ val row = new SpecificMutableRow(fieldTypes)
+ row.setLong(0, 0)
+ row.setString(1, "Hello")
+ row.update(2, DateUtils.fromJavaDate(Date.valueOf("1970-01-01")))
+ row.update(3, DateUtils.fromJavaTimestamp(Timestamp.valueOf("2015-05-08 08:10:25")))
+
+ val sizeRequired: Int = converter.getSizeRequirement(row)
+ sizeRequired should be (8 + (8 * 4) +
+ ByteArrayMethods.roundNumberOfBytesToNearestWord("Hello".getBytes.length + 8))
+ val buffer: Array[Long] = new Array[Long](sizeRequired / 8)
+ val numBytesWritten = converter.writeRow(row, buffer, PlatformDependent.LONG_ARRAY_OFFSET)
+ numBytesWritten should be (sizeRequired)
+
+ val unsafeRow = new UnsafeRow()
+ unsafeRow.pointTo(buffer, PlatformDependent.LONG_ARRAY_OFFSET, fieldTypes.length, null)
+ unsafeRow.getLong(0) should be (0)
+ unsafeRow.getString(1) should be ("Hello")
+ // Date is represented as Int in unsafeRow
+ DateUtils.toJavaDate(unsafeRow.getInt(2)) should be (Date.valueOf("1970-01-01"))
+ // Timestamp is represented as Long in unsafeRow
+ DateUtils.toJavaTimestamp(unsafeRow.getLong(3)) should be
+ (Timestamp.valueOf("2015-05-08 08:10:25"))
+ }
+
test("null handling") {
val fieldTypes: Array[DataType] = Array(
NullType,