aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
authorReynold Xin <rxin@databricks.com>2015-10-08 17:25:14 -0700
committerReynold Xin <rxin@databricks.com>2015-10-08 17:25:14 -0700
commit84ea287178247c163226e835490c9c70b17d8d3b (patch)
tree1e01cf3ce4db65842b0685d55e954e089a8ddf68 /sql
parent02149ff08eed3745086589a047adbce9a580389f (diff)
downloadspark-84ea287178247c163226e835490c9c70b17d8d3b.tar.gz
spark-84ea287178247c163226e835490c9c70b17d8d3b.tar.bz2
spark-84ea287178247c163226e835490c9c70b17d8d3b.zip
[SPARK-10914] UnsafeRow serialization breaks when two machines have different Oops size.
UnsafeRow contains 3 pieces of information when pointing to some data in memory (an object, a base offset, and length). When the row is serialized with Java/Kryo serialization, the object layout in memory can change if two machines have different pointer width (Oops in JVM). To reproduce, launch Spark using MASTER=local-cluster[2,1,1024] bin/spark-shell --conf "spark.executor.extraJavaOptions=-XX:-UseCompressedOops" And then run the following scala> sql("select 1 xx").collect() Author: Reynold Xin <rxin@databricks.com> Closes #9030 from rxin/SPARK-10914.
Diffstat (limited to 'sql')
-rw-r--r--sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java47
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/UnsafeRowSuite.scala29
2 files changed, 72 insertions, 4 deletions
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java
index e8ac2999c2..5af7ed5d6e 100644
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java
@@ -17,8 +17,7 @@
package org.apache.spark.sql.catalyst.expressions;
-import java.io.IOException;
-import java.io.OutputStream;
+import java.io.*;
import java.math.BigDecimal;
import java.math.BigInteger;
import java.util.Arrays;
@@ -26,6 +25,11 @@ import java.util.Collections;
import java.util.HashSet;
import java.util.Set;
+import com.esotericsoftware.kryo.Kryo;
+import com.esotericsoftware.kryo.KryoSerializable;
+import com.esotericsoftware.kryo.io.Input;
+import com.esotericsoftware.kryo.io.Output;
+
import org.apache.spark.sql.types.*;
import org.apache.spark.unsafe.Platform;
import org.apache.spark.unsafe.array.ByteArrayMethods;
@@ -35,6 +39,7 @@ import org.apache.spark.unsafe.types.CalendarInterval;
import org.apache.spark.unsafe.types.UTF8String;
import static org.apache.spark.sql.types.DataTypes.*;
+import static org.apache.spark.unsafe.Platform.BYTE_ARRAY_OFFSET;
/**
* An Unsafe implementation of Row which is backed by raw memory instead of Java objects.
@@ -52,7 +57,7 @@ import static org.apache.spark.sql.types.DataTypes.*;
*
* Instances of `UnsafeRow` act as pointers to row data stored in this format.
*/
-public final class UnsafeRow extends MutableRow {
+public final class UnsafeRow extends MutableRow implements Externalizable, KryoSerializable {
//////////////////////////////////////////////////////////////////////////////
// Static methods
@@ -596,4 +601,40 @@ public final class UnsafeRow extends MutableRow {
public void writeToMemory(Object target, long targetOffset) {
Platform.copyMemory(baseObject, baseOffset, target, targetOffset, sizeInBytes);
}
+
+ @Override
+ public void writeExternal(ObjectOutput out) throws IOException {
+ byte[] bytes = getBytes();
+ out.writeInt(bytes.length);
+ out.writeInt(this.numFields);
+ out.write(bytes);
+ }
+
+ @Override
+ public void readExternal(ObjectInput in) throws IOException, ClassNotFoundException {
+ this.baseOffset = BYTE_ARRAY_OFFSET;
+ this.sizeInBytes = in.readInt();
+ this.numFields = in.readInt();
+ this.bitSetWidthInBytes = calculateBitSetWidthInBytes(numFields);
+ this.baseObject = new byte[sizeInBytes];
+ in.readFully((byte[]) baseObject);
+ }
+
+ @Override
+ public void write(Kryo kryo, Output out) {
+ byte[] bytes = getBytes();
+ out.writeInt(bytes.length);
+ out.writeInt(this.numFields);
+ out.write(bytes);
+ }
+
+ @Override
+ public void read(Kryo kryo, Input in) {
+ this.baseOffset = BYTE_ARRAY_OFFSET;
+ this.sizeInBytes = in.readInt();
+ this.numFields = in.readInt();
+ this.bitSetWidthInBytes = calculateBitSetWidthInBytes(numFields);
+ this.baseObject = new byte[sizeInBytes];
+ in.read((byte[]) baseObject);
+ }
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UnsafeRowSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UnsafeRowSuite.scala
index 944d4e1134..7d1ee39d4b 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/UnsafeRowSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/UnsafeRowSuite.scala
@@ -19,7 +19,8 @@ package org.apache.spark.sql
import java.io.ByteArrayOutputStream
-import org.apache.spark.SparkFunSuite
+import org.apache.spark.{SparkConf, SparkFunSuite}
+import org.apache.spark.serializer.{KryoSerializer, JavaSerializer}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{UnsafeRow, UnsafeProjection}
import org.apache.spark.sql.types._
@@ -29,6 +30,32 @@ import org.apache.spark.unsafe.types.UTF8String
class UnsafeRowSuite extends SparkFunSuite {
+ test("UnsafeRow Java serialization") {
+ // serializing an UnsafeRow pointing to a large buffer should only serialize the relevant data
+ val data = new Array[Byte](1024)
+ val row = new UnsafeRow
+ row.pointTo(data, 1, 16)
+ row.setLong(0, 19285)
+
+ val ser = new JavaSerializer(new SparkConf).newInstance()
+ val row1 = ser.deserialize[UnsafeRow](ser.serialize(row))
+ assert(row1.getLong(0) == 19285)
+ assert(row1.getBaseObject().asInstanceOf[Array[Byte]].length == 16)
+ }
+
+ test("UnsafeRow Kryo serialization") {
+ // serializing an UnsafeRow pointing to a large buffer should only serialize the relevant data
+ val data = new Array[Byte](1024)
+ val row = new UnsafeRow
+ row.pointTo(data, 1, 16)
+ row.setLong(0, 19285)
+
+ val ser = new KryoSerializer(new SparkConf).newInstance()
+ val row1 = ser.deserialize[UnsafeRow](ser.serialize(row))
+ assert(row1.getLong(0) == 19285)
+ assert(row1.getBaseObject().asInstanceOf[Array[Byte]].length == 16)
+ }
+
test("bitset width calculation") {
assert(UnsafeRow.calculateBitSetWidthInBytes(0) === 0)
assert(UnsafeRow.calculateBitSetWidthInBytes(1) === 8)