aboutsummaryrefslogtreecommitdiff
path: root/sql/catalyst/src
diff options
context:
space:
mode:
authorRyan Blue <blue@apache.org>2016-11-11 13:52:10 -0800
committerReynold Xin <rxin@databricks.com>2016-11-11 13:52:10 -0800
commit6e95325fc3726d260054bd6e7c0717b3c139917e (patch)
tree6e1248728b75908abfd0c9c56d68c40b5c4a5f67 /sql/catalyst/src
parentd42bb7cc4e32c173769bd7da5b9b5eafb510860c (diff)
downloadspark-6e95325fc3726d260054bd6e7c0717b3c139917e.tar.gz
spark-6e95325fc3726d260054bd6e7c0717b3c139917e.tar.bz2
spark-6e95325fc3726d260054bd6e7c0717b3c139917e.zip
[SPARK-18387][SQL] Add serialization to checkEvaluation.
## What changes were proposed in this pull request? This removes the serialization test from RegexpExpressionsSuite and replaces it by serializing all expressions in checkEvaluation. This also fixes math constant expressions by making LeafMathExpression Serializable and fixes NumberFormat values that are null or invalid after serialization. ## How was this patch tested? This patch is to tests. Author: Ryan Blue <blue@apache.org> Closes #15847 from rdblue/SPARK-18387-fix-serializable-expressions.
Diffstat (limited to 'sql/catalyst/src')
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala2
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala44
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala15
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/RegexpExpressionsSuite.scala16
4 files changed, 36 insertions, 41 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala
index a60494a5bb..65273a77b1 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala
@@ -36,7 +36,7 @@ import org.apache.spark.unsafe.types.UTF8String
* @param name The short name of the function
*/
abstract class LeafMathExpression(c: Double, name: String)
- extends LeafExpression with CodegenFallback {
+ extends LeafExpression with CodegenFallback with Serializable {
override def dataType: DataType = DoubleType
override def foldable: Boolean = true
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala
index 5f533fecf8..e74ef9a087 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala
@@ -1431,18 +1431,20 @@ case class FormatNumber(x: Expression, d: Expression)
// Associated with the pattern, for the last d value, and we will update the
// pattern (DecimalFormat) once the new coming d value differ with the last one.
+ // This is an Option to distinguish between 0 (numberFormat is valid) and uninitialized after
+ // serialization (numberFormat has not been updated for dValue = 0).
@transient
- private var lastDValue: Int = -100
+ private var lastDValue: Option[Int] = None
// A cached DecimalFormat, for performance concern, we will change it
// only if the d value changed.
@transient
- private val pattern: StringBuffer = new StringBuffer()
+ private lazy val pattern: StringBuffer = new StringBuffer()
// SPARK-13515: US Locale configures the DecimalFormat object to use a dot ('.')
// as a decimal separator.
@transient
- private val numberFormat = new DecimalFormat("", new DecimalFormatSymbols(Locale.US))
+ private lazy val numberFormat = new DecimalFormat("", new DecimalFormatSymbols(Locale.US))
override protected def nullSafeEval(xObject: Any, dObject: Any): Any = {
val dValue = dObject.asInstanceOf[Int]
@@ -1450,24 +1452,28 @@ case class FormatNumber(x: Expression, d: Expression)
return null
}
- if (dValue != lastDValue) {
- // construct a new DecimalFormat only if a new dValue
- pattern.delete(0, pattern.length)
- pattern.append("#,###,###,###,###,###,##0")
-
- // decimal place
- if (dValue > 0) {
- pattern.append(".")
-
- var i = 0
- while (i < dValue) {
- i += 1
- pattern.append("0")
+ lastDValue match {
+ case Some(last) if last == dValue =>
+ // use the current pattern
+ case _ =>
+ // construct a new DecimalFormat only if a new dValue
+ pattern.delete(0, pattern.length)
+ pattern.append("#,###,###,###,###,###,##0")
+
+ // decimal place
+ if (dValue > 0) {
+ pattern.append(".")
+
+ var i = 0
+ while (i < dValue) {
+ i += 1
+ pattern.append("0")
+ }
}
- }
- lastDValue = dValue
- numberFormat.applyLocalizedPattern(pattern.toString)
+ lastDValue = Some(dValue)
+
+ numberFormat.applyLocalizedPattern(pattern.toString)
}
x.dataType match {
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala
index 9ceb709185..f83650424a 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala
@@ -22,7 +22,8 @@ import org.scalactic.TripleEqualsSupport.Spread
import org.scalatest.exceptions.TestFailedException
import org.scalatest.prop.GeneratorDrivenPropertyChecks
-import org.apache.spark.SparkFunSuite
+import org.apache.spark.{SparkConf, SparkFunSuite}
+import org.apache.spark.serializer.JavaSerializer
import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow}
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.optimizer.SimpleTestOptimizer
@@ -43,13 +44,15 @@ trait ExpressionEvalHelper extends GeneratorDrivenPropertyChecks {
protected def checkEvaluation(
expression: => Expression, expected: Any, inputRow: InternalRow = EmptyRow): Unit = {
+ val serializer = new JavaSerializer(new SparkConf()).newInstance
+ val expr: Expression = serializer.deserialize(serializer.serialize(expression))
val catalystValue = CatalystTypeConverters.convertToCatalyst(expected)
- checkEvaluationWithoutCodegen(expression, catalystValue, inputRow)
- checkEvaluationWithGeneratedMutableProjection(expression, catalystValue, inputRow)
- if (GenerateUnsafeProjection.canSupport(expression.dataType)) {
- checkEvalutionWithUnsafeProjection(expression, catalystValue, inputRow)
+ checkEvaluationWithoutCodegen(expr, catalystValue, inputRow)
+ checkEvaluationWithGeneratedMutableProjection(expr, catalystValue, inputRow)
+ if (GenerateUnsafeProjection.canSupport(expr.dataType)) {
+ checkEvalutionWithUnsafeProjection(expr, catalystValue, inputRow)
}
- checkEvaluationWithOptimization(expression, catalystValue, inputRow)
+ checkEvaluationWithOptimization(expr, catalystValue, inputRow)
}
/**
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/RegexpExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/RegexpExpressionsSuite.scala
index d0d1aaa9d2..5299549e7b 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/RegexpExpressionsSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/RegexpExpressionsSuite.scala
@@ -17,8 +17,7 @@
package org.apache.spark.sql.catalyst.expressions
-import org.apache.spark.{SparkConf, SparkFunSuite}
-import org.apache.spark.serializer.JavaSerializer
+import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.types.StringType
@@ -192,17 +191,4 @@ class RegexpExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
checkEvaluation(StringSplit(s1, s2), null, row3)
}
- test("RegExpReplace serialization") {
- val serializer = new JavaSerializer(new SparkConf()).newInstance
-
- val row = create_row("abc", "b", "")
-
- val s = 's.string.at(0)
- val p = 'p.string.at(1)
- val r = 'r.string.at(2)
-
- val expr: RegExpReplace = serializer.deserialize(serializer.serialize(RegExpReplace(s, p, r)))
- checkEvaluation(expr, "ac", row)
- }
-
}