aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
authorDavies Liu <davies@databricks.com>2015-06-04 10:28:59 -0700
committerDavies Liu <davies@databricks.com>2015-06-04 10:28:59 -0700
commitc8709dcfd1237ffa19ee9286e99ddf2718a616d8 (patch)
tree633db0167d2c5aec29525400e22d879ad2564f34 /sql
parent10ba1880878d0babcdc5c9b688df5458ea131531 (diff)
downloadspark-c8709dcfd1237ffa19ee9286e99ddf2718a616d8.tar.gz
spark-c8709dcfd1237ffa19ee9286e99ddf2718a616d8.tar.bz2
spark-c8709dcfd1237ffa19ee9286e99ddf2718a616d8.zip
[SPARK-7956] [SQL] Use Janino to compile SQL expressions into bytecode
In order to reduce the overhead of codegen, this PR switch to use Janino to compile SQL expressions into bytecode. After this, the time used to compile a SQL expression is decreased from 100ms to 5ms, which is necessary to turn on codegen for general workload, also tests. cc rxin Author: Davies Liu <davies@databricks.com> Closes #6479 from davies/janino and squashes the following commits: cc689f5 [Davies Liu] remove globalLock 262d848 [Davies Liu] Merge branch 'master' of github.com:apache/spark into janino eec3a33 [Davies Liu] address comments from Josh f37c8c3 [Davies Liu] fix DecimalType and cast to String 202298b [Davies Liu] Merge branch 'master' of github.com:apache/spark into janino a21e968 [Davies Liu] fix style 0ed3dc6 [Davies Liu] Merge branch 'master' of github.com:apache/spark into janino 551a851 [Davies Liu] fix tests c3bdffa [Davies Liu] remove print 6089ce5 [Davies Liu] change logging level 7e46ac3 [Davies Liu] fix style d8f0f6c [Davies Liu] Merge branch 'master' of github.com:apache/spark into janino da4926a [Davies Liu] fix tests 03660f3 [Davies Liu] WIP: use Janino to compile Java source f2629cd [Davies Liu] Merge branch 'master' of github.com:apache/spark into janino f7d66cf [Davies Liu] use template based string for codegen
Diffstat (limited to 'sql')
-rw-r--r--sql/catalyst/pom.xml16
-rw-r--r--sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java101
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/BaseMutableRow.java68
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/BaseRow.java190
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala797
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala87
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala146
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratePredicate.scala44
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala316
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/package.scala6
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala15
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/GeneratedEvaluationSuite.scala5
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/GeneratedMutableEvaluationSuite.scala7
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala11
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala162
15 files changed, 1105 insertions, 866 deletions
diff --git a/sql/catalyst/pom.xml b/sql/catalyst/pom.xml
index bf0a7327a5..f4b1cc3a4f 100644
--- a/sql/catalyst/pom.xml
+++ b/sql/catalyst/pom.xml
@@ -38,10 +38,6 @@
<dependencies>
<dependency>
<groupId>org.scala-lang</groupId>
- <artifactId>scala-compiler</artifactId>
- </dependency>
- <dependency>
- <groupId>org.scala-lang</groupId>
<artifactId>scala-reflect</artifactId>
</dependency>
@@ -67,6 +63,11 @@
<artifactId>scalacheck_${scala.binary.version}</artifactId>
<scope>test</scope>
</dependency>
+ <dependency>
+ <groupId>org.codehaus.janino</groupId>
+ <artifactId>janino</artifactId>
+ <version>2.7.8</version>
+ </dependency>
</dependencies>
<build>
<outputDirectory>target/scala-${scala.binary.version}/classes</outputDirectory>
@@ -108,13 +109,6 @@
<activation>
<property><name>!scala-2.11</name></property>
</activation>
- <dependencies>
- <dependency>
- <groupId>org.scalamacros</groupId>
- <artifactId>quasiquotes_${scala.binary.version}</artifactId>
- <version>${scala.macros.version}</version>
- </dependency>
- </dependencies>
</profile>
</profiles>
</project>
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java
index bb546b3086..ec97fe603c 100644
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java
@@ -17,23 +17,25 @@
package org.apache.spark.sql.catalyst.expressions;
-import scala.collection.Map;
+import javax.annotation.Nullable;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.HashSet;
+import java.util.Set;
+
import scala.collection.Seq;
import scala.collection.mutable.ArraySeq;
-import javax.annotation.Nullable;
-import java.math.BigDecimal;
-import java.sql.Date;
-import java.util.*;
-
import org.apache.spark.sql.Row;
+import org.apache.spark.sql.BaseMutableRow;
import org.apache.spark.sql.types.DataType;
-import static org.apache.spark.sql.types.DataTypes.*;
import org.apache.spark.sql.types.StructType;
import org.apache.spark.sql.types.UTF8String;
import org.apache.spark.unsafe.PlatformDependent;
import org.apache.spark.unsafe.bitset.BitSetMethods;
+import static org.apache.spark.sql.types.DataTypes.*;
+
/**
* An Unsafe implementation of Row which is backed by raw memory instead of Java objects.
*
@@ -49,7 +51,7 @@ import org.apache.spark.unsafe.bitset.BitSetMethods;
*
* Instances of `UnsafeRow` act as pointers to row data stored in this format.
*/
-public final class UnsafeRow implements MutableRow {
+public final class UnsafeRow extends BaseMutableRow {
private Object baseObject;
private long baseOffset;
@@ -228,21 +230,11 @@ public final class UnsafeRow implements MutableRow {
}
@Override
- public int length() {
- return size();
- }
-
- @Override
public StructType schema() {
return schema;
}
@Override
- public Object apply(int i) {
- return get(i);
- }
-
- @Override
public Object get(int i) {
assertIndexIsValid(i);
assert (schema != null) : "Schema must be defined when calling generic get() method";
@@ -339,60 +331,7 @@ public final class UnsafeRow implements MutableRow {
return getUTF8String(i).toString();
}
- @Override
- public BigDecimal getDecimal(int i) {
- throw new UnsupportedOperationException();
- }
-
- @Override
- public Date getDate(int i) {
- throw new UnsupportedOperationException();
- }
- @Override
- public <T> Seq<T> getSeq(int i) {
- throw new UnsupportedOperationException();
- }
-
- @Override
- public <T> List<T> getList(int i) {
- throw new UnsupportedOperationException();
- }
-
- @Override
- public <K, V> Map<K, V> getMap(int i) {
- throw new UnsupportedOperationException();
- }
-
- @Override
- public <T> scala.collection.immutable.Map<String, T> getValuesMap(Seq<String> fieldNames) {
- throw new UnsupportedOperationException();
- }
-
- @Override
- public <K, V> java.util.Map<K, V> getJavaMap(int i) {
- throw new UnsupportedOperationException();
- }
-
- @Override
- public Row getStruct(int i) {
- throw new UnsupportedOperationException();
- }
-
- @Override
- public <T> T getAs(int i) {
- throw new UnsupportedOperationException();
- }
-
- @Override
- public <T> T getAs(String fieldName) {
- throw new UnsupportedOperationException();
- }
-
- @Override
- public int fieldIndex(String name) {
- throw new UnsupportedOperationException();
- }
@Override
public Row copy() {
@@ -412,24 +351,4 @@ public final class UnsafeRow implements MutableRow {
}
return values;
}
-
- @Override
- public String toString() {
- return mkString("[", ",", "]");
- }
-
- @Override
- public String mkString() {
- return toSeq().mkString();
- }
-
- @Override
- public String mkString(String sep) {
- return toSeq().mkString(sep);
- }
-
- @Override
- public String mkString(String start, String sep, String end) {
- return toSeq().mkString(start, sep, end);
- }
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/BaseMutableRow.java b/sql/catalyst/src/main/scala/org/apache/spark/sql/BaseMutableRow.java
new file mode 100644
index 0000000000..acec2bf452
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/BaseMutableRow.java
@@ -0,0 +1,68 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql;
+
+import org.apache.spark.sql.catalyst.expressions.MutableRow;
+
+public abstract class BaseMutableRow extends BaseRow implements MutableRow {
+
+ @Override
+ public void update(int ordinal, Object value) {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public void setInt(int ordinal, int value) {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public void setLong(int ordinal, long value) {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public void setDouble(int ordinal, double value) {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public void setBoolean(int ordinal, boolean value) {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public void setShort(int ordinal, short value) {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public void setByte(int ordinal, byte value) {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public void setFloat(int ordinal, float value) {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public void setString(int ordinal, String value) {
+ throw new UnsupportedOperationException();
+ }
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/BaseRow.java b/sql/catalyst/src/main/scala/org/apache/spark/sql/BaseRow.java
new file mode 100644
index 0000000000..d138b43a34
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/BaseRow.java
@@ -0,0 +1,190 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql;
+
+import java.math.BigDecimal;
+import java.sql.Date;
+import java.util.List;
+
+import scala.collection.Seq;
+import scala.collection.mutable.ArraySeq;
+
+import org.apache.spark.sql.catalyst.expressions.GenericRow;
+import org.apache.spark.sql.types.StructType;
+
+public abstract class BaseRow implements Row {
+
+ @Override
+ final public int length() {
+ return size();
+ }
+
+ @Override
+ public boolean anyNull() {
+ final int n = size();
+ for (int i=0; i < n; i++) {
+ if (isNullAt(i)) {
+ return true;
+ }
+ }
+ return false;
+ }
+
+ @Override
+ public StructType schema() { throw new UnsupportedOperationException(); }
+
+ @Override
+ final public Object apply(int i) {
+ return get(i);
+ }
+
+ @Override
+ public int getInt(int i) {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public long getLong(int i) {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public float getFloat(int i) {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public double getDouble(int i) {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public byte getByte(int i) {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public short getShort(int i) {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public boolean getBoolean(int i) {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public String getString(int i) {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public BigDecimal getDecimal(int i) {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public Date getDate(int i) {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public <T> Seq<T> getSeq(int i) {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public <T> List<T> getList(int i) {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public <K, V> scala.collection.Map<K, V> getMap(int i) {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public <T> scala.collection.immutable.Map<String, T> getValuesMap(Seq<String> fieldNames) {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public <K, V> java.util.Map<K, V> getJavaMap(int i) {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public Row getStruct(int i) {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public <T> T getAs(int i) {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public <T> T getAs(String fieldName) {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public int fieldIndex(String name) {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public Row copy() {
+ final int n = size();
+ Object[] arr = new Object[n];
+ for (int i = 0; i < n; i++) {
+ arr[i] = get(i);
+ }
+ return new GenericRow(arr);
+ }
+
+ @Override
+ public Seq<Object> toSeq() {
+ final int n = size();
+ final ArraySeq<Object> values = new ArraySeq<Object>(n);
+ for (int i = 0; i < n; i++) {
+ values.update(i, get(i));
+ }
+ return values;
+ }
+
+ @Override
+ public String toString() {
+ return mkString("[", ",", "]");
+ }
+
+ @Override
+ public String mkString() {
+ return toSeq().mkString();
+ }
+
+ @Override
+ public String mkString(String sep) {
+ return toSeq().mkString(sep);
+ }
+
+ @Override
+ public String mkString(String start, String sep, String end) {
+ return toSeq().mkString(start, sep, end);
+ }
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
index 36964af68d..cd604121b7 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
@@ -17,10 +17,12 @@
package org.apache.spark.sql.catalyst.expressions.codegen
-import com.google.common.cache.{CacheLoader, CacheBuilder}
-
+import scala.collection.mutable
import scala.language.existentials
+import com.google.common.cache.{CacheBuilder, CacheLoader}
+import org.codehaus.janino.ClassBodyEvaluator
+
import org.apache.spark.Logging
import org.apache.spark.sql.catalyst.expressions
import org.apache.spark.sql.catalyst.expressions._
@@ -36,23 +38,15 @@ class LongHashSet extends org.apache.spark.util.collection.OpenHashSet[Long]
* expressions.
*/
abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Logging {
- import scala.reflect.runtime.{universe => ru}
- import scala.reflect.runtime.universe._
-
- import scala.tools.reflect.ToolBox
-
- protected val toolBox = runtimeMirror(getClass.getClassLoader).mkToolBox()
- protected val rowType = typeOf[Row]
- protected val mutableRowType = typeOf[MutableRow]
- protected val genericRowType = typeOf[GenericRow]
- protected val genericMutableRowType = typeOf[GenericMutableRow]
-
- protected val projectionType = typeOf[Projection]
- protected val mutableProjectionType = typeOf[MutableProjection]
+ protected val rowType = classOf[Row].getName
+ protected val stringType = classOf[UTF8String].getName
+ protected val decimalType = classOf[Decimal].getName
+ protected val exprType = classOf[Expression].getName
+ protected val mutableRowType = classOf[MutableRow].getName
+ protected val genericMutableRowType = classOf[GenericMutableRow].getName
private val curId = new java.util.concurrent.atomic.AtomicInteger()
- private val javaSeparator = "$"
/**
* Can be flipped on manually in the console to add (expensive) expression evaluation trace code.
@@ -75,6 +69,20 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin
protected def bind(in: InType, inputSchema: Seq[Attribute]): InType
/**
+ * Compile the Java source code into a Java class, using Janino.
+ *
+ * It will track the time used to compile
+ */
+ protected def compile(code: String): Class[_] = {
+ val startTime = System.nanoTime()
+ val clazz = new ClassBodyEvaluator(code).getClazz()
+ val endTime = System.nanoTime()
+ def timeMs: Double = (endTime - startTime).toDouble / 1000000
+ logDebug(s"Compiled Java code (${code.size} bytes) in $timeMs ms")
+ clazz
+ }
+
+ /**
* A cache of generated classes.
*
* From the Guava Docs: A Cache is similar to ConcurrentMap, but not quite the same. The most
@@ -87,7 +95,7 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin
.maximumSize(1000)
.build(
new CacheLoader[InType, OutType]() {
- override def load(in: InType): OutType = globalLock.synchronized {
+ override def load(in: InType): OutType = {
val startTime = System.nanoTime()
val result = create(in)
val endTime = System.nanoTime()
@@ -110,8 +118,8 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin
* (Since we aren't in a macro context we do not seem to have access to the built in `freshName`
* function.)
*/
- protected def freshName(prefix: String): TermName = {
- newTermName(s"$prefix$javaSeparator${curId.getAndIncrement}")
+ protected def freshName(prefix: String): String = {
+ s"$prefix${curId.getAndIncrement}"
}
/**
@@ -125,32 +133,51 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin
* @param objectTerm A possibly boxed version of the result of evaluating this expression.
*/
protected case class EvaluatedExpression(
- code: Seq[Tree],
- nullTerm: TermName,
- primitiveTerm: TermName,
- objectTerm: TermName)
+ code: String,
+ nullTerm: String,
+ primitiveTerm: String,
+ objectTerm: String)
+
+ /**
+ * A context for codegen, which is used to bookkeeping the expressions those are not supported
+ * by codegen, then they are evaluated directly. The unsupported expression is appended at the
+ * end of `references`, the position of it is kept in the code, used to access and evaluate it.
+ */
+ protected class CodeGenContext {
+ /**
+ * Holding all the expressions those do not support codegen, will be evaluated directly.
+ */
+ val references: mutable.ArrayBuffer[Expression] = new mutable.ArrayBuffer[Expression]()
+ }
+
+ /**
+ * Create a new codegen context for expression evaluator, used to store those
+ * expressions that don't support codegen
+ */
+ def newCodeGenContext(): CodeGenContext = {
+ new CodeGenContext()
+ }
/**
* Given an expression tree returns an [[EvaluatedExpression]], which contains Scala trees that
* can be used to determine the result of evaluating the expression on an input row.
*/
- def expressionEvaluator(e: Expression): EvaluatedExpression = {
+ def expressionEvaluator(e: Expression, ctx: CodeGenContext): EvaluatedExpression = {
val primitiveTerm = freshName("primitiveTerm")
val nullTerm = freshName("nullTerm")
val objectTerm = freshName("objectTerm")
implicit class Evaluate1(e: Expression) {
- def castOrNull(f: TermName => Tree, dataType: DataType): Seq[Tree] = {
- val eval = expressionEvaluator(e)
- eval.code ++
- q"""
- val $nullTerm = ${eval.nullTerm}
- val $primitiveTerm =
- if($nullTerm)
- ${defaultPrimitive(dataType)}
- else
- ${f(eval.primitiveTerm)}
- """.children
+ def castOrNull(f: String => String, dataType: DataType): String = {
+ val eval = expressionEvaluator(e, ctx)
+ eval.code +
+ s"""
+ boolean $nullTerm = ${eval.nullTerm};
+ ${primitiveForType(dataType)} $primitiveTerm = ${defaultPrimitive(dataType)};
+ if (!$nullTerm) {
+ $primitiveTerm = ${f(eval.primitiveTerm)};
+ }
+ """
}
}
@@ -163,529 +190,505 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin
*
* @param f a function from two primitive term names to a tree that evaluates them.
*/
- def evaluate(f: (TermName, TermName) => Tree): Seq[Tree] =
+ def evaluate(f: (String, String) => String): String =
evaluateAs(expressions._1.dataType)(f)
- def evaluateAs(resultType: DataType)(f: (TermName, TermName) => Tree): Seq[Tree] = {
+ def evaluateAs(resultType: DataType)(f: (String, String) => String): String = {
// TODO: Right now some timestamp tests fail if we enforce this...
if (expressions._1.dataType != expressions._2.dataType) {
log.warn(s"${expressions._1.dataType} != ${expressions._2.dataType}")
}
- val eval1 = expressionEvaluator(expressions._1)
- val eval2 = expressionEvaluator(expressions._2)
+ val eval1 = expressionEvaluator(expressions._1, ctx)
+ val eval2 = expressionEvaluator(expressions._2, ctx)
val resultCode = f(eval1.primitiveTerm, eval2.primitiveTerm)
- eval1.code ++ eval2.code ++
- q"""
- val $nullTerm = ${eval1.nullTerm} || ${eval2.nullTerm}
- val $primitiveTerm: ${termForType(resultType)} =
- if($nullTerm) {
- ${defaultPrimitive(resultType)}
- } else {
- $resultCode.asInstanceOf[${termForType(resultType)}]
- }
- """.children : Seq[Tree]
+ eval1.code + eval2.code +
+ s"""
+ boolean $nullTerm = ${eval1.nullTerm} || ${eval2.nullTerm};
+ ${primitiveForType(resultType)} $primitiveTerm = ${defaultPrimitive(resultType)};
+ if(!$nullTerm) {
+ $primitiveTerm = (${primitiveForType(resultType)})($resultCode);
+ }
+ """
}
}
- val inputTuple = newTermName(s"i")
+ val inputTuple = "i"
// TODO: Skip generation of null handling code when expression are not nullable.
- val primitiveEvaluation: PartialFunction[Expression, Seq[Tree]] = {
+ val primitiveEvaluation: PartialFunction[Expression, String] = {
case b @ BoundReference(ordinal, dataType, nullable) =>
- val nullValue = q"$inputTuple.isNullAt($ordinal)"
- q"""
- val $nullTerm: Boolean = $nullValue
- val $primitiveTerm: ${termForType(dataType)} =
- if($nullTerm)
- ${defaultPrimitive(dataType)}
- else
- ${getColumn(inputTuple, dataType, ordinal)}
- """.children
+ s"""
+ final boolean $nullTerm = $inputTuple.isNullAt($ordinal);
+ final ${primitiveForType(dataType)} $primitiveTerm = $nullTerm ?
+ ${defaultPrimitive(dataType)} : (${getColumn(inputTuple, dataType, ordinal)});
+ """
case expressions.Literal(null, dataType) =>
- q"""
- val $nullTerm = true
- val $primitiveTerm: ${termForType(dataType)} = null.asInstanceOf[${termForType(dataType)}]
- """.children
-
- case expressions.Literal(value: Boolean, dataType) =>
- q"""
- val $nullTerm = ${value == null}
- val $primitiveTerm: ${termForType(dataType)} = $value
- """.children
-
- case expressions.Literal(value: UTF8String, dataType) =>
- q"""
- val $nullTerm = ${value == null}
- val $primitiveTerm: ${termForType(dataType)} =
- org.apache.spark.sql.types.UTF8String(${value.getBytes})
- """.children
-
- case expressions.Literal(value: Int, dataType) =>
- q"""
- val $nullTerm = ${value == null}
- val $primitiveTerm: ${termForType(dataType)} = $value
- """.children
-
- case expressions.Literal(value: Long, dataType) =>
- q"""
- val $nullTerm = ${value == null}
- val $primitiveTerm: ${termForType(dataType)} = $value
- """.children
-
- case Cast(e @ BinaryType(), StringType) =>
- val eval = expressionEvaluator(e)
- eval.code ++
- q"""
- val $nullTerm = ${eval.nullTerm}
- val $primitiveTerm =
- if($nullTerm)
- ${defaultPrimitive(StringType)}
- else
- org.apache.spark.sql.types.UTF8String(${eval.primitiveTerm}.asInstanceOf[Array[Byte]])
- """.children
+ s"""
+ final boolean $nullTerm = true;
+ ${primitiveForType(dataType)} $primitiveTerm = ${defaultPrimitive(dataType)};
+ """
+
+ case expressions.Literal(value: UTF8String, StringType) =>
+ val arr = s"new byte[]{${value.getBytes.map(_.toString).mkString(", ")}}"
+ s"""
+ final boolean $nullTerm = false;
+ ${stringType} $primitiveTerm =
+ new ${stringType}().set(${arr});
+ """
+
+ case expressions.Literal(value, FloatType) =>
+ s"""
+ final boolean $nullTerm = false;
+ float $primitiveTerm = ${value}f;
+ """
+
+ case expressions.Literal(value, dt @ DecimalType()) =>
+ s"""
+ final boolean $nullTerm = false;
+ ${primitiveForType(dt)} $primitiveTerm = new ${primitiveForType(dt)}().set($value);
+ """
+
+ case expressions.Literal(value, dataType) =>
+ s"""
+ final boolean $nullTerm = false;
+ ${primitiveForType(dataType)} $primitiveTerm = $value;
+ """
+
+ case Cast(child @ BinaryType(), StringType) =>
+ child.castOrNull(c =>
+ s"new ${stringType}().set($c)",
+ StringType)
case Cast(child @ DateType(), StringType) =>
child.castOrNull(c =>
- q"""org.apache.spark.sql.types.UTF8String(
+ s"""new ${stringType}().set(
org.apache.spark.sql.catalyst.util.DateUtils.toString($c))""",
StringType)
- case Cast(child @ NumericType(), IntegerType) =>
- child.castOrNull(c => q"$c.toInt", IntegerType)
+ case Cast(child @ BooleanType(), dt: NumericType) if !dt.isInstanceOf[DecimalType] =>
+ child.castOrNull(c => s"(${primitiveForType(dt)})($c?1:0)", dt)
- case Cast(child @ NumericType(), LongType) =>
- child.castOrNull(c => q"$c.toLong", LongType)
+ case Cast(child @ DecimalType(), IntegerType) =>
+ child.castOrNull(c => s"($c).toInt()", IntegerType)
- case Cast(child @ NumericType(), DoubleType) =>
- child.castOrNull(c => q"$c.toDouble", DoubleType)
+ case Cast(child @ DecimalType(), dt: NumericType) if !dt.isInstanceOf[DecimalType] =>
+ child.castOrNull(c => s"($c).to${termForType(dt)}()", dt)
- case Cast(child @ NumericType(), FloatType) =>
- child.castOrNull(c => q"$c.toFloat", FloatType)
+ case Cast(child @ NumericType(), dt: NumericType) if !dt.isInstanceOf[DecimalType] =>
+ child.castOrNull(c => s"(${primitiveForType(dt)})($c)", dt)
// Special handling required for timestamps in hive test cases since the toString function
// does not match the expected output.
case Cast(e, StringType) if e.dataType != TimestampType =>
- val eval = expressionEvaluator(e)
- eval.code ++
- q"""
- val $nullTerm = ${eval.nullTerm}
- val $primitiveTerm =
- if($nullTerm)
- ${defaultPrimitive(StringType)}
- else
- org.apache.spark.sql.types.UTF8String(${eval.primitiveTerm}.toString)
- """.children
+ e.castOrNull(c =>
+ s"new ${stringType}().set(String.valueOf($c))",
+ StringType)
case EqualTo(e1 @ BinaryType(), e2 @ BinaryType()) =>
(e1, e2).evaluateAs (BooleanType) {
case (eval1, eval2) =>
- q"""
- java.util.Arrays.equals($eval1.asInstanceOf[Array[Byte]],
- $eval2.asInstanceOf[Array[Byte]])
- """
+ s"java.util.Arrays.equals((byte[])$eval1, (byte[])$eval2)"
}
case EqualTo(e1, e2) =>
- (e1, e2).evaluateAs (BooleanType) { case (eval1, eval2) => q"$eval1 == $eval2" }
-
- /* TODO: Fix null semantics.
- case In(e1, list) if !list.exists(!_.isInstanceOf[expressions.Literal]) =>
- val eval = expressionEvaluator(e1)
-
- val checks = list.map {
- case expressions.Literal(v: String, dataType) =>
- q"if(${eval.primitiveTerm} == $v) return true"
- case expressions.Literal(v: Int, dataType) =>
- q"if(${eval.primitiveTerm} == $v) return true"
- }
-
- val funcName = newTermName(s"isIn${curId.getAndIncrement()}")
-
- q"""
- def $funcName: Boolean = {
- ..${eval.code}
- if(${eval.nullTerm}) return false
- ..$checks
- return false
- }
- val $nullTerm = false
- val $primitiveTerm = $funcName
- """.children
- */
+ (e1, e2).evaluateAs (BooleanType) { case (eval1, eval2) => s"$eval1 == $eval2" }
case GreaterThan(e1 @ NumericType(), e2 @ NumericType()) =>
- (e1, e2).evaluateAs (BooleanType) { case (eval1, eval2) => q"$eval1 > $eval2" }
+ (e1, e2).evaluateAs (BooleanType) { case (eval1, eval2) => s"$eval1 > $eval2" }
case GreaterThanOrEqual(e1 @ NumericType(), e2 @ NumericType()) =>
- (e1, e2).evaluateAs (BooleanType) { case (eval1, eval2) => q"$eval1 >= $eval2" }
+ (e1, e2).evaluateAs (BooleanType) { case (eval1, eval2) => s"$eval1 >= $eval2" }
case LessThan(e1 @ NumericType(), e2 @ NumericType()) =>
- (e1, e2).evaluateAs (BooleanType) { case (eval1, eval2) => q"$eval1 < $eval2" }
+ (e1, e2).evaluateAs (BooleanType) { case (eval1, eval2) => s"$eval1 < $eval2" }
case LessThanOrEqual(e1 @ NumericType(), e2 @ NumericType()) =>
- (e1, e2).evaluateAs (BooleanType) { case (eval1, eval2) => q"$eval1 <= $eval2" }
+ (e1, e2).evaluateAs (BooleanType) { case (eval1, eval2) => s"$eval1 <= $eval2" }
case And(e1, e2) =>
- val eval1 = expressionEvaluator(e1)
- val eval2 = expressionEvaluator(e2)
-
- q"""
- ..${eval1.code}
- var $nullTerm = false
- var $primitiveTerm: ${termForType(BooleanType)} = false
-
- if (!${eval1.nullTerm} && ${eval1.primitiveTerm} == false) {
+ val eval1 = expressionEvaluator(e1, ctx)
+ val eval2 = expressionEvaluator(e2, ctx)
+ s"""
+ ${eval1.code}
+ boolean $nullTerm = false;
+ boolean $primitiveTerm = false;
+
+ if (!${eval1.nullTerm} && !${eval1.primitiveTerm}) {
} else {
- ..${eval2.code}
- if (!${eval2.nullTerm} && ${eval2.primitiveTerm} == false) {
+ ${eval2.code}
+ if (!${eval2.nullTerm} && !${eval2.primitiveTerm}) {
} else if (!${eval1.nullTerm} && !${eval2.nullTerm}) {
- $primitiveTerm = true
+ $primitiveTerm = true;
} else {
- $nullTerm = true
+ $nullTerm = true;
}
}
- """.children
+ """
case Or(e1, e2) =>
- val eval1 = expressionEvaluator(e1)
- val eval2 = expressionEvaluator(e2)
+ val eval1 = expressionEvaluator(e1, ctx)
+ val eval2 = expressionEvaluator(e2, ctx)
- q"""
- ..${eval1.code}
- var $nullTerm = false
- var $primitiveTerm: ${termForType(BooleanType)} = false
+ s"""
+ ${eval1.code}
+ boolean $nullTerm = false;
+ boolean $primitiveTerm = false;
if (!${eval1.nullTerm} && ${eval1.primitiveTerm}) {
- $primitiveTerm = true
+ $primitiveTerm = true;
} else {
- ..${eval2.code}
+ ${eval2.code}
if (!${eval2.nullTerm} && ${eval2.primitiveTerm}) {
- $primitiveTerm = true
+ $primitiveTerm = true;
} else if (!${eval1.nullTerm} && !${eval2.nullTerm}) {
- $primitiveTerm = false
+ $primitiveTerm = false;
} else {
- $nullTerm = true
+ $nullTerm = true;
}
}
- """.children
+ """
case Not(child) =>
// Uh, bad function name...
- child.castOrNull(c => q"!$c", BooleanType)
-
- case Add(e1, e2) => (e1, e2) evaluate { case (eval1, eval2) => q"$eval1 + $eval2" }
- case Subtract(e1, e2) => (e1, e2) evaluate { case (eval1, eval2) => q"$eval1 - $eval2" }
- case Multiply(e1, e2) => (e1, e2) evaluate { case (eval1, eval2) => q"$eval1 * $eval2" }
+ child.castOrNull(c => s"!$c", BooleanType)
+
+ case Add(e1 @ DecimalType(), e2 @ DecimalType()) =>
+ (e1, e2) evaluate { case (eval1, eval2) => s"$eval1.$$plus($eval2)" }
+ case Subtract(e1 @ DecimalType(), e2 @ DecimalType()) =>
+ (e1, e2) evaluate { case (eval1, eval2) => s"$eval1.$$minus($eval2)" }
+ case Multiply(e1 @ DecimalType(), e2 @ DecimalType()) =>
+ (e1, e2) evaluate { case (eval1, eval2) => s"$eval1.$$times($eval2)" }
+ case Divide(e1 @ DecimalType(), e2 @ DecimalType()) =>
+ val eval1 = expressionEvaluator(e1, ctx)
+ val eval2 = expressionEvaluator(e2, ctx)
+ eval1.code + eval2.code +
+ s"""
+ boolean $nullTerm = false;
+ ${primitiveForType(e1.dataType)} $primitiveTerm = null;
+ if (${eval1.nullTerm} || ${eval2.nullTerm} || ${eval2.primitiveTerm}.isZero()) {
+ $nullTerm = true;
+ } else {
+ $primitiveTerm = ${eval1.primitiveTerm}.$$div${eval2.primitiveTerm});
+ }
+ """
+ case Remainder(e1 @ DecimalType(), e2 @ DecimalType()) =>
+ val eval1 = expressionEvaluator(e1, ctx)
+ val eval2 = expressionEvaluator(e2, ctx)
+ eval1.code + eval2.code +
+ s"""
+ boolean $nullTerm = false;
+ ${primitiveForType(e1.dataType)} $primitiveTerm = 0;
+ if (${eval1.nullTerm} || ${eval2.nullTerm} || ${eval2.primitiveTerm}.isZero()) {
+ $nullTerm = true;
+ } else {
+ $primitiveTerm = ${eval1.primitiveTerm}.remainder(${eval2.primitiveTerm});
+ }
+ """
+
+ case Add(e1, e2) =>
+ (e1, e2) evaluate { case (eval1, eval2) => s"$eval1 + $eval2" }
+ case Subtract(e1, e2) =>
+ (e1, e2) evaluate { case (eval1, eval2) => s"$eval1 - $eval2" }
+ case Multiply(e1, e2) =>
+ (e1, e2) evaluate { case (eval1, eval2) => s"$eval1 * $eval2" }
case Divide(e1, e2) =>
- val eval1 = expressionEvaluator(e1)
- val eval2 = expressionEvaluator(e2)
-
- eval1.code ++ eval2.code ++
- q"""
- var $nullTerm = false
- var $primitiveTerm: ${termForType(e1.dataType)} = 0
-
- if (${eval1.nullTerm} || ${eval2.nullTerm} ) {
- $nullTerm = true
- } else if (${eval2.primitiveTerm} == 0)
- $nullTerm = true
- else {
- $primitiveTerm = ${eval1.primitiveTerm} / ${eval2.primitiveTerm}
+ val eval1 = expressionEvaluator(e1, ctx)
+ val eval2 = expressionEvaluator(e2, ctx)
+ eval1.code + eval2.code +
+ s"""
+ boolean $nullTerm = false;
+ ${primitiveForType(e1.dataType)} $primitiveTerm = 0;
+ if (${eval1.nullTerm} || ${eval2.nullTerm} || ${eval2.primitiveTerm} == 0) {
+ $nullTerm = true;
+ } else {
+ $primitiveTerm = ${eval1.primitiveTerm} / ${eval2.primitiveTerm};
}
- """.children
-
+ """
case Remainder(e1, e2) =>
- val eval1 = expressionEvaluator(e1)
- val eval2 = expressionEvaluator(e2)
-
- eval1.code ++ eval2.code ++
- q"""
- var $nullTerm = false
- var $primitiveTerm: ${termForType(e1.dataType)} = 0
-
- if (${eval1.nullTerm} || ${eval2.nullTerm} ) {
- $nullTerm = true
- } else if (${eval2.primitiveTerm} == 0)
- $nullTerm = true
- else {
- $nullTerm = false
- $primitiveTerm = ${eval1.primitiveTerm} % ${eval2.primitiveTerm}
+ val eval1 = expressionEvaluator(e1, ctx)
+ val eval2 = expressionEvaluator(e2, ctx)
+ eval1.code + eval2.code +
+ s"""
+ boolean $nullTerm = false;
+ ${primitiveForType(e1.dataType)} $primitiveTerm = 0;
+ if (${eval1.nullTerm} || ${eval2.nullTerm} || ${eval2.primitiveTerm} == 0) {
+ $nullTerm = true;
+ } else {
+ $primitiveTerm = ${eval1.primitiveTerm} % ${eval2.primitiveTerm};
}
- """.children
+ """
case IsNotNull(e) =>
- val eval = expressionEvaluator(e)
- q"""
- ..${eval.code}
- var $nullTerm = false
- var $primitiveTerm: ${termForType(BooleanType)} = !${eval.nullTerm}
- """.children
+ val eval = expressionEvaluator(e, ctx)
+ s"""
+ ${eval.code}
+ boolean $nullTerm = false;
+ boolean $primitiveTerm = !${eval.nullTerm};
+ """
case IsNull(e) =>
- val eval = expressionEvaluator(e)
- q"""
- ..${eval.code}
- var $nullTerm = false
- var $primitiveTerm: ${termForType(BooleanType)} = ${eval.nullTerm}
- """.children
-
- case c @ Coalesce(children) =>
- q"""
- var $nullTerm = true
- var $primitiveTerm: ${termForType(c.dataType)} = ${defaultPrimitive(c.dataType)}
- """.children ++
+ val eval = expressionEvaluator(e, ctx)
+ s"""
+ ${eval.code}
+ boolean $nullTerm = false;
+ boolean $primitiveTerm = ${eval.nullTerm};
+ """
+
+ case e @ Coalesce(children) =>
+ s"""
+ boolean $nullTerm = true;
+ ${primitiveForType(e.dataType)} $primitiveTerm = ${defaultPrimitive(e.dataType)};
+ """ +
children.map { c =>
- val eval = expressionEvaluator(c)
- q"""
+ val eval = expressionEvaluator(c, ctx)
+ s"""
if($nullTerm) {
- ..${eval.code}
+ ${eval.code}
if(!${eval.nullTerm}) {
- $nullTerm = false
- $primitiveTerm = ${eval.primitiveTerm}
+ $nullTerm = false;
+ $primitiveTerm = ${eval.primitiveTerm};
}
}
"""
- }
+ }.mkString("\n")
- case i @ expressions.If(condition, trueValue, falseValue) =>
- val condEval = expressionEvaluator(condition)
- val trueEval = expressionEvaluator(trueValue)
- val falseEval = expressionEvaluator(falseValue)
+ case e @ expressions.If(condition, trueValue, falseValue) =>
+ val condEval = expressionEvaluator(condition, ctx)
+ val trueEval = expressionEvaluator(trueValue, ctx)
+ val falseEval = expressionEvaluator(falseValue, ctx)
- q"""
- var $nullTerm = false
- var $primitiveTerm: ${termForType(i.dataType)} = ${defaultPrimitive(i.dataType)}
- ..${condEval.code}
+ s"""
+ boolean $nullTerm = false;
+ ${primitiveForType(e.dataType)} $primitiveTerm = ${defaultPrimitive(e.dataType)};
+ ${condEval.code}
if(!${condEval.nullTerm} && ${condEval.primitiveTerm}) {
- ..${trueEval.code}
- $nullTerm = ${trueEval.nullTerm}
- $primitiveTerm = ${trueEval.primitiveTerm}
+ ${trueEval.code}
+ $nullTerm = ${trueEval.nullTerm};
+ $primitiveTerm = ${trueEval.primitiveTerm};
} else {
- ..${falseEval.code}
- $nullTerm = ${falseEval.nullTerm}
- $primitiveTerm = ${falseEval.primitiveTerm}
+ ${falseEval.code}
+ $nullTerm = ${falseEval.nullTerm};
+ $primitiveTerm = ${falseEval.primitiveTerm};
}
- """.children
+ """
case NewSet(elementType) =>
- q"""
- val $nullTerm = false
- val $primitiveTerm = new ${hashSetForType(elementType)}()
- """.children
+ s"""
+ boolean $nullTerm = false;
+ ${hashSetForType(elementType)} $primitiveTerm = new ${hashSetForType(elementType)}();
+ """
case AddItemToSet(item, set) =>
- val itemEval = expressionEvaluator(item)
- val setEval = expressionEvaluator(set)
+ val itemEval = expressionEvaluator(item, ctx)
+ val setEval = expressionEvaluator(set, ctx)
val elementType = set.dataType.asInstanceOf[OpenHashSetUDT].elementType
+ val htype = hashSetForType(elementType)
- itemEval.code ++ setEval.code ++
- q"""
- if (!${itemEval.nullTerm}) {
- ${setEval.primitiveTerm}
- .asInstanceOf[${hashSetForType(elementType)}]
- .add(${itemEval.primitiveTerm})
+ itemEval.code + setEval.code +
+ s"""
+ if (!${itemEval.nullTerm} && !${setEval.nullTerm}) {
+ (($htype)${setEval.primitiveTerm}).add(${itemEval.primitiveTerm});
}
-
- val $nullTerm = false
- val $primitiveTerm = ${setEval.primitiveTerm}
- """.children
+ boolean $nullTerm = false;
+ ${htype} $primitiveTerm = ($htype)${setEval.primitiveTerm};
+ """
case CombineSets(left, right) =>
- val leftEval = expressionEvaluator(left)
- val rightEval = expressionEvaluator(right)
+ val leftEval = expressionEvaluator(left, ctx)
+ val rightEval = expressionEvaluator(right, ctx)
val elementType = left.dataType.asInstanceOf[OpenHashSetUDT].elementType
+ val htype = hashSetForType(elementType)
- leftEval.code ++ rightEval.code ++
- q"""
- val $nullTerm = false
- var $primitiveTerm: ${hashSetForType(elementType)} = null
-
- {
- val leftSet = ${leftEval.primitiveTerm}.asInstanceOf[${hashSetForType(elementType)}]
- val rightSet = ${rightEval.primitiveTerm}.asInstanceOf[${hashSetForType(elementType)}]
- val iterator = rightSet.iterator
- while (iterator.hasNext) {
- leftSet.add(iterator.next())
- }
- $primitiveTerm = leftSet
- }
- """.children
+ leftEval.code + rightEval.code +
+ s"""
+ boolean $nullTerm = false;
+ ${htype} $primitiveTerm =
+ (${htype})${leftEval.primitiveTerm};
+ $primitiveTerm.union((${htype})${rightEval.primitiveTerm});
+ """
- case MaxOf(e1, e2) =>
- val eval1 = expressionEvaluator(e1)
- val eval2 = expressionEvaluator(e2)
+ case MaxOf(e1, e2) if !e1.dataType.isInstanceOf[DecimalType] =>
+ val eval1 = expressionEvaluator(e1, ctx)
+ val eval2 = expressionEvaluator(e2, ctx)
- eval1.code ++ eval2.code ++
- q"""
- var $nullTerm = false
- var $primitiveTerm: ${termForType(e1.dataType)} = ${defaultPrimitive(e1.dataType)}
+ eval1.code + eval2.code +
+ s"""
+ boolean $nullTerm = false;
+ ${primitiveForType(e1.dataType)} $primitiveTerm = ${defaultPrimitive(e1.dataType)};
if (${eval1.nullTerm}) {
- $nullTerm = ${eval2.nullTerm}
- $primitiveTerm = ${eval2.primitiveTerm}
+ $nullTerm = ${eval2.nullTerm};
+ $primitiveTerm = ${eval2.primitiveTerm};
} else if (${eval2.nullTerm}) {
- $nullTerm = ${eval1.nullTerm}
- $primitiveTerm = ${eval1.primitiveTerm}
+ $nullTerm = ${eval1.nullTerm};
+ $primitiveTerm = ${eval1.primitiveTerm};
} else {
if (${eval1.primitiveTerm} > ${eval2.primitiveTerm}) {
- $primitiveTerm = ${eval1.primitiveTerm}
+ $primitiveTerm = ${eval1.primitiveTerm};
} else {
- $primitiveTerm = ${eval2.primitiveTerm}
+ $primitiveTerm = ${eval2.primitiveTerm};
}
}
- """.children
+ """
- case MinOf(e1, e2) =>
- val eval1 = expressionEvaluator(e1)
- val eval2 = expressionEvaluator(e2)
+ case MinOf(e1, e2) if !e1.dataType.isInstanceOf[DecimalType] =>
+ val eval1 = expressionEvaluator(e1, ctx)
+ val eval2 = expressionEvaluator(e2, ctx)
- eval1.code ++ eval2.code ++
- q"""
- var $nullTerm = false
- var $primitiveTerm: ${termForType(e1.dataType)} = ${defaultPrimitive(e1.dataType)}
+ eval1.code + eval2.code +
+ s"""
+ boolean $nullTerm = false;
+ ${primitiveForType(e1.dataType)} $primitiveTerm = ${defaultPrimitive(e1.dataType)};
if (${eval1.nullTerm}) {
- $nullTerm = ${eval2.nullTerm}
- $primitiveTerm = ${eval2.primitiveTerm}
+ $nullTerm = ${eval2.nullTerm};
+ $primitiveTerm = ${eval2.primitiveTerm};
} else if (${eval2.nullTerm}) {
- $nullTerm = ${eval1.nullTerm}
- $primitiveTerm = ${eval1.primitiveTerm}
+ $nullTerm = ${eval1.nullTerm};
+ $primitiveTerm = ${eval1.primitiveTerm};
} else {
if (${eval1.primitiveTerm} < ${eval2.primitiveTerm}) {
- $primitiveTerm = ${eval1.primitiveTerm}
+ $primitiveTerm = ${eval1.primitiveTerm};
} else {
- $primitiveTerm = ${eval2.primitiveTerm}
+ $primitiveTerm = ${eval2.primitiveTerm};
}
}
- """.children
+ """
case UnscaledValue(child) =>
- val childEval = expressionEvaluator(child)
-
- childEval.code ++
- q"""
- var $nullTerm = ${childEval.nullTerm}
- var $primitiveTerm: Long = if (!$nullTerm) {
- ${childEval.primitiveTerm}.toUnscaledLong
- } else {
- ${defaultPrimitive(LongType)}
- }
- """.children
+ val childEval = expressionEvaluator(child, ctx)
+
+ childEval.code +
+ s"""
+ boolean $nullTerm = ${childEval.nullTerm};
+ long $primitiveTerm = $nullTerm ? -1 : ${childEval.primitiveTerm}.toUnscaledLong();
+ """
case MakeDecimal(child, precision, scale) =>
- val childEval = expressionEvaluator(child)
+ val eval = expressionEvaluator(child, ctx)
- childEval.code ++
- q"""
- var $nullTerm = ${childEval.nullTerm}
- var $primitiveTerm: org.apache.spark.sql.types.Decimal =
- ${defaultPrimitive(DecimalType())}
+ eval.code +
+ s"""
+ boolean $nullTerm = ${eval.nullTerm};
+ org.apache.spark.sql.types.Decimal $primitiveTerm = ${defaultPrimitive(DecimalType())};
if (!$nullTerm) {
- $primitiveTerm = new org.apache.spark.sql.types.Decimal()
- $primitiveTerm = $primitiveTerm.setOrNull(${childEval.primitiveTerm}, $precision, $scale)
- $nullTerm = $primitiveTerm == null
+ $primitiveTerm = new org.apache.spark.sql.types.Decimal();
+ $primitiveTerm = $primitiveTerm.setOrNull(${eval.primitiveTerm}, $precision, $scale);
+ $nullTerm = $primitiveTerm == null;
}
- """.children
+ """
}
// If there was no match in the partial function above, we fall back on calling the interpreted
// expression evaluator.
- val code: Seq[Tree] =
+ val code: String =
primitiveEvaluation.lift.apply(e).getOrElse {
- log.debug(s"No rules to generate $e")
- val tree = reify { e }
- q"""
- val $objectTerm = $tree.eval(i)
- val $nullTerm = $objectTerm == null
- val $primitiveTerm = $objectTerm.asInstanceOf[${termForType(e.dataType)}]
- """.children
- }
-
- // Only inject debugging code if debugging is turned on.
- val debugCode =
- if (debugLogging) {
- val localLogger = log
- val localLoggerTree = reify { localLogger }
- q"""
- $localLoggerTree.debug(
- ${e.toString} + ": " + (if ($nullTerm) "null" else $primitiveTerm.toString))
- """ :: Nil
- } else {
- Nil
+ logError(s"No rules to generate $e")
+ ctx.references += e
+ s"""
+ /* expression: ${e} */
+ Object $objectTerm = expressions[${ctx.references.size - 1}].eval(i);
+ boolean $nullTerm = $objectTerm == null;
+ ${primitiveForType(e.dataType)} $primitiveTerm = ${defaultPrimitive(e.dataType)};
+ if (!$nullTerm) $primitiveTerm = (${termForType(e.dataType)})$objectTerm;
+ """
}
- EvaluatedExpression(code ++ debugCode, nullTerm, primitiveTerm, objectTerm)
+ EvaluatedExpression(code, nullTerm, primitiveTerm, objectTerm)
}
- protected def getColumn(inputRow: TermName, dataType: DataType, ordinal: Int) = {
+ protected def getColumn(inputRow: String, dataType: DataType, ordinal: Int) = {
dataType match {
- case StringType => q"$inputRow($ordinal).asInstanceOf[org.apache.spark.sql.types.UTF8String]"
- case dt: DataType if isNativeType(dt) => q"$inputRow.${accessorForType(dt)}($ordinal)"
- case _ => q"$inputRow.apply($ordinal).asInstanceOf[${termForType(dataType)}]"
+ case StringType => s"(${stringType})$inputRow.apply($ordinal)"
+ case dt: DataType if isNativeType(dt) => s"$inputRow.${accessorForType(dt)}($ordinal)"
+ case _ => s"(${termForType(dataType)})$inputRow.apply($ordinal)"
}
}
protected def setColumn(
- destinationRow: TermName,
+ destinationRow: String,
dataType: DataType,
ordinal: Int,
- value: TermName) = {
+ value: String): String = {
dataType match {
- case StringType => q"$destinationRow.update($ordinal, $value)"
+ case StringType => s"$destinationRow.update($ordinal, $value)"
case dt: DataType if isNativeType(dt) =>
- q"$destinationRow.${mutatorForType(dt)}($ordinal, $value)"
- case _ => q"$destinationRow.update($ordinal, $value)"
+ s"$destinationRow.${mutatorForType(dt)}($ordinal, $value)"
+ case _ => s"$destinationRow.update($ordinal, $value)"
}
}
- protected def accessorForType(dt: DataType) = newTermName(s"get${primitiveForType(dt)}")
- protected def mutatorForType(dt: DataType) = newTermName(s"set${primitiveForType(dt)}")
+ protected def accessorForType(dt: DataType) = dt match {
+ case IntegerType => "getInt"
+ case other => s"get${termForType(dt)}"
+ }
+
+ protected def mutatorForType(dt: DataType) = dt match {
+ case IntegerType => "setInt"
+ case other => s"set${termForType(dt)}"
+ }
- protected def hashSetForType(dt: DataType) = dt match {
- case IntegerType => typeOf[IntegerHashSet]
- case LongType => typeOf[LongHashSet]
+ protected def hashSetForType(dt: DataType): String = dt match {
+ case IntegerType => classOf[IntegerHashSet].getName
+ case LongType => classOf[LongHashSet].getName
case unsupportedType =>
sys.error(s"Code generation not support for hashset of type $unsupportedType")
}
- protected def primitiveForType(dt: DataType) = dt match {
- case IntegerType => "Int"
+ protected def primitiveForType(dt: DataType): String = dt match {
+ case IntegerType => "int"
+ case LongType => "long"
+ case ShortType => "short"
+ case ByteType => "byte"
+ case DoubleType => "double"
+ case FloatType => "float"
+ case BooleanType => "boolean"
+ case dt: DecimalType => decimalType
+ case BinaryType => "byte[]"
+ case StringType => stringType
+ case DateType => "int"
+ case TimestampType => "java.sql.Timestamp"
+ case _ => "Object"
+ }
+
+ protected def defaultPrimitive(dt: DataType): String = dt match {
+ case BooleanType => "false"
+ case FloatType => "-1.0f"
+ case ShortType => "-1"
+ case LongType => "-1"
+ case ByteType => "-1"
+ case DoubleType => "-1.0"
+ case IntegerType => "-1"
+ case DateType => "-1"
+ case dt: DecimalType => "null"
+ case StringType => "null"
+ case _ => "null"
+ }
+
+ protected def termForType(dt: DataType): String = dt match {
+ case IntegerType => "Integer"
case LongType => "Long"
case ShortType => "Short"
case ByteType => "Byte"
case DoubleType => "Double"
case FloatType => "Float"
case BooleanType => "Boolean"
- case StringType => "org.apache.spark.sql.types.UTF8String"
- }
-
- protected def defaultPrimitive(dt: DataType) = dt match {
- case BooleanType => ru.Literal(Constant(false))
- case FloatType => ru.Literal(Constant(-1.0.toFloat))
- case StringType => q"""org.apache.spark.sql.types.UTF8String("<uninit>")"""
- case ShortType => ru.Literal(Constant(-1.toShort))
- case LongType => ru.Literal(Constant(-1L))
- case ByteType => ru.Literal(Constant(-1.toByte))
- case DoubleType => ru.Literal(Constant(-1.toDouble))
- case DecimalType() => q"org.apache.spark.sql.types.Decimal(-1)"
- case IntegerType => ru.Literal(Constant(-1))
- case DateType => ru.Literal(Constant(-1))
- case _ => ru.Literal(Constant(null))
- }
-
- protected def termForType(dt: DataType) = dt match {
- case n: AtomicType => n.tag
- case _ => typeTag[Any]
+ case dt: DecimalType => decimalType
+ case BinaryType => "byte[]"
+ case StringType => stringType
+ case DateType => "Integer"
+ case TimestampType => "java.sql.Timestamp"
+ case _ => "Object"
}
/**
* List of data types that have special accessors and setters in [[Row]].
*/
protected val nativeTypes =
- Seq(IntegerType, BooleanType, LongType, DoubleType, FloatType, ShortType, ByteType, StringType)
+ Seq(IntegerType, BooleanType, LongType, DoubleType, FloatType, ShortType, ByteType)
/**
* Returns true if the data type has a special accessor and setter in [[Row]].
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala
index 840260703a..638b53fe0f 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala
@@ -19,15 +19,14 @@ package org.apache.spark.sql.catalyst.expressions.codegen
import org.apache.spark.sql.catalyst.expressions._
+// MutableProjection is not accessible in Java
+abstract class BaseMutableProjection extends MutableProjection {}
+
/**
* Generates byte code that produces a [[MutableRow]] object that can update itself based on a new
* input [[Row]] for a fixed set of [[Expression Expressions]].
*/
object GenerateMutableProjection extends CodeGenerator[Seq[Expression], () => MutableProjection] {
- import scala.reflect.runtime.{universe => ru}
- import scala.reflect.runtime.universe._
-
- val mutableRowName = newTermName("mutableRow")
protected def canonicalize(in: Seq[Expression]): Seq[Expression] =
in.map(ExpressionCanonicalizer.execute)
@@ -36,41 +35,61 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], () => Mu
in.map(BindReferences.bindReference(_, inputSchema))
protected def create(expressions: Seq[Expression]): (() => MutableProjection) = {
- val projectionCode = expressions.zipWithIndex.flatMap { case (e, i) =>
- val evaluationCode = expressionEvaluator(e)
-
- evaluationCode.code :+
- q"""
- if(${evaluationCode.nullTerm})
- mutableRow.setNullAt($i)
- else
- ${setColumn(mutableRowName, e.dataType, i, evaluationCode.primitiveTerm)}
- """
- }
+ val ctx = newCodeGenContext()
+ val projectionCode = expressions.zipWithIndex.map { case (e, i) =>
+ val evaluationCode = expressionEvaluator(e, ctx)
+ evaluationCode.code +
+ s"""
+ if(${evaluationCode.nullTerm})
+ mutableRow.setNullAt($i);
+ else
+ ${setColumn("mutableRow", e.dataType, i, evaluationCode.primitiveTerm)};
+ """
+ }.mkString("\n")
+ val code = s"""
+ import org.apache.spark.sql.Row;
+
+ public SpecificProjection generate($exprType[] expr) {
+ return new SpecificProjection(expr);
+ }
+
+ class SpecificProjection extends ${classOf[BaseMutableProjection].getName} {
- val code =
- q"""
- () => { new $mutableProjectionType {
+ private $exprType[] expressions = null;
+ private $mutableRowType mutableRow = null;
- private[this] var $mutableRowName: $mutableRowType =
- new $genericMutableRowType(${expressions.size})
+ public SpecificProjection($exprType[] expr) {
+ expressions = expr;
+ mutableRow = new $genericMutableRowType(${expressions.size});
+ }
- def target(row: $mutableRowType): $mutableProjectionType = {
- $mutableRowName = row
- this
- }
+ public ${classOf[BaseMutableProjection].getName} target($mutableRowType row) {
+ mutableRow = row;
+ return this;
+ }
- /* Provide immutable access to the last projected row. */
- def currentValue: $rowType = mutableRow
+ /* Provide immutable access to the last projected row. */
+ public Row currentValue() {
+ return mutableRow;
+ }
- def apply(i: $rowType): $rowType = {
- ..$projectionCode
- mutableRow
- }
- } }
- """
+ public Object apply(Object _i) {
+ Row i = (Row) _i;
+ $projectionCode
- log.debug(s"code for ${expressions.mkString(",")}:\n$code")
- toolBox.eval(code).asInstanceOf[() => MutableProjection]
+ return mutableRow;
+ }
+ }
+ """
+
+
+ logDebug(s"code for ${expressions.mkString(",")}:\n$code")
+
+ val c = compile(code)
+ // fetch the only one method `generate(Expression[])`
+ val m = c.getDeclaredMethods()(0)
+ () => {
+ m.invoke(c.newInstance(), ctx.references.toArray).asInstanceOf[BaseMutableProjection]
+ }
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala
index b129c0d898..0ff840dab3 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala
@@ -18,18 +18,29 @@
package org.apache.spark.sql.catalyst.expressions.codegen
import org.apache.spark.Logging
+import org.apache.spark.annotation.Private
+import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.types.{BinaryType, StringType, NumericType}
+import org.apache.spark.sql.types.{BinaryType, NumericType}
+
+/**
+ * Inherits some default implementation for Java from `Ordering[Row]`
+ */
+@Private
+class BaseOrdering extends Ordering[Row] {
+ def compare(a: Row, b: Row): Int = {
+ throw new UnsupportedOperationException
+ }
+}
/**
* Generates bytecode for an [[Ordering]] of [[Row Rows]] for a given set of
* [[Expression Expressions]].
*/
object GenerateOrdering extends CodeGenerator[Seq[SortOrder], Ordering[Row]] with Logging {
- import scala.reflect.runtime.{universe => ru}
import scala.reflect.runtime.universe._
- protected def canonicalize(in: Seq[SortOrder]): Seq[SortOrder] =
+ protected def canonicalize(in: Seq[SortOrder]): Seq[SortOrder] =
in.map(ExpressionCanonicalizer.execute(_).asInstanceOf[SortOrder])
protected def bind(in: Seq[SortOrder], inputSchema: Seq[Attribute]): Seq[SortOrder] =
@@ -38,73 +49,90 @@ object GenerateOrdering extends CodeGenerator[Seq[SortOrder], Ordering[Row]] wit
protected def create(ordering: Seq[SortOrder]): Ordering[Row] = {
val a = newTermName("a")
val b = newTermName("b")
- val comparisons = ordering.zipWithIndex.map { case (order, i) =>
- val evalA = expressionEvaluator(order.child)
- val evalB = expressionEvaluator(order.child)
+ val ctx = newCodeGenContext()
+ val comparisons = ordering.zipWithIndex.map { case (order, i) =>
+ val evalA = expressionEvaluator(order.child, ctx)
+ val evalB = expressionEvaluator(order.child, ctx)
+ val asc = order.direction == Ascending
val compare = order.child.dataType match {
case BinaryType =>
- q"""
- val x = ${if (order.direction == Ascending) evalA.primitiveTerm else evalB.primitiveTerm}
- val y = ${if (order.direction != Ascending) evalB.primitiveTerm else evalA.primitiveTerm}
- var i = 0
- while (i < x.length && i < y.length) {
- val res = x(i).compareTo(y(i))
- if (res != 0) return res
- i = i+1
- }
- return x.length - y.length
- """
+ s"""
+ {
+ byte[] x = ${if (asc) evalA.primitiveTerm else evalB.primitiveTerm};
+ byte[] y = ${if (!asc) evalB.primitiveTerm else evalA.primitiveTerm};
+ int j = 0;
+ while (j < x.length && j < y.length) {
+ if (x[j] != y[j]) return x[j] - y[j];
+ j = j + 1;
+ }
+ int d = x.length - y.length;
+ if (d != 0) {
+ return d;
+ }
+ }"""
case _: NumericType =>
- q"""
- val comp = ${evalA.primitiveTerm} - ${evalB.primitiveTerm}
- if(comp != 0) {
- return ${if (order.direction == Ascending) q"comp.toInt" else q"-comp.toInt"}
- }
- """
- case StringType =>
- if (order.direction == Ascending) {
- q"""return ${evalA.primitiveTerm}.compare(${evalB.primitiveTerm})"""
+ s"""
+ if (${evalA.primitiveTerm} != ${evalB.primitiveTerm}) {
+ if (${evalA.primitiveTerm} > ${evalB.primitiveTerm}) {
+ return ${if (asc) "1" else "-1"};
+ } else {
+ return ${if (asc) "-1" else "1"};
+ }
+ }"""
+ case _ =>
+ s"""
+ int comp = ${evalA.primitiveTerm}.compare(${evalB.primitiveTerm});
+ if (comp != 0) {
+ return ${if (asc) "comp" else "-comp"};
+ }"""
+ }
+
+ s"""
+ i = $a;
+ ${evalA.code}
+ i = $b;
+ ${evalB.code}
+ if (${evalA.nullTerm} && ${evalB.nullTerm}) {
+ // Nothing
+ } else if (${evalA.nullTerm}) {
+ return ${if (order.direction == Ascending) "-1" else "1"};
+ } else if (${evalB.nullTerm}) {
+ return ${if (order.direction == Ascending) "1" else "-1"};
} else {
- q"""return ${evalB.primitiveTerm}.compare(${evalA.primitiveTerm})"""
+ $compare
}
+ """
+ }.mkString("\n")
+
+ val code = s"""
+ import org.apache.spark.sql.Row;
+
+ public SpecificOrdering generate($exprType[] expr) {
+ return new SpecificOrdering(expr);
}
- q"""
- i = $a
- ..${evalA.code}
- i = $b
- ..${evalB.code}
- if (${evalA.nullTerm} && ${evalB.nullTerm}) {
- // Nothing
- } else if (${evalA.nullTerm}) {
- return ${if (order.direction == Ascending) q"-1" else q"1"}
- } else if (${evalB.nullTerm}) {
- return ${if (order.direction == Ascending) q"1" else q"-1"}
- } else {
- $compare
+ class SpecificOrdering extends ${typeOf[BaseOrdering]} {
+
+ private $exprType[] expressions = null;
+
+ public SpecificOrdering($exprType[] expr) {
+ expressions = expr;
}
- """
- }
- val q"class $orderingName extends $orderingType { ..$body }" = reify {
- class SpecificOrdering extends Ordering[Row] {
- val o = ordering
- }
- }.tree.children.head
-
- val code = q"""
- class $orderingName extends $orderingType {
- ..$body
- def compare(a: $rowType, b: $rowType): Int = {
- var i: $rowType = null // Holds current row being evaluated.
- ..$comparisons
- return 0
+ @Override
+ public int compare(Row a, Row b) {
+ Row i = null; // Holds current row being evaluated.
+ $comparisons
+ return 0;
}
- }
- new $orderingName()
- """
+ }"""
+
logDebug(s"Generated Ordering: $code")
- toolBox.eval(code).asInstanceOf[Ordering[Row]]
+
+ val c = compile(code)
+ // fetch the only one method `generate(Expression[])`
+ val m = c.getDeclaredMethods()(0)
+ m.invoke(c.newInstance(), ctx.references.toArray).asInstanceOf[BaseOrdering]
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratePredicate.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratePredicate.scala
index 40e1630243..fb18769f00 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratePredicate.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratePredicate.scala
@@ -20,11 +20,16 @@ package org.apache.spark.sql.catalyst.expressions.codegen
import org.apache.spark.sql.catalyst.expressions._
/**
+ * Interface for generated predicate
+ */
+abstract class Predicate {
+ def eval(r: Row): Boolean
+}
+
+/**
* Generates bytecode that evaluates a boolean [[Expression]] on a given input [[Row]].
*/
object GeneratePredicate extends CodeGenerator[Expression, (Row) => Boolean] {
- import scala.reflect.runtime.{universe => ru}
- import scala.reflect.runtime.universe._
protected def canonicalize(in: Expression): Expression = ExpressionCanonicalizer.execute(in)
@@ -32,17 +37,34 @@ object GeneratePredicate extends CodeGenerator[Expression, (Row) => Boolean] {
BindReferences.bindReference(in, inputSchema)
protected def create(predicate: Expression): ((Row) => Boolean) = {
- val cEval = expressionEvaluator(predicate)
+ val ctx = newCodeGenContext()
+ val eval = expressionEvaluator(predicate, ctx)
+ val code = s"""
+ import org.apache.spark.sql.Row;
- val code =
- q"""
- (i: $rowType) => {
- ..${cEval.code}
- if (${cEval.nullTerm}) false else ${cEval.primitiveTerm}
+ public SpecificPredicate generate($exprType[] expr) {
+ return new SpecificPredicate(expr);
+ }
+
+ class SpecificPredicate extends ${classOf[Predicate].getName} {
+ private final $exprType[] expressions;
+ public SpecificPredicate($exprType[] expr) {
+ expressions = expr;
+ }
+
+ @Override
+ public boolean eval(Row i) {
+ ${eval.code}
+ return !${eval.nullTerm} && ${eval.primitiveTerm};
}
- """
+ }"""
+
+ logDebug(s"Generated predicate '$predicate':\n$code")
- log.debug(s"Generated predicate '$predicate':\n$code")
- toolBox.eval(code).asInstanceOf[Row => Boolean]
+ val c = compile(code)
+ // fetch the only one method `generate(Expression[])`
+ val m = c.getDeclaredMethods()(0)
+ val p = m.invoke(c.newInstance(), ctx.references.toArray).asInstanceOf[Predicate]
+ (r: Row) => p.eval(r)
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala
index 31c63a79eb..d5be1fc12e 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala
@@ -17,9 +17,14 @@
package org.apache.spark.sql.catalyst.expressions.codegen
+import org.apache.spark.sql.BaseMutableRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.types._
+/**
+ * Java can not access Projection (in package object)
+ */
+abstract class BaseProject extends Projection {}
/**
* Generates bytecode that produces a new [[Row]] object based on a fixed set of input
@@ -27,7 +32,6 @@ import org.apache.spark.sql.types._
* generated based on the output types of the [[Expression]] to avoid boxing of primitive values.
*/
object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] {
- import scala.reflect.runtime.{universe => ru}
import scala.reflect.runtime.universe._
protected def canonicalize(in: Seq[Expression]): Seq[Expression] =
@@ -38,201 +42,183 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] {
// Make Mutablility optional...
protected def create(expressions: Seq[Expression]): Projection = {
- val tupleLength = ru.Literal(Constant(expressions.length))
- val lengthDef = q"final val length = $tupleLength"
-
- /* TODO: Configurable...
- val nullFunctions =
- q"""
- private final val nullSet = new org.apache.spark.util.collection.BitSet(length)
- final def setNullAt(i: Int) = nullSet.set(i)
- final def isNullAt(i: Int) = nullSet.get(i)
- """
- */
-
- val nullFunctions =
- q"""
- private[this] var nullBits = new Array[Boolean](${expressions.size})
- override def setNullAt(i: Int) = { nullBits(i) = true }
- override def isNullAt(i: Int) = nullBits(i)
- """.children
-
- val tupleElements = expressions.zipWithIndex.flatMap {
+ val ctx = newCodeGenContext()
+ val columns = expressions.zipWithIndex.map {
case (e, i) =>
- val elementName = newTermName(s"c$i")
- val evaluatedExpression = expressionEvaluator(e)
- val iLit = ru.Literal(Constant(i))
+ s"private ${primitiveForType(e.dataType)} c$i = ${defaultPrimitive(e.dataType)};\n"
+ }.mkString("\n ")
- q"""
- var ${newTermName(s"c$i")}: ${termForType(e.dataType)} = _
+ val initColumns = expressions.zipWithIndex.map {
+ case (e, i) =>
+ val eval = expressionEvaluator(e, ctx)
+ s"""
{
- ..${evaluatedExpression.code}
- if(${evaluatedExpression.nullTerm})
- setNullAt($iLit)
- else {
- nullBits($iLit) = false
- $elementName = ${evaluatedExpression.primitiveTerm}
+ // column$i
+ ${eval.code}
+ nullBits[$i] = ${eval.nullTerm};
+ if(!${eval.nullTerm}) {
+ c$i = ${eval.primitiveTerm};
}
}
- """.children : Seq[Tree]
- }
+ """
+ }.mkString("\n")
- val accessorFailure = q"""scala.sys.error("Invalid ordinal:" + i)"""
- val applyFunction = {
- val cases = (0 until expressions.size).map { i =>
- val ordinal = ru.Literal(Constant(i))
- val elementName = newTermName(s"c$i")
- val iLit = ru.Literal(Constant(i))
+ val getCases = (0 until expressions.size).map { i =>
+ s"case $i: return c$i;"
+ }.mkString("\n ")
- q"if(i == $ordinal) { if(isNullAt($i)) return null else return $elementName }"
- }
- q"override def apply(i: Int): Any = { ..$cases; $accessorFailure }"
- }
-
- val updateFunction = {
- val cases = expressions.zipWithIndex.map {case (e, i) =>
- val ordinal = ru.Literal(Constant(i))
- val elementName = newTermName(s"c$i")
- val iLit = ru.Literal(Constant(i))
-
- q"""
- if(i == $ordinal) {
- if(value == null) {
- setNullAt(i)
- } else {
- nullBits(i) = false
- $elementName = value.asInstanceOf[${termForType(e.dataType)}]
- }
- return
- }"""
- }
- q"override def update(i: Int, value: Any): Unit = { ..$cases; $accessorFailure }"
- }
+ val updateCases = expressions.zipWithIndex.map { case (e, i) =>
+ s"case $i: { c$i = (${termForType(e.dataType)})value; return;}"
+ }.mkString("\n ")
val specificAccessorFunctions = nativeTypes.map { dataType =>
- val ifStatements = expressions.zipWithIndex.flatMap {
- // getString() is not used by expressions
- case (e, i) if e.dataType == dataType && dataType != StringType =>
- val elementName = newTermName(s"c$i")
- // TODO: The string of ifs gets pretty inefficient as the row grows in size.
- // TODO: Optional null checks?
- q"if(i == $i) return $elementName" :: Nil
- case _ => Nil
- }
- dataType match {
- // Row() need this interface to compile
- case StringType =>
- q"""
- override def getString(i: Int): String = {
- $accessorFailure
- }"""
- case other =>
- q"""
- override def ${accessorForType(dataType)}(i: Int): ${termForType(dataType)} = {
- ..$ifStatements;
- $accessorFailure
- }"""
+ val cases = expressions.zipWithIndex.map {
+ case (e, i) if e.dataType == dataType =>
+ s"case $i: return c$i;"
+ case _ => ""
+ }.mkString("\n ")
+ if (cases.count(_ != '\n') > 0) {
+ s"""
+ @Override
+ public ${primitiveForType(dataType)} ${accessorForType(dataType)}(int i) {
+ if (isNullAt(i)) {
+ return ${defaultPrimitive(dataType)};
+ }
+ switch (i) {
+ $cases
+ }
+ return ${defaultPrimitive(dataType)};
+ }"""
+ } else {
+ ""
}
- }
+ }.mkString("\n")
val specificMutatorFunctions = nativeTypes.map { dataType =>
- val ifStatements = expressions.zipWithIndex.flatMap {
- // setString() is not used by expressions
- case (e, i) if e.dataType == dataType && dataType != StringType =>
- val elementName = newTermName(s"c$i")
- // TODO: The string of ifs gets pretty inefficient as the row grows in size.
- // TODO: Optional null checks?
- q"if(i == $i) { nullBits($i) = false; $elementName = value; return }" :: Nil
- case _ => Nil
- }
- dataType match {
- case StringType =>
- // MutableRow() need this interface to compile
- q"""
- override def setString(i: Int, value: String) {
- $accessorFailure
- }"""
- case other =>
- q"""
- override def ${mutatorForType(dataType)}(i: Int, value: ${termForType(dataType)}) {
- ..$ifStatements;
- $accessorFailure
- }"""
+ val cases = expressions.zipWithIndex.map {
+ case (e, i) if e.dataType == dataType =>
+ s"case $i: { c$i = value; return; }"
+ case _ => ""
+ }.mkString("\n")
+ if (cases.count(_ != '\n') > 0) {
+ s"""
+ @Override
+ public void ${mutatorForType(dataType)}(int i, ${primitiveForType(dataType)} value) {
+ nullBits[i] = false;
+ switch (i) {
+ $cases
+ }
+ }"""
+ } else {
+ ""
}
- }
+ }.mkString("\n")
val hashValues = expressions.zipWithIndex.map { case (e, i) =>
- val elementName = newTermName(s"c$i")
+ val col = newTermName(s"c$i")
val nonNull = e.dataType match {
- case BooleanType => q"if ($elementName) 0 else 1"
- case ByteType | ShortType | IntegerType => q"$elementName.toInt"
- case LongType => q"($elementName ^ ($elementName >>> 32)).toInt"
- case FloatType => q"java.lang.Float.floatToIntBits($elementName)"
+ case BooleanType => s"$col ? 0 : 1"
+ case ByteType | ShortType | IntegerType | DateType => s"$col"
+ case LongType => s"$col ^ ($col >>> 32)"
+ case FloatType => s"Float.floatToIntBits($col)"
case DoubleType =>
- q"{ val b = java.lang.Double.doubleToLongBits($elementName); (b ^ (b >>>32)).toInt }"
- case _ => q"$elementName.hashCode"
+ s"Double.doubleToLongBits($col) ^ (Double.doubleToLongBits($col) >>> 32)"
+ case _ => s"$col.hashCode()"
}
- q"if (isNullAt($i)) 0 else $nonNull"
+ s"isNullAt($i) ? 0 : ($nonNull)"
}
- val hashUpdates: Seq[Tree] = hashValues.map(v => q"""result = 37 * result + $v""": Tree)
+ val hashUpdates: String = hashValues.map( v =>
+ s"""
+ result *= 37; result += $v;"""
+ ).mkString("\n")
- val hashCodeFunction =
- q"""
- override def hashCode(): Int = {
- var result: Int = 37
- ..$hashUpdates
- result
- }
+ val columnChecks = expressions.zipWithIndex.map { case (e, i) =>
+ s"""
+ if (isNullAt($i) != row.isNullAt($i) || !isNullAt($i) && !get($i).equals(row.get($i))) {
+ return false;
+ }
"""
+ }.mkString("\n")
- val columnChecks = (0 until expressions.size).map { i =>
- val elementName = newTermName(s"c$i")
- q"if (this.$elementName != specificType.$elementName) return false"
+ val code = s"""
+ import org.apache.spark.sql.Row;
+
+ public SpecificProjection generate($exprType[] expr) {
+ return new SpecificProjection(expr);
}
- val equalsFunction =
- q"""
- override def equals(other: Any): Boolean = other match {
- case specificType: SpecificRow =>
- ..$columnChecks
- return true
- case other => super.equals(other)
- }
- """
+ class SpecificProjection extends ${typeOf[BaseProject]} {
+ private $exprType[] expressions = null;
+
+ public SpecificProjection($exprType[] expr) {
+ expressions = expr;
+ }
- val allColumns = (0 until expressions.size).map { i =>
- val iLit = ru.Literal(Constant(i))
- q"if(isNullAt($iLit)) { null } else { ${newTermName(s"c$i")} }"
+ @Override
+ public Object apply(Object r) {
+ return new SpecificRow(expressions, (Row) r);
+ }
}
- val copyFunction =
- q"override def copy() = new $genericRowType(Array[Any](..$allColumns))"
-
- val toSeqFunction =
- q"override def toSeq: Seq[Any] = Seq(..$allColumns)"
-
- val classBody =
- nullFunctions ++ (
- lengthDef +:
- applyFunction +:
- updateFunction +:
- equalsFunction +:
- hashCodeFunction +:
- copyFunction +:
- toSeqFunction +:
- (tupleElements ++ specificAccessorFunctions ++ specificMutatorFunctions))
-
- val code = q"""
- final class SpecificRow(i: $rowType) extends $mutableRowType {
- ..$classBody
+ final class SpecificRow extends ${typeOf[BaseMutableRow]} {
+
+ $columns
+
+ public SpecificRow($exprType[] expressions, Row i) {
+ $initColumns
+ }
+
+ public int size() { return ${expressions.length};}
+ private boolean[] nullBits = new boolean[${expressions.length}];
+ public void setNullAt(int i) { nullBits[i] = true; }
+ public boolean isNullAt(int i) { return nullBits[i]; }
+
+ public Object get(int i) {
+ if (isNullAt(i)) return null;
+ switch (i) {
+ $getCases
+ }
+ return null;
+ }
+ public void update(int i, Object value) {
+ if (value == null) {
+ setNullAt(i);
+ return;
+ }
+ nullBits[i] = false;
+ switch (i) {
+ $updateCases
+ }
+ }
+ $specificAccessorFunctions
+ $specificMutatorFunctions
+
+ @Override
+ public int hashCode() {
+ int result = 37;
+ $hashUpdates
+ return result;
}
- new $projectionType { def apply(r: $rowType) = new SpecificRow(r) }
+ @Override
+ public boolean equals(Object other) {
+ if (other instanceof Row) {
+ Row row = (Row) other;
+ if (row.length() != size()) return false;
+ $columnChecks
+ return true;
+ }
+ return super.equals(other);
+ }
+ }
"""
- log.debug(
- s"MutableRow, initExprs: ${expressions.mkString(",")} code:\n${toolBox.typeCheck(code)}")
- toolBox.eval(code).asInstanceOf[Projection]
+ logDebug(s"MutableRow, initExprs: ${expressions.mkString(",")} code:\n${code}")
+
+ val c = compile(code)
+ // fetch the only one method `generate(Expression[])`
+ val m = c.getDeclaredMethods()(0)
+ m.invoke(c.newInstance(), ctx.references.toArray).asInstanceOf[Projection]
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/package.scala
index 528e38a50a..7f1b12cdd5 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/package.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/package.scala
@@ -27,12 +27,6 @@ import org.apache.spark.util.Utils
*/
package object codegen {
- /**
- * A lock to protect invoking the scala compiler at runtime, since it is not thread safe in Scala
- * 2.10.
- */
- protected[codegen] val globalLock = org.apache.spark.sql.catalyst.ScalaReflectionLock
-
/** Canonicalizes an expression so those that differ only by names can reuse the same code. */
object ExpressionCanonicalizer extends rules.RuleExecutor[Expression] {
val batches =
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala
index b6927485f4..5df528770c 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala
@@ -344,7 +344,7 @@ class ExpressionEvaluationSuite extends ExpressionEvaluationBaseSuite {
checkEvaluation("abdef" cast TimestampType, null)
checkEvaluation("12.65" cast DecimalType.Unlimited, Decimal(12.65))
- checkEvaluation(Literal(1) cast LongType, 1)
+ checkEvaluation(Literal(1) cast LongType, 1.toLong)
checkEvaluation(Cast(Literal(1000) cast TimestampType, LongType), 1.toLong)
checkEvaluation(Cast(Literal(-1200) cast TimestampType, LongType), -2.toLong)
checkEvaluation(Cast(Literal(1.toDouble) cast TimestampType, DoubleType), 1.toDouble)
@@ -363,13 +363,16 @@ class ExpressionEvaluationSuite extends ExpressionEvaluationBaseSuite {
checkEvaluation(Cast("abdef" cast BinaryType, StringType), "abdef")
checkEvaluation(Cast(Cast(Cast(Cast(
- Cast("5" cast ByteType, ShortType), IntegerType), FloatType), DoubleType), LongType), 5)
+ Cast("5" cast ByteType, ShortType), IntegerType), FloatType), DoubleType), LongType),
+ 5.toLong)
checkEvaluation(Cast(Cast(Cast(Cast(Cast("5" cast
- ByteType, TimestampType), DecimalType.Unlimited), LongType), StringType), ShortType), 0)
+ ByteType, TimestampType), DecimalType.Unlimited), LongType), StringType), ShortType),
+ 0.toShort)
checkEvaluation(Cast(Cast(Cast(Cast(Cast("5" cast
TimestampType, ByteType), DecimalType.Unlimited), LongType), StringType), ShortType), null)
checkEvaluation(Cast(Cast(Cast(Cast(Cast("5" cast
- DecimalType.Unlimited, ByteType), TimestampType), LongType), StringType), ShortType), 0)
+ DecimalType.Unlimited, ByteType), TimestampType), LongType), StringType), ShortType),
+ 0.toShort)
checkEvaluation(Literal(true) cast IntegerType, 1)
checkEvaluation(Literal(false) cast IntegerType, 0)
checkEvaluation(Literal(true) cast StringType, "true")
@@ -509,9 +512,9 @@ class ExpressionEvaluationSuite extends ExpressionEvaluationBaseSuite {
val seconds = millis * 1000 + 2
val ts = new Timestamp(millis)
val tss = new Timestamp(seconds)
- checkEvaluation(Cast(ts, ShortType), 15)
+ checkEvaluation(Cast(ts, ShortType), 15.toShort)
checkEvaluation(Cast(ts, IntegerType), 15)
- checkEvaluation(Cast(ts, LongType), 15)
+ checkEvaluation(Cast(ts, LongType), 15.toLong)
checkEvaluation(Cast(ts, FloatType), 15.002f)
checkEvaluation(Cast(ts, DoubleType), 15.002)
checkEvaluation(Cast(Cast(tss, ShortType), TimestampType), ts)
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/GeneratedEvaluationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/GeneratedEvaluationSuite.scala
index d7c437095e..8cfd853afa 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/GeneratedEvaluationSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/GeneratedEvaluationSuite.scala
@@ -32,11 +32,12 @@ class GeneratedEvaluationSuite extends ExpressionEvaluationSuite {
GenerateMutableProjection.generate(Alias(expression, s"Optimized($expression)")() :: Nil)()
} catch {
case e: Throwable =>
- val evaluated = GenerateProjection.expressionEvaluator(expression)
+ val ctx = GenerateProjection.newCodeGenContext()
+ val evaluated = GenerateProjection.expressionEvaluator(expression, ctx)
fail(
s"""
|Code generation of $expression failed:
- |${evaluated.code.mkString("\n")}
+ |${evaluated.code}
|$e
""".stripMargin)
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/GeneratedMutableEvaluationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/GeneratedMutableEvaluationSuite.scala
index a40324b008..9ab1f7d7ad 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/GeneratedMutableEvaluationSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/GeneratedMutableEvaluationSuite.scala
@@ -28,7 +28,8 @@ class GeneratedMutableEvaluationSuite extends ExpressionEvaluationSuite {
expression: Expression,
expected: Any,
inputRow: Row = EmptyRow): Unit = {
- lazy val evaluated = GenerateProjection.expressionEvaluator(expression)
+ val ctx = GenerateProjection.newCodeGenContext()
+ lazy val evaluated = GenerateProjection.expressionEvaluator(expression, ctx)
val plan = try {
GenerateProjection.generate(Alias(expression, s"Optimized($expression)")() :: Nil)
@@ -37,7 +38,7 @@ class GeneratedMutableEvaluationSuite extends ExpressionEvaluationSuite {
fail(
s"""
|Code generation of $expression failed:
- |${evaluated.code.mkString("\n")}
+ |${evaluated.code}
|$e
""".stripMargin)
}
@@ -49,7 +50,7 @@ class GeneratedMutableEvaluationSuite extends ExpressionEvaluationSuite {
s"""
|Mismatched hashCodes for values: $actual, $expectedRow
|Hash Codes: ${actual.hashCode()} != ${expectedRow.hashCode()}
- |${evaluated.code.mkString("\n")}
+ |${evaluated.code}
""".stripMargin)
}
if (actual != expectedRow) {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
index 9aaec2b064..b41b1b77d0 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
@@ -451,10 +451,13 @@ class DataFrameSuite extends QueryTest {
test("SPARK-6899") {
val originalValue = TestSQLContext.conf.codegenEnabled
TestSQLContext.setConf(SQLConf.CODEGEN_ENABLED, "true")
- checkAnswer(
- decimalData.agg(avg('a)),
- Row(new java.math.BigDecimal(2.0)))
- TestSQLContext.setConf(SQLConf.CODEGEN_ENABLED, originalValue.toString)
+ try{
+ checkAnswer(
+ decimalData.agg(avg('a)),
+ Row(new java.math.BigDecimal(2.0)))
+ } finally {
+ TestSQLContext.setConf(SQLConf.CODEGEN_ENABLED, originalValue.toString)
+ }
}
test("SPARK-7133: Implement struct, array, and map field accessor") {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
index 63f7d314fb..55b68d8e22 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
@@ -184,77 +184,79 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils {
checkAnswer(df, expectedResults)
}
- // Just to group rows.
- testCodeGen(
- "SELECT key FROM testData3x GROUP BY key",
- (1 to 100).map(Row(_)))
- // COUNT
- testCodeGen(
- "SELECT key, count(value) FROM testData3x GROUP BY key",
- (1 to 100).map(i => Row(i, 3)))
- testCodeGen(
- "SELECT count(key) FROM testData3x",
- Row(300) :: Nil)
- // COUNT DISTINCT ON int
- testCodeGen(
- "SELECT value, count(distinct key) FROM testData3x GROUP BY value",
- (1 to 100).map(i => Row(i.toString, 1)))
- testCodeGen(
- "SELECT count(distinct key) FROM testData3x",
- Row(100) :: Nil)
- // SUM
- testCodeGen(
- "SELECT value, sum(key) FROM testData3x GROUP BY value",
- (1 to 100).map(i => Row(i.toString, 3 * i)))
- testCodeGen(
- "SELECT sum(key), SUM(CAST(key as Double)) FROM testData3x",
- Row(5050 * 3, 5050 * 3.0) :: Nil)
- // AVERAGE
- testCodeGen(
- "SELECT value, avg(key) FROM testData3x GROUP BY value",
- (1 to 100).map(i => Row(i.toString, i)))
- testCodeGen(
- "SELECT avg(key) FROM testData3x",
- Row(50.5) :: Nil)
- // MAX
- testCodeGen(
- "SELECT value, max(key) FROM testData3x GROUP BY value",
- (1 to 100).map(i => Row(i.toString, i)))
- testCodeGen(
- "SELECT max(key) FROM testData3x",
- Row(100) :: Nil)
- // MIN
- testCodeGen(
- "SELECT value, min(key) FROM testData3x GROUP BY value",
- (1 to 100).map(i => Row(i.toString, i)))
- testCodeGen(
- "SELECT min(key) FROM testData3x",
- Row(1) :: Nil)
- // Some combinations.
- testCodeGen(
- """
- |SELECT
- | value,
- | sum(key),
- | max(key),
- | min(key),
- | avg(key),
- | count(key),
- | count(distinct key)
- |FROM testData3x
- |GROUP BY value
- """.stripMargin,
- (1 to 100).map(i => Row(i.toString, i*3, i, i, i, 3, 1)))
- testCodeGen(
- "SELECT max(key), min(key), avg(key), count(key), count(distinct key) FROM testData3x",
- Row(100, 1, 50.5, 300, 100) :: Nil)
- // Aggregate with Code generation handling all null values
- testCodeGen(
- "SELECT sum('a'), avg('a'), count(null) FROM testData",
- Row(0, null, 0) :: Nil)
-
- dropTempTable("testData3x")
- setConf(SQLConf.CODEGEN_ENABLED, originalValue.toString)
+ try {
+ // Just to group rows.
+ testCodeGen(
+ "SELECT key FROM testData3x GROUP BY key",
+ (1 to 100).map(Row(_)))
+ // COUNT
+ testCodeGen(
+ "SELECT key, count(value) FROM testData3x GROUP BY key",
+ (1 to 100).map(i => Row(i, 3)))
+ testCodeGen(
+ "SELECT count(key) FROM testData3x",
+ Row(300) :: Nil)
+ // COUNT DISTINCT ON int
+ testCodeGen(
+ "SELECT value, count(distinct key) FROM testData3x GROUP BY value",
+ (1 to 100).map(i => Row(i.toString, 1)))
+ testCodeGen(
+ "SELECT count(distinct key) FROM testData3x",
+ Row(100) :: Nil)
+ // SUM
+ testCodeGen(
+ "SELECT value, sum(key) FROM testData3x GROUP BY value",
+ (1 to 100).map(i => Row(i.toString, 3 * i)))
+ testCodeGen(
+ "SELECT sum(key), SUM(CAST(key as Double)) FROM testData3x",
+ Row(5050 * 3, 5050 * 3.0) :: Nil)
+ // AVERAGE
+ testCodeGen(
+ "SELECT value, avg(key) FROM testData3x GROUP BY value",
+ (1 to 100).map(i => Row(i.toString, i)))
+ testCodeGen(
+ "SELECT avg(key) FROM testData3x",
+ Row(50.5) :: Nil)
+ // MAX
+ testCodeGen(
+ "SELECT value, max(key) FROM testData3x GROUP BY value",
+ (1 to 100).map(i => Row(i.toString, i)))
+ testCodeGen(
+ "SELECT max(key) FROM testData3x",
+ Row(100) :: Nil)
+ // MIN
+ testCodeGen(
+ "SELECT value, min(key) FROM testData3x GROUP BY value",
+ (1 to 100).map(i => Row(i.toString, i)))
+ testCodeGen(
+ "SELECT min(key) FROM testData3x",
+ Row(1) :: Nil)
+ // Some combinations.
+ testCodeGen(
+ """
+ |SELECT
+ | value,
+ | sum(key),
+ | max(key),
+ | min(key),
+ | avg(key),
+ | count(key),
+ | count(distinct key)
+ |FROM testData3x
+ |GROUP BY value
+ """.stripMargin,
+ (1 to 100).map(i => Row(i.toString, i*3, i, i, i, 3, 1)))
+ testCodeGen(
+ "SELECT max(key), min(key), avg(key), count(key), count(distinct key) FROM testData3x",
+ Row(100, 1, 50.5, 300, 100) :: Nil)
+ // Aggregate with Code generation handling all null values
+ testCodeGen(
+ "SELECT sum('a'), avg('a'), count(null) FROM testData",
+ Row(0, null, 0) :: Nil)
+ } finally {
+ dropTempTable("testData3x")
+ setConf(SQLConf.CODEGEN_ENABLED, originalValue.toString)
+ }
}
test("Add Parser of SQL COALESCE()") {
@@ -463,9 +465,12 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils {
val codegenbefore = conf.codegenEnabled
setConf(SQLConf.EXTERNAL_SORT, "false")
setConf(SQLConf.CODEGEN_ENABLED, "true")
- sortTest()
- setConf(SQLConf.EXTERNAL_SORT, externalbefore.toString)
- setConf(SQLConf.CODEGEN_ENABLED, codegenbefore.toString)
+ try{
+ sortTest()
+ } finally {
+ setConf(SQLConf.EXTERNAL_SORT, externalbefore.toString)
+ setConf(SQLConf.CODEGEN_ENABLED, codegenbefore.toString)
+ }
}
test("SPARK-6927 external sorting with codegen on") {
@@ -473,9 +478,12 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils {
val codegenbefore = conf.codegenEnabled
setConf(SQLConf.CODEGEN_ENABLED, "true")
setConf(SQLConf.EXTERNAL_SORT, "true")
- sortTest()
- setConf(SQLConf.EXTERNAL_SORT, externalbefore.toString)
- setConf(SQLConf.CODEGEN_ENABLED, codegenbefore.toString)
+ try {
+ sortTest()
+ } finally {
+ setConf(SQLConf.EXTERNAL_SORT, externalbefore.toString)
+ setConf(SQLConf.CODEGEN_ENABLED, codegenbefore.toString)
+ }
}
test("limit") {