aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--core/src/main/scala/org/apache/spark/util/collection/OpenHashSet.scala12
-rw-r--r--pom.xml10
-rw-r--r--project/SparkBuild.scala11
-rw-r--r--sql/catalyst/pom.xml16
-rw-r--r--sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java101
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/BaseMutableRow.java68
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/BaseRow.java190
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala797
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala87
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala146
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratePredicate.scala44
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala316
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/package.scala6
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala15
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/GeneratedEvaluationSuite.scala5
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/GeneratedMutableEvaluationSuite.scala7
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala11
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala162
18 files changed, 1116 insertions, 888 deletions
diff --git a/core/src/main/scala/org/apache/spark/util/collection/OpenHashSet.scala b/core/src/main/scala/org/apache/spark/util/collection/OpenHashSet.scala
index 1501111a06..64e7102e36 100644
--- a/core/src/main/scala/org/apache/spark/util/collection/OpenHashSet.scala
+++ b/core/src/main/scala/org/apache/spark/util/collection/OpenHashSet.scala
@@ -20,6 +20,8 @@ package org.apache.spark.util.collection
import scala.reflect._
import com.google.common.hash.Hashing
+import org.apache.spark.annotation.Private
+
/**
* A simple, fast hash set optimized for non-null insertion-only use case, where keys are never
* removed.
@@ -37,7 +39,7 @@ import com.google.common.hash.Hashing
* It uses quadratic probing with a power-of-2 hash table size, which is guaranteed
* to explore all spaces for each key (see http://en.wikipedia.org/wiki/Quadratic_probing).
*/
-private[spark]
+@Private
class OpenHashSet[@specialized(Long, Int) T: ClassTag](
initialCapacity: Int,
loadFactor: Double)
@@ -110,6 +112,14 @@ class OpenHashSet[@specialized(Long, Int) T: ClassTag](
rehashIfNeeded(k, grow, move)
}
+ def union(other: OpenHashSet[T]): OpenHashSet[T] = {
+ val iterator = other.iterator
+ while (iterator.hasNext) {
+ add(iterator.next())
+ }
+ this
+ }
+
/**
* Add an element to the set. This one differs from add in that it doesn't trigger rehashing.
* The caller is responsible for calling rehashIfNeeded.
diff --git a/pom.xml b/pom.xml
index d03d33bf02..bcb6ef96a1 100644
--- a/pom.xml
+++ b/pom.xml
@@ -118,7 +118,6 @@
<akka.version>2.3.4-spark</akka.version>
<java.version>1.6</java.version>
<sbt.project.name>spark</sbt.project.name>
- <scala.macros.version>2.0.1</scala.macros.version>
<mesos.version>0.21.1</mesos.version>
<mesos.classifier>shaded-protobuf</mesos.classifier>
<slf4j.version>1.7.10</slf4j.version>
@@ -1217,15 +1216,6 @@
<javacArg>-target</javacArg>
<javacArg>${java.version}</javacArg>
</javacArgs>
- <!-- The following plugin is required to use quasiquotes in Scala 2.10 and is used
- by Spark SQL for code generation. -->
- <compilerPlugins>
- <compilerPlugin>
- <groupId>org.scalamacros</groupId>
- <artifactId>paradise_${scala.version}</artifactId>
- <version>${scala.macros.version}</version>
- </compilerPlugin>
- </compilerPlugins>
</configuration>
</plugin>
<plugin>
diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala
index 9a84963923..f65031fe25 100644
--- a/project/SparkBuild.scala
+++ b/project/SparkBuild.scala
@@ -178,9 +178,6 @@ object SparkBuild extends PomBuild {
/* Enable unidoc only for the root spark project */
enable(Unidoc.settings)(spark)
- /* Catalyst macro settings */
- enable(Catalyst.settings)(catalyst)
-
/* Spark SQL Core console settings */
enable(SQL.settings)(sql)
@@ -275,14 +272,6 @@ object OldDeps {
)
}
-object Catalyst {
- lazy val settings = Seq(
- addCompilerPlugin("org.scalamacros" % "paradise" % "2.0.1" cross CrossVersion.full),
- // Quasiquotes break compiling scala doc...
- // TODO: Investigate fixing this.
- sources in (Compile, doc) ~= (_ filter (_.getName contains "codegen")))
-}
-
object SQL {
lazy val settings = Seq(
initialCommands in console :=
diff --git a/sql/catalyst/pom.xml b/sql/catalyst/pom.xml
index bf0a7327a5..f4b1cc3a4f 100644
--- a/sql/catalyst/pom.xml
+++ b/sql/catalyst/pom.xml
@@ -38,10 +38,6 @@
<dependencies>
<dependency>
<groupId>org.scala-lang</groupId>
- <artifactId>scala-compiler</artifactId>
- </dependency>
- <dependency>
- <groupId>org.scala-lang</groupId>
<artifactId>scala-reflect</artifactId>
</dependency>
@@ -67,6 +63,11 @@
<artifactId>scalacheck_${scala.binary.version}</artifactId>
<scope>test</scope>
</dependency>
+ <dependency>
+ <groupId>org.codehaus.janino</groupId>
+ <artifactId>janino</artifactId>
+ <version>2.7.8</version>
+ </dependency>
</dependencies>
<build>
<outputDirectory>target/scala-${scala.binary.version}/classes</outputDirectory>
@@ -108,13 +109,6 @@
<activation>
<property><name>!scala-2.11</name></property>
</activation>
- <dependencies>
- <dependency>
- <groupId>org.scalamacros</groupId>
- <artifactId>quasiquotes_${scala.binary.version}</artifactId>
- <version>${scala.macros.version}</version>
- </dependency>
- </dependencies>
</profile>
</profiles>
</project>
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java
index bb546b3086..ec97fe603c 100644
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java
@@ -17,23 +17,25 @@
package org.apache.spark.sql.catalyst.expressions;
-import scala.collection.Map;
+import javax.annotation.Nullable;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.HashSet;
+import java.util.Set;
+
import scala.collection.Seq;
import scala.collection.mutable.ArraySeq;
-import javax.annotation.Nullable;
-import java.math.BigDecimal;
-import java.sql.Date;
-import java.util.*;
-
import org.apache.spark.sql.Row;
+import org.apache.spark.sql.BaseMutableRow;
import org.apache.spark.sql.types.DataType;
-import static org.apache.spark.sql.types.DataTypes.*;
import org.apache.spark.sql.types.StructType;
import org.apache.spark.sql.types.UTF8String;
import org.apache.spark.unsafe.PlatformDependent;
import org.apache.spark.unsafe.bitset.BitSetMethods;
+import static org.apache.spark.sql.types.DataTypes.*;
+
/**
* An Unsafe implementation of Row which is backed by raw memory instead of Java objects.
*
@@ -49,7 +51,7 @@ import org.apache.spark.unsafe.bitset.BitSetMethods;
*
* Instances of `UnsafeRow` act as pointers to row data stored in this format.
*/
-public final class UnsafeRow implements MutableRow {
+public final class UnsafeRow extends BaseMutableRow {
private Object baseObject;
private long baseOffset;
@@ -228,21 +230,11 @@ public final class UnsafeRow implements MutableRow {
}
@Override
- public int length() {
- return size();
- }
-
- @Override
public StructType schema() {
return schema;
}
@Override
- public Object apply(int i) {
- return get(i);
- }
-
- @Override
public Object get(int i) {
assertIndexIsValid(i);
assert (schema != null) : "Schema must be defined when calling generic get() method";
@@ -339,60 +331,7 @@ public final class UnsafeRow implements MutableRow {
return getUTF8String(i).toString();
}
- @Override
- public BigDecimal getDecimal(int i) {
- throw new UnsupportedOperationException();
- }
-
- @Override
- public Date getDate(int i) {
- throw new UnsupportedOperationException();
- }
- @Override
- public <T> Seq<T> getSeq(int i) {
- throw new UnsupportedOperationException();
- }
-
- @Override
- public <T> List<T> getList(int i) {
- throw new UnsupportedOperationException();
- }
-
- @Override
- public <K, V> Map<K, V> getMap(int i) {
- throw new UnsupportedOperationException();
- }
-
- @Override
- public <T> scala.collection.immutable.Map<String, T> getValuesMap(Seq<String> fieldNames) {
- throw new UnsupportedOperationException();
- }
-
- @Override
- public <K, V> java.util.Map<K, V> getJavaMap(int i) {
- throw new UnsupportedOperationException();
- }
-
- @Override
- public Row getStruct(int i) {
- throw new UnsupportedOperationException();
- }
-
- @Override
- public <T> T getAs(int i) {
- throw new UnsupportedOperationException();
- }
-
- @Override
- public <T> T getAs(String fieldName) {
- throw new UnsupportedOperationException();
- }
-
- @Override
- public int fieldIndex(String name) {
- throw new UnsupportedOperationException();
- }
@Override
public Row copy() {
@@ -412,24 +351,4 @@ public final class UnsafeRow implements MutableRow {
}
return values;
}
-
- @Override
- public String toString() {
- return mkString("[", ",", "]");
- }
-
- @Override
- public String mkString() {
- return toSeq().mkString();
- }
-
- @Override
- public String mkString(String sep) {
- return toSeq().mkString(sep);
- }
-
- @Override
- public String mkString(String start, String sep, String end) {
- return toSeq().mkString(start, sep, end);
- }
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/BaseMutableRow.java b/sql/catalyst/src/main/scala/org/apache/spark/sql/BaseMutableRow.java
new file mode 100644
index 0000000000..acec2bf452
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/BaseMutableRow.java
@@ -0,0 +1,68 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql;
+
+import org.apache.spark.sql.catalyst.expressions.MutableRow;
+
+public abstract class BaseMutableRow extends BaseRow implements MutableRow {
+
+ @Override
+ public void update(int ordinal, Object value) {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public void setInt(int ordinal, int value) {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public void setLong(int ordinal, long value) {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public void setDouble(int ordinal, double value) {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public void setBoolean(int ordinal, boolean value) {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public void setShort(int ordinal, short value) {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public void setByte(int ordinal, byte value) {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public void setFloat(int ordinal, float value) {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public void setString(int ordinal, String value) {
+ throw new UnsupportedOperationException();
+ }
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/BaseRow.java b/sql/catalyst/src/main/scala/org/apache/spark/sql/BaseRow.java
new file mode 100644
index 0000000000..d138b43a34
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/BaseRow.java
@@ -0,0 +1,190 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql;
+
+import java.math.BigDecimal;
+import java.sql.Date;
+import java.util.List;
+
+import scala.collection.Seq;
+import scala.collection.mutable.ArraySeq;
+
+import org.apache.spark.sql.catalyst.expressions.GenericRow;
+import org.apache.spark.sql.types.StructType;
+
+public abstract class BaseRow implements Row {
+
+ @Override
+ final public int length() {
+ return size();
+ }
+
+ @Override
+ public boolean anyNull() {
+ final int n = size();
+ for (int i=0; i < n; i++) {
+ if (isNullAt(i)) {
+ return true;
+ }
+ }
+ return false;
+ }
+
+ @Override
+ public StructType schema() { throw new UnsupportedOperationException(); }
+
+ @Override
+ final public Object apply(int i) {
+ return get(i);
+ }
+
+ @Override
+ public int getInt(int i) {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public long getLong(int i) {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public float getFloat(int i) {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public double getDouble(int i) {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public byte getByte(int i) {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public short getShort(int i) {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public boolean getBoolean(int i) {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public String getString(int i) {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public BigDecimal getDecimal(int i) {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public Date getDate(int i) {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public <T> Seq<T> getSeq(int i) {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public <T> List<T> getList(int i) {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public <K, V> scala.collection.Map<K, V> getMap(int i) {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public <T> scala.collection.immutable.Map<String, T> getValuesMap(Seq<String> fieldNames) {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public <K, V> java.util.Map<K, V> getJavaMap(int i) {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public Row getStruct(int i) {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public <T> T getAs(int i) {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public <T> T getAs(String fieldName) {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public int fieldIndex(String name) {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public Row copy() {
+ final int n = size();
+ Object[] arr = new Object[n];
+ for (int i = 0; i < n; i++) {
+ arr[i] = get(i);
+ }
+ return new GenericRow(arr);
+ }
+
+ @Override
+ public Seq<Object> toSeq() {
+ final int n = size();
+ final ArraySeq<Object> values = new ArraySeq<Object>(n);
+ for (int i = 0; i < n; i++) {
+ values.update(i, get(i));
+ }
+ return values;
+ }
+
+ @Override
+ public String toString() {
+ return mkString("[", ",", "]");
+ }
+
+ @Override
+ public String mkString() {
+ return toSeq().mkString();
+ }
+
+ @Override
+ public String mkString(String sep) {
+ return toSeq().mkString(sep);
+ }
+
+ @Override
+ public String mkString(String start, String sep, String end) {
+ return toSeq().mkString(start, sep, end);
+ }
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
index 36964af68d..cd604121b7 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
@@ -17,10 +17,12 @@
package org.apache.spark.sql.catalyst.expressions.codegen
-import com.google.common.cache.{CacheLoader, CacheBuilder}
-
+import scala.collection.mutable
import scala.language.existentials
+import com.google.common.cache.{CacheBuilder, CacheLoader}
+import org.codehaus.janino.ClassBodyEvaluator
+
import org.apache.spark.Logging
import org.apache.spark.sql.catalyst.expressions
import org.apache.spark.sql.catalyst.expressions._
@@ -36,23 +38,15 @@ class LongHashSet extends org.apache.spark.util.collection.OpenHashSet[Long]
* expressions.
*/
abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Logging {
- import scala.reflect.runtime.{universe => ru}
- import scala.reflect.runtime.universe._
-
- import scala.tools.reflect.ToolBox
-
- protected val toolBox = runtimeMirror(getClass.getClassLoader).mkToolBox()
- protected val rowType = typeOf[Row]
- protected val mutableRowType = typeOf[MutableRow]
- protected val genericRowType = typeOf[GenericRow]
- protected val genericMutableRowType = typeOf[GenericMutableRow]
-
- protected val projectionType = typeOf[Projection]
- protected val mutableProjectionType = typeOf[MutableProjection]
+ protected val rowType = classOf[Row].getName
+ protected val stringType = classOf[UTF8String].getName
+ protected val decimalType = classOf[Decimal].getName
+ protected val exprType = classOf[Expression].getName
+ protected val mutableRowType = classOf[MutableRow].getName
+ protected val genericMutableRowType = classOf[GenericMutableRow].getName
private val curId = new java.util.concurrent.atomic.AtomicInteger()
- private val javaSeparator = "$"
/**
* Can be flipped on manually in the console to add (expensive) expression evaluation trace code.
@@ -75,6 +69,20 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin
protected def bind(in: InType, inputSchema: Seq[Attribute]): InType
/**
+ * Compile the Java source code into a Java class, using Janino.
+ *
+ * It will track the time used to compile
+ */
+ protected def compile(code: String): Class[_] = {
+ val startTime = System.nanoTime()
+ val clazz = new ClassBodyEvaluator(code).getClazz()
+ val endTime = System.nanoTime()
+ def timeMs: Double = (endTime - startTime).toDouble / 1000000
+ logDebug(s"Compiled Java code (${code.size} bytes) in $timeMs ms")
+ clazz
+ }
+
+ /**
* A cache of generated classes.
*
* From the Guava Docs: A Cache is similar to ConcurrentMap, but not quite the same. The most
@@ -87,7 +95,7 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin
.maximumSize(1000)
.build(
new CacheLoader[InType, OutType]() {
- override def load(in: InType): OutType = globalLock.synchronized {
+ override def load(in: InType): OutType = {
val startTime = System.nanoTime()
val result = create(in)
val endTime = System.nanoTime()
@@ -110,8 +118,8 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin
* (Since we aren't in a macro context we do not seem to have access to the built in `freshName`
* function.)
*/
- protected def freshName(prefix: String): TermName = {
- newTermName(s"$prefix$javaSeparator${curId.getAndIncrement}")
+ protected def freshName(prefix: String): String = {
+ s"$prefix${curId.getAndIncrement}"
}
/**
@@ -125,32 +133,51 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin
* @param objectTerm A possibly boxed version of the result of evaluating this expression.
*/
protected case class EvaluatedExpression(
- code: Seq[Tree],
- nullTerm: TermName,
- primitiveTerm: TermName,
- objectTerm: TermName)
+ code: String,
+ nullTerm: String,
+ primitiveTerm: String,
+ objectTerm: String)
+
+ /**
+ * A context for codegen, which is used to bookkeeping the expressions those are not supported
+ * by codegen, then they are evaluated directly. The unsupported expression is appended at the
+ * end of `references`, the position of it is kept in the code, used to access and evaluate it.
+ */
+ protected class CodeGenContext {
+ /**
+ * Holding all the expressions those do not support codegen, will be evaluated directly.
+ */
+ val references: mutable.ArrayBuffer[Expression] = new mutable.ArrayBuffer[Expression]()
+ }
+
+ /**
+ * Create a new codegen context for expression evaluator, used to store those
+ * expressions that don't support codegen
+ */
+ def newCodeGenContext(): CodeGenContext = {
+ new CodeGenContext()
+ }
/**
* Given an expression tree returns an [[EvaluatedExpression]], which contains Scala trees that
* can be used to determine the result of evaluating the expression on an input row.
*/
- def expressionEvaluator(e: Expression): EvaluatedExpression = {
+ def expressionEvaluator(e: Expression, ctx: CodeGenContext): EvaluatedExpression = {
val primitiveTerm = freshName("primitiveTerm")
val nullTerm = freshName("nullTerm")
val objectTerm = freshName("objectTerm")
implicit class Evaluate1(e: Expression) {
- def castOrNull(f: TermName => Tree, dataType: DataType): Seq[Tree] = {
- val eval = expressionEvaluator(e)
- eval.code ++
- q"""
- val $nullTerm = ${eval.nullTerm}
- val $primitiveTerm =
- if($nullTerm)
- ${defaultPrimitive(dataType)}
- else
- ${f(eval.primitiveTerm)}
- """.children
+ def castOrNull(f: String => String, dataType: DataType): String = {
+ val eval = expressionEvaluator(e, ctx)
+ eval.code +
+ s"""
+ boolean $nullTerm = ${eval.nullTerm};
+ ${primitiveForType(dataType)} $primitiveTerm = ${defaultPrimitive(dataType)};
+ if (!$nullTerm) {
+ $primitiveTerm = ${f(eval.primitiveTerm)};
+ }
+ """
}
}
@@ -163,529 +190,505 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin
*
* @param f a function from two primitive term names to a tree that evaluates them.
*/
- def evaluate(f: (TermName, TermName) => Tree): Seq[Tree] =
+ def evaluate(f: (String, String) => String): String =
evaluateAs(expressions._1.dataType)(f)
- def evaluateAs(resultType: DataType)(f: (TermName, TermName) => Tree): Seq[Tree] = {
+ def evaluateAs(resultType: DataType)(f: (String, String) => String): String = {
// TODO: Right now some timestamp tests fail if we enforce this...
if (expressions._1.dataType != expressions._2.dataType) {
log.warn(s"${expressions._1.dataType} != ${expressions._2.dataType}")
}
- val eval1 = expressionEvaluator(expressions._1)
- val eval2 = expressionEvaluator(expressions._2)
+ val eval1 = expressionEvaluator(expressions._1, ctx)
+ val eval2 = expressionEvaluator(expressions._2, ctx)
val resultCode = f(eval1.primitiveTerm, eval2.primitiveTerm)
- eval1.code ++ eval2.code ++
- q"""
- val $nullTerm = ${eval1.nullTerm} || ${eval2.nullTerm}
- val $primitiveTerm: ${termForType(resultType)} =
- if($nullTerm) {
- ${defaultPrimitive(resultType)}
- } else {
- $resultCode.asInstanceOf[${termForType(resultType)}]
- }
- """.children : Seq[Tree]
+ eval1.code + eval2.code +
+ s"""
+ boolean $nullTerm = ${eval1.nullTerm} || ${eval2.nullTerm};
+ ${primitiveForType(resultType)} $primitiveTerm = ${defaultPrimitive(resultType)};
+ if(!$nullTerm) {
+ $primitiveTerm = (${primitiveForType(resultType)})($resultCode);
+ }
+ """
}
}
- val inputTuple = newTermName(s"i")
+ val inputTuple = "i"
// TODO: Skip generation of null handling code when expression are not nullable.
- val primitiveEvaluation: PartialFunction[Expression, Seq[Tree]] = {
+ val primitiveEvaluation: PartialFunction[Expression, String] = {
case b @ BoundReference(ordinal, dataType, nullable) =>
- val nullValue = q"$inputTuple.isNullAt($ordinal)"
- q"""
- val $nullTerm: Boolean = $nullValue
- val $primitiveTerm: ${termForType(dataType)} =
- if($nullTerm)
- ${defaultPrimitive(dataType)}
- else
- ${getColumn(inputTuple, dataType, ordinal)}
- """.children
+ s"""
+ final boolean $nullTerm = $inputTuple.isNullAt($ordinal);
+ final ${primitiveForType(dataType)} $primitiveTerm = $nullTerm ?
+ ${defaultPrimitive(dataType)} : (${getColumn(inputTuple, dataType, ordinal)});
+ """
case expressions.Literal(null, dataType) =>
- q"""
- val $nullTerm = true
- val $primitiveTerm: ${termForType(dataType)} = null.asInstanceOf[${termForType(dataType)}]
- """.children
-
- case expressions.Literal(value: Boolean, dataType) =>
- q"""
- val $nullTerm = ${value == null}
- val $primitiveTerm: ${termForType(dataType)} = $value
- """.children
-
- case expressions.Literal(value: UTF8String, dataType) =>
- q"""
- val $nullTerm = ${value == null}
- val $primitiveTerm: ${termForType(dataType)} =
- org.apache.spark.sql.types.UTF8String(${value.getBytes})
- """.children
-
- case expressions.Literal(value: Int, dataType) =>
- q"""
- val $nullTerm = ${value == null}
- val $primitiveTerm: ${termForType(dataType)} = $value
- """.children
-
- case expressions.Literal(value: Long, dataType) =>
- q"""
- val $nullTerm = ${value == null}
- val $primitiveTerm: ${termForType(dataType)} = $value
- """.children
-
- case Cast(e @ BinaryType(), StringType) =>
- val eval = expressionEvaluator(e)
- eval.code ++
- q"""
- val $nullTerm = ${eval.nullTerm}
- val $primitiveTerm =
- if($nullTerm)
- ${defaultPrimitive(StringType)}
- else
- org.apache.spark.sql.types.UTF8String(${eval.primitiveTerm}.asInstanceOf[Array[Byte]])
- """.children
+ s"""
+ final boolean $nullTerm = true;
+ ${primitiveForType(dataType)} $primitiveTerm = ${defaultPrimitive(dataType)};
+ """
+
+ case expressions.Literal(value: UTF8String, StringType) =>
+ val arr = s"new byte[]{${value.getBytes.map(_.toString).mkString(", ")}}"
+ s"""
+ final boolean $nullTerm = false;
+ ${stringType} $primitiveTerm =
+ new ${stringType}().set(${arr});
+ """
+
+ case expressions.Literal(value, FloatType) =>
+ s"""
+ final boolean $nullTerm = false;
+ float $primitiveTerm = ${value}f;
+ """
+
+ case expressions.Literal(value, dt @ DecimalType()) =>
+ s"""
+ final boolean $nullTerm = false;
+ ${primitiveForType(dt)} $primitiveTerm = new ${primitiveForType(dt)}().set($value);
+ """
+
+ case expressions.Literal(value, dataType) =>
+ s"""
+ final boolean $nullTerm = false;
+ ${primitiveForType(dataType)} $primitiveTerm = $value;
+ """
+
+ case Cast(child @ BinaryType(), StringType) =>
+ child.castOrNull(c =>
+ s"new ${stringType}().set($c)",
+ StringType)
case Cast(child @ DateType(), StringType) =>
child.castOrNull(c =>
- q"""org.apache.spark.sql.types.UTF8String(
+ s"""new ${stringType}().set(
org.apache.spark.sql.catalyst.util.DateUtils.toString($c))""",
StringType)
- case Cast(child @ NumericType(), IntegerType) =>
- child.castOrNull(c => q"$c.toInt", IntegerType)
+ case Cast(child @ BooleanType(), dt: NumericType) if !dt.isInstanceOf[DecimalType] =>
+ child.castOrNull(c => s"(${primitiveForType(dt)})($c?1:0)", dt)
- case Cast(child @ NumericType(), LongType) =>
- child.castOrNull(c => q"$c.toLong", LongType)
+ case Cast(child @ DecimalType(), IntegerType) =>
+ child.castOrNull(c => s"($c).toInt()", IntegerType)
- case Cast(child @ NumericType(), DoubleType) =>
- child.castOrNull(c => q"$c.toDouble", DoubleType)
+ case Cast(child @ DecimalType(), dt: NumericType) if !dt.isInstanceOf[DecimalType] =>
+ child.castOrNull(c => s"($c).to${termForType(dt)}()", dt)
- case Cast(child @ NumericType(), FloatType) =>
- child.castOrNull(c => q"$c.toFloat", FloatType)
+ case Cast(child @ NumericType(), dt: NumericType) if !dt.isInstanceOf[DecimalType] =>
+ child.castOrNull(c => s"(${primitiveForType(dt)})($c)", dt)
// Special handling required for timestamps in hive test cases since the toString function
// does not match the expected output.
case Cast(e, StringType) if e.dataType != TimestampType =>
- val eval = expressionEvaluator(e)
- eval.code ++
- q"""
- val $nullTerm = ${eval.nullTerm}
- val $primitiveTerm =
- if($nullTerm)
- ${defaultPrimitive(StringType)}
- else
- org.apache.spark.sql.types.UTF8String(${eval.primitiveTerm}.toString)
- """.children
+ e.castOrNull(c =>
+ s"new ${stringType}().set(String.valueOf($c))",
+ StringType)
case EqualTo(e1 @ BinaryType(), e2 @ BinaryType()) =>
(e1, e2).evaluateAs (BooleanType) {
case (eval1, eval2) =>
- q"""
- java.util.Arrays.equals($eval1.asInstanceOf[Array[Byte]],
- $eval2.asInstanceOf[Array[Byte]])
- """
+ s"java.util.Arrays.equals((byte[])$eval1, (byte[])$eval2)"
}
case EqualTo(e1, e2) =>
- (e1, e2).evaluateAs (BooleanType) { case (eval1, eval2) => q"$eval1 == $eval2" }
-
- /* TODO: Fix null semantics.
- case In(e1, list) if !list.exists(!_.isInstanceOf[expressions.Literal]) =>
- val eval = expressionEvaluator(e1)
-
- val checks = list.map {
- case expressions.Literal(v: String, dataType) =>
- q"if(${eval.primitiveTerm} == $v) return true"
- case expressions.Literal(v: Int, dataType) =>
- q"if(${eval.primitiveTerm} == $v) return true"
- }
-
- val funcName = newTermName(s"isIn${curId.getAndIncrement()}")
-
- q"""
- def $funcName: Boolean = {
- ..${eval.code}
- if(${eval.nullTerm}) return false
- ..$checks
- return false
- }
- val $nullTerm = false
- val $primitiveTerm = $funcName
- """.children
- */
+ (e1, e2).evaluateAs (BooleanType) { case (eval1, eval2) => s"$eval1 == $eval2" }
case GreaterThan(e1 @ NumericType(), e2 @ NumericType()) =>
- (e1, e2).evaluateAs (BooleanType) { case (eval1, eval2) => q"$eval1 > $eval2" }
+ (e1, e2).evaluateAs (BooleanType) { case (eval1, eval2) => s"$eval1 > $eval2" }
case GreaterThanOrEqual(e1 @ NumericType(), e2 @ NumericType()) =>
- (e1, e2).evaluateAs (BooleanType) { case (eval1, eval2) => q"$eval1 >= $eval2" }
+ (e1, e2).evaluateAs (BooleanType) { case (eval1, eval2) => s"$eval1 >= $eval2" }
case LessThan(e1 @ NumericType(), e2 @ NumericType()) =>
- (e1, e2).evaluateAs (BooleanType) { case (eval1, eval2) => q"$eval1 < $eval2" }
+ (e1, e2).evaluateAs (BooleanType) { case (eval1, eval2) => s"$eval1 < $eval2" }
case LessThanOrEqual(e1 @ NumericType(), e2 @ NumericType()) =>
- (e1, e2).evaluateAs (BooleanType) { case (eval1, eval2) => q"$eval1 <= $eval2" }
+ (e1, e2).evaluateAs (BooleanType) { case (eval1, eval2) => s"$eval1 <= $eval2" }
case And(e1, e2) =>
- val eval1 = expressionEvaluator(e1)
- val eval2 = expressionEvaluator(e2)
-
- q"""
- ..${eval1.code}
- var $nullTerm = false
- var $primitiveTerm: ${termForType(BooleanType)} = false
-
- if (!${eval1.nullTerm} && ${eval1.primitiveTerm} == false) {
+ val eval1 = expressionEvaluator(e1, ctx)
+ val eval2 = expressionEvaluator(e2, ctx)
+ s"""
+ ${eval1.code}
+ boolean $nullTerm = false;
+ boolean $primitiveTerm = false;
+
+ if (!${eval1.nullTerm} && !${eval1.primitiveTerm}) {
} else {
- ..${eval2.code}
- if (!${eval2.nullTerm} && ${eval2.primitiveTerm} == false) {
+ ${eval2.code}
+ if (!${eval2.nullTerm} && !${eval2.primitiveTerm}) {
} else if (!${eval1.nullTerm} && !${eval2.nullTerm}) {
- $primitiveTerm = true
+ $primitiveTerm = true;
} else {
- $nullTerm = true
+ $nullTerm = true;
}
}
- """.children
+ """
case Or(e1, e2) =>
- val eval1 = expressionEvaluator(e1)
- val eval2 = expressionEvaluator(e2)
+ val eval1 = expressionEvaluator(e1, ctx)
+ val eval2 = expressionEvaluator(e2, ctx)
- q"""
- ..${eval1.code}
- var $nullTerm = false
- var $primitiveTerm: ${termForType(BooleanType)} = false
+ s"""
+ ${eval1.code}
+ boolean $nullTerm = false;
+ boolean $primitiveTerm = false;
if (!${eval1.nullTerm} && ${eval1.primitiveTerm}) {
- $primitiveTerm = true
+ $primitiveTerm = true;
} else {
- ..${eval2.code}
+ ${eval2.code}
if (!${eval2.nullTerm} && ${eval2.primitiveTerm}) {
- $primitiveTerm = true
+ $primitiveTerm = true;
} else if (!${eval1.nullTerm} && !${eval2.nullTerm}) {
- $primitiveTerm = false
+ $primitiveTerm = false;
} else {
- $nullTerm = true
+ $nullTerm = true;
}
}
- """.children
+ """
case Not(child) =>
// Uh, bad function name...
- child.castOrNull(c => q"!$c", BooleanType)
-
- case Add(e1, e2) => (e1, e2) evaluate { case (eval1, eval2) => q"$eval1 + $eval2" }
- case Subtract(e1, e2) => (e1, e2) evaluate { case (eval1, eval2) => q"$eval1 - $eval2" }
- case Multiply(e1, e2) => (e1, e2) evaluate { case (eval1, eval2) => q"$eval1 * $eval2" }
+ child.castOrNull(c => s"!$c", BooleanType)
+
+ case Add(e1 @ DecimalType(), e2 @ DecimalType()) =>
+ (e1, e2) evaluate { case (eval1, eval2) => s"$eval1.$$plus($eval2)" }
+ case Subtract(e1 @ DecimalType(), e2 @ DecimalType()) =>
+ (e1, e2) evaluate { case (eval1, eval2) => s"$eval1.$$minus($eval2)" }
+ case Multiply(e1 @ DecimalType(), e2 @ DecimalType()) =>
+ (e1, e2) evaluate { case (eval1, eval2) => s"$eval1.$$times($eval2)" }
+ case Divide(e1 @ DecimalType(), e2 @ DecimalType()) =>
+ val eval1 = expressionEvaluator(e1, ctx)
+ val eval2 = expressionEvaluator(e2, ctx)
+ eval1.code + eval2.code +
+ s"""
+ boolean $nullTerm = false;
+ ${primitiveForType(e1.dataType)} $primitiveTerm = null;
+ if (${eval1.nullTerm} || ${eval2.nullTerm} || ${eval2.primitiveTerm}.isZero()) {
+ $nullTerm = true;
+ } else {
+ $primitiveTerm = ${eval1.primitiveTerm}.$$div${eval2.primitiveTerm});
+ }
+ """
+ case Remainder(e1 @ DecimalType(), e2 @ DecimalType()) =>
+ val eval1 = expressionEvaluator(e1, ctx)
+ val eval2 = expressionEvaluator(e2, ctx)
+ eval1.code + eval2.code +
+ s"""
+ boolean $nullTerm = false;
+ ${primitiveForType(e1.dataType)} $primitiveTerm = 0;
+ if (${eval1.nullTerm} || ${eval2.nullTerm} || ${eval2.primitiveTerm}.isZero()) {
+ $nullTerm = true;
+ } else {
+ $primitiveTerm = ${eval1.primitiveTerm}.remainder(${eval2.primitiveTerm});
+ }
+ """
+
+ case Add(e1, e2) =>
+ (e1, e2) evaluate { case (eval1, eval2) => s"$eval1 + $eval2" }
+ case Subtract(e1, e2) =>
+ (e1, e2) evaluate { case (eval1, eval2) => s"$eval1 - $eval2" }
+ case Multiply(e1, e2) =>
+ (e1, e2) evaluate { case (eval1, eval2) => s"$eval1 * $eval2" }
case Divide(e1, e2) =>
- val eval1 = expressionEvaluator(e1)
- val eval2 = expressionEvaluator(e2)
-
- eval1.code ++ eval2.code ++
- q"""
- var $nullTerm = false
- var $primitiveTerm: ${termForType(e1.dataType)} = 0
-
- if (${eval1.nullTerm} || ${eval2.nullTerm} ) {
- $nullTerm = true
- } else if (${eval2.primitiveTerm} == 0)
- $nullTerm = true
- else {
- $primitiveTerm = ${eval1.primitiveTerm} / ${eval2.primitiveTerm}
+ val eval1 = expressionEvaluator(e1, ctx)
+ val eval2 = expressionEvaluator(e2, ctx)
+ eval1.code + eval2.code +
+ s"""
+ boolean $nullTerm = false;
+ ${primitiveForType(e1.dataType)} $primitiveTerm = 0;
+ if (${eval1.nullTerm} || ${eval2.nullTerm} || ${eval2.primitiveTerm} == 0) {
+ $nullTerm = true;
+ } else {
+ $primitiveTerm = ${eval1.primitiveTerm} / ${eval2.primitiveTerm};
}
- """.children
-
+ """
case Remainder(e1, e2) =>
- val eval1 = expressionEvaluator(e1)
- val eval2 = expressionEvaluator(e2)
-
- eval1.code ++ eval2.code ++
- q"""
- var $nullTerm = false
- var $primitiveTerm: ${termForType(e1.dataType)} = 0
-
- if (${eval1.nullTerm} || ${eval2.nullTerm} ) {
- $nullTerm = true
- } else if (${eval2.primitiveTerm} == 0)
- $nullTerm = true
- else {
- $nullTerm = false
- $primitiveTerm = ${eval1.primitiveTerm} % ${eval2.primitiveTerm}
+ val eval1 = expressionEvaluator(e1, ctx)
+ val eval2 = expressionEvaluator(e2, ctx)
+ eval1.code + eval2.code +
+ s"""
+ boolean $nullTerm = false;
+ ${primitiveForType(e1.dataType)} $primitiveTerm = 0;
+ if (${eval1.nullTerm} || ${eval2.nullTerm} || ${eval2.primitiveTerm} == 0) {
+ $nullTerm = true;
+ } else {
+ $primitiveTerm = ${eval1.primitiveTerm} % ${eval2.primitiveTerm};
}
- """.children
+ """
case IsNotNull(e) =>
- val eval = expressionEvaluator(e)
- q"""
- ..${eval.code}
- var $nullTerm = false
- var $primitiveTerm: ${termForType(BooleanType)} = !${eval.nullTerm}
- """.children
+ val eval = expressionEvaluator(e, ctx)
+ s"""
+ ${eval.code}
+ boolean $nullTerm = false;
+ boolean $primitiveTerm = !${eval.nullTerm};
+ """
case IsNull(e) =>
- val eval = expressionEvaluator(e)
- q"""
- ..${eval.code}
- var $nullTerm = false
- var $primitiveTerm: ${termForType(BooleanType)} = ${eval.nullTerm}
- """.children
-
- case c @ Coalesce(children) =>
- q"""
- var $nullTerm = true
- var $primitiveTerm: ${termForType(c.dataType)} = ${defaultPrimitive(c.dataType)}
- """.children ++
+ val eval = expressionEvaluator(e, ctx)
+ s"""
+ ${eval.code}
+ boolean $nullTerm = false;
+ boolean $primitiveTerm = ${eval.nullTerm};
+ """
+
+ case e @ Coalesce(children) =>
+ s"""
+ boolean $nullTerm = true;
+ ${primitiveForType(e.dataType)} $primitiveTerm = ${defaultPrimitive(e.dataType)};
+ """ +
children.map { c =>
- val eval = expressionEvaluator(c)
- q"""
+ val eval = expressionEvaluator(c, ctx)
+ s"""
if($nullTerm) {
- ..${eval.code}
+ ${eval.code}
if(!${eval.nullTerm}) {
- $nullTerm = false
- $primitiveTerm = ${eval.primitiveTerm}
+ $nullTerm = false;
+ $primitiveTerm = ${eval.primitiveTerm};
}
}
"""
- }
+ }.mkString("\n")
- case i @ expressions.If(condition, trueValue, falseValue) =>
- val condEval = expressionEvaluator(condition)
- val trueEval = expressionEvaluator(trueValue)
- val falseEval = expressionEvaluator(falseValue)
+ case e @ expressions.If(condition, trueValue, falseValue) =>
+ val condEval = expressionEvaluator(condition, ctx)
+ val trueEval = expressionEvaluator(trueValue, ctx)
+ val falseEval = expressionEvaluator(falseValue, ctx)
- q"""
- var $nullTerm = false
- var $primitiveTerm: ${termForType(i.dataType)} = ${defaultPrimitive(i.dataType)}
- ..${condEval.code}
+ s"""
+ boolean $nullTerm = false;
+ ${primitiveForType(e.dataType)} $primitiveTerm = ${defaultPrimitive(e.dataType)};
+ ${condEval.code}
if(!${condEval.nullTerm} && ${condEval.primitiveTerm}) {
- ..${trueEval.code}
- $nullTerm = ${trueEval.nullTerm}
- $primitiveTerm = ${trueEval.primitiveTerm}
+ ${trueEval.code}
+ $nullTerm = ${trueEval.nullTerm};
+ $primitiveTerm = ${trueEval.primitiveTerm};
} else {
- ..${falseEval.code}
- $nullTerm = ${falseEval.nullTerm}
- $primitiveTerm = ${falseEval.primitiveTerm}
+ ${falseEval.code}
+ $nullTerm = ${falseEval.nullTerm};
+ $primitiveTerm = ${falseEval.primitiveTerm};
}
- """.children
+ """
case NewSet(elementType) =>
- q"""
- val $nullTerm = false
- val $primitiveTerm = new ${hashSetForType(elementType)}()
- """.children
+ s"""
+ boolean $nullTerm = false;
+ ${hashSetForType(elementType)} $primitiveTerm = new ${hashSetForType(elementType)}();
+ """
case AddItemToSet(item, set) =>
- val itemEval = expressionEvaluator(item)
- val setEval = expressionEvaluator(set)
+ val itemEval = expressionEvaluator(item, ctx)
+ val setEval = expressionEvaluator(set, ctx)
val elementType = set.dataType.asInstanceOf[OpenHashSetUDT].elementType
+ val htype = hashSetForType(elementType)
- itemEval.code ++ setEval.code ++
- q"""
- if (!${itemEval.nullTerm}) {
- ${setEval.primitiveTerm}
- .asInstanceOf[${hashSetForType(elementType)}]
- .add(${itemEval.primitiveTerm})
+ itemEval.code + setEval.code +
+ s"""
+ if (!${itemEval.nullTerm} && !${setEval.nullTerm}) {
+ (($htype)${setEval.primitiveTerm}).add(${itemEval.primitiveTerm});
}
-
- val $nullTerm = false
- val $primitiveTerm = ${setEval.primitiveTerm}
- """.children
+ boolean $nullTerm = false;
+ ${htype} $primitiveTerm = ($htype)${setEval.primitiveTerm};
+ """
case CombineSets(left, right) =>
- val leftEval = expressionEvaluator(left)
- val rightEval = expressionEvaluator(right)
+ val leftEval = expressionEvaluator(left, ctx)
+ val rightEval = expressionEvaluator(right, ctx)
val elementType = left.dataType.asInstanceOf[OpenHashSetUDT].elementType
+ val htype = hashSetForType(elementType)
- leftEval.code ++ rightEval.code ++
- q"""
- val $nullTerm = false
- var $primitiveTerm: ${hashSetForType(elementType)} = null
-
- {
- val leftSet = ${leftEval.primitiveTerm}.asInstanceOf[${hashSetForType(elementType)}]
- val rightSet = ${rightEval.primitiveTerm}.asInstanceOf[${hashSetForType(elementType)}]
- val iterator = rightSet.iterator
- while (iterator.hasNext) {
- leftSet.add(iterator.next())
- }
- $primitiveTerm = leftSet
- }
- """.children
+ leftEval.code + rightEval.code +
+ s"""
+ boolean $nullTerm = false;
+ ${htype} $primitiveTerm =
+ (${htype})${leftEval.primitiveTerm};
+ $primitiveTerm.union((${htype})${rightEval.primitiveTerm});
+ """
- case MaxOf(e1, e2) =>
- val eval1 = expressionEvaluator(e1)
- val eval2 = expressionEvaluator(e2)
+ case MaxOf(e1, e2) if !e1.dataType.isInstanceOf[DecimalType] =>
+ val eval1 = expressionEvaluator(e1, ctx)
+ val eval2 = expressionEvaluator(e2, ctx)
- eval1.code ++ eval2.code ++
- q"""
- var $nullTerm = false
- var $primitiveTerm: ${termForType(e1.dataType)} = ${defaultPrimitive(e1.dataType)}
+ eval1.code + eval2.code +
+ s"""
+ boolean $nullTerm = false;
+ ${primitiveForType(e1.dataType)} $primitiveTerm = ${defaultPrimitive(e1.dataType)};
if (${eval1.nullTerm}) {
- $nullTerm = ${eval2.nullTerm}
- $primitiveTerm = ${eval2.primitiveTerm}
+ $nullTerm = ${eval2.nullTerm};
+ $primitiveTerm = ${eval2.primitiveTerm};
} else if (${eval2.nullTerm}) {
- $nullTerm = ${eval1.nullTerm}
- $primitiveTerm = ${eval1.primitiveTerm}
+ $nullTerm = ${eval1.nullTerm};
+ $primitiveTerm = ${eval1.primitiveTerm};
} else {
if (${eval1.primitiveTerm} > ${eval2.primitiveTerm}) {
- $primitiveTerm = ${eval1.primitiveTerm}
+ $primitiveTerm = ${eval1.primitiveTerm};
} else {
- $primitiveTerm = ${eval2.primitiveTerm}
+ $primitiveTerm = ${eval2.primitiveTerm};
}
}
- """.children
+ """
- case MinOf(e1, e2) =>
- val eval1 = expressionEvaluator(e1)
- val eval2 = expressionEvaluator(e2)
+ case MinOf(e1, e2) if !e1.dataType.isInstanceOf[DecimalType] =>
+ val eval1 = expressionEvaluator(e1, ctx)
+ val eval2 = expressionEvaluator(e2, ctx)
- eval1.code ++ eval2.code ++
- q"""
- var $nullTerm = false
- var $primitiveTerm: ${termForType(e1.dataType)} = ${defaultPrimitive(e1.dataType)}
+ eval1.code + eval2.code +
+ s"""
+ boolean $nullTerm = false;
+ ${primitiveForType(e1.dataType)} $primitiveTerm = ${defaultPrimitive(e1.dataType)};
if (${eval1.nullTerm}) {
- $nullTerm = ${eval2.nullTerm}
- $primitiveTerm = ${eval2.primitiveTerm}
+ $nullTerm = ${eval2.nullTerm};
+ $primitiveTerm = ${eval2.primitiveTerm};
} else if (${eval2.nullTerm}) {
- $nullTerm = ${eval1.nullTerm}
- $primitiveTerm = ${eval1.primitiveTerm}
+ $nullTerm = ${eval1.nullTerm};
+ $primitiveTerm = ${eval1.primitiveTerm};
} else {
if (${eval1.primitiveTerm} < ${eval2.primitiveTerm}) {
- $primitiveTerm = ${eval1.primitiveTerm}
+ $primitiveTerm = ${eval1.primitiveTerm};
} else {
- $primitiveTerm = ${eval2.primitiveTerm}
+ $primitiveTerm = ${eval2.primitiveTerm};
}
}
- """.children
+ """
case UnscaledValue(child) =>
- val childEval = expressionEvaluator(child)
-
- childEval.code ++
- q"""
- var $nullTerm = ${childEval.nullTerm}
- var $primitiveTerm: Long = if (!$nullTerm) {
- ${childEval.primitiveTerm}.toUnscaledLong
- } else {
- ${defaultPrimitive(LongType)}
- }
- """.children
+ val childEval = expressionEvaluator(child, ctx)
+
+ childEval.code +
+ s"""
+ boolean $nullTerm = ${childEval.nullTerm};
+ long $primitiveTerm = $nullTerm ? -1 : ${childEval.primitiveTerm}.toUnscaledLong();
+ """
case MakeDecimal(child, precision, scale) =>
- val childEval = expressionEvaluator(child)
+ val eval = expressionEvaluator(child, ctx)
- childEval.code ++
- q"""
- var $nullTerm = ${childEval.nullTerm}
- var $primitiveTerm: org.apache.spark.sql.types.Decimal =
- ${defaultPrimitive(DecimalType())}
+ eval.code +
+ s"""
+ boolean $nullTerm = ${eval.nullTerm};
+ org.apache.spark.sql.types.Decimal $primitiveTerm = ${defaultPrimitive(DecimalType())};
if (!$nullTerm) {
- $primitiveTerm = new org.apache.spark.sql.types.Decimal()
- $primitiveTerm = $primitiveTerm.setOrNull(${childEval.primitiveTerm}, $precision, $scale)
- $nullTerm = $primitiveTerm == null
+ $primitiveTerm = new org.apache.spark.sql.types.Decimal();
+ $primitiveTerm = $primitiveTerm.setOrNull(${eval.primitiveTerm}, $precision, $scale);
+ $nullTerm = $primitiveTerm == null;
}
- """.children
+ """
}
// If there was no match in the partial function above, we fall back on calling the interpreted
// expression evaluator.
- val code: Seq[Tree] =
+ val code: String =
primitiveEvaluation.lift.apply(e).getOrElse {
- log.debug(s"No rules to generate $e")
- val tree = reify { e }
- q"""
- val $objectTerm = $tree.eval(i)
- val $nullTerm = $objectTerm == null
- val $primitiveTerm = $objectTerm.asInstanceOf[${termForType(e.dataType)}]
- """.children
- }
-
- // Only inject debugging code if debugging is turned on.
- val debugCode =
- if (debugLogging) {
- val localLogger = log
- val localLoggerTree = reify { localLogger }
- q"""
- $localLoggerTree.debug(
- ${e.toString} + ": " + (if ($nullTerm) "null" else $primitiveTerm.toString))
- """ :: Nil
- } else {
- Nil
+ logError(s"No rules to generate $e")
+ ctx.references += e
+ s"""
+ /* expression: ${e} */
+ Object $objectTerm = expressions[${ctx.references.size - 1}].eval(i);
+ boolean $nullTerm = $objectTerm == null;
+ ${primitiveForType(e.dataType)} $primitiveTerm = ${defaultPrimitive(e.dataType)};
+ if (!$nullTerm) $primitiveTerm = (${termForType(e.dataType)})$objectTerm;
+ """
}
- EvaluatedExpression(code ++ debugCode, nullTerm, primitiveTerm, objectTerm)
+ EvaluatedExpression(code, nullTerm, primitiveTerm, objectTerm)
}
- protected def getColumn(inputRow: TermName, dataType: DataType, ordinal: Int) = {
+ protected def getColumn(inputRow: String, dataType: DataType, ordinal: Int) = {
dataType match {
- case StringType => q"$inputRow($ordinal).asInstanceOf[org.apache.spark.sql.types.UTF8String]"
- case dt: DataType if isNativeType(dt) => q"$inputRow.${accessorForType(dt)}($ordinal)"
- case _ => q"$inputRow.apply($ordinal).asInstanceOf[${termForType(dataType)}]"
+ case StringType => s"(${stringType})$inputRow.apply($ordinal)"
+ case dt: DataType if isNativeType(dt) => s"$inputRow.${accessorForType(dt)}($ordinal)"
+ case _ => s"(${termForType(dataType)})$inputRow.apply($ordinal)"
}
}
protected def setColumn(
- destinationRow: TermName,
+ destinationRow: String,
dataType: DataType,
ordinal: Int,
- value: TermName) = {
+ value: String): String = {
dataType match {
- case StringType => q"$destinationRow.update($ordinal, $value)"
+ case StringType => s"$destinationRow.update($ordinal, $value)"
case dt: DataType if isNativeType(dt) =>
- q"$destinationRow.${mutatorForType(dt)}($ordinal, $value)"
- case _ => q"$destinationRow.update($ordinal, $value)"
+ s"$destinationRow.${mutatorForType(dt)}($ordinal, $value)"
+ case _ => s"$destinationRow.update($ordinal, $value)"
}
}
- protected def accessorForType(dt: DataType) = newTermName(s"get${primitiveForType(dt)}")
- protected def mutatorForType(dt: DataType) = newTermName(s"set${primitiveForType(dt)}")
+ protected def accessorForType(dt: DataType) = dt match {
+ case IntegerType => "getInt"
+ case other => s"get${termForType(dt)}"
+ }
+
+ protected def mutatorForType(dt: DataType) = dt match {
+ case IntegerType => "setInt"
+ case other => s"set${termForType(dt)}"
+ }
- protected def hashSetForType(dt: DataType) = dt match {
- case IntegerType => typeOf[IntegerHashSet]
- case LongType => typeOf[LongHashSet]
+ protected def hashSetForType(dt: DataType): String = dt match {
+ case IntegerType => classOf[IntegerHashSet].getName
+ case LongType => classOf[LongHashSet].getName
case unsupportedType =>
sys.error(s"Code generation not support for hashset of type $unsupportedType")
}
- protected def primitiveForType(dt: DataType) = dt match {
- case IntegerType => "Int"
+ protected def primitiveForType(dt: DataType): String = dt match {
+ case IntegerType => "int"
+ case LongType => "long"
+ case ShortType => "short"
+ case ByteType => "byte"
+ case DoubleType => "double"
+ case FloatType => "float"
+ case BooleanType => "boolean"
+ case dt: DecimalType => decimalType
+ case BinaryType => "byte[]"
+ case StringType => stringType
+ case DateType => "int"
+ case TimestampType => "java.sql.Timestamp"
+ case _ => "Object"
+ }
+
+ protected def defaultPrimitive(dt: DataType): String = dt match {
+ case BooleanType => "false"
+ case FloatType => "-1.0f"
+ case ShortType => "-1"
+ case LongType => "-1"
+ case ByteType => "-1"
+ case DoubleType => "-1.0"
+ case IntegerType => "-1"
+ case DateType => "-1"
+ case dt: DecimalType => "null"
+ case StringType => "null"
+ case _ => "null"
+ }
+
+ protected def termForType(dt: DataType): String = dt match {
+ case IntegerType => "Integer"
case LongType => "Long"
case ShortType => "Short"
case ByteType => "Byte"
case DoubleType => "Double"
case FloatType => "Float"
case BooleanType => "Boolean"
- case StringType => "org.apache.spark.sql.types.UTF8String"
- }
-
- protected def defaultPrimitive(dt: DataType) = dt match {
- case BooleanType => ru.Literal(Constant(false))
- case FloatType => ru.Literal(Constant(-1.0.toFloat))
- case StringType => q"""org.apache.spark.sql.types.UTF8String("<uninit>")"""
- case ShortType => ru.Literal(Constant(-1.toShort))
- case LongType => ru.Literal(Constant(-1L))
- case ByteType => ru.Literal(Constant(-1.toByte))
- case DoubleType => ru.Literal(Constant(-1.toDouble))
- case DecimalType() => q"org.apache.spark.sql.types.Decimal(-1)"
- case IntegerType => ru.Literal(Constant(-1))
- case DateType => ru.Literal(Constant(-1))
- case _ => ru.Literal(Constant(null))
- }
-
- protected def termForType(dt: DataType) = dt match {
- case n: AtomicType => n.tag
- case _ => typeTag[Any]
+ case dt: DecimalType => decimalType
+ case BinaryType => "byte[]"
+ case StringType => stringType
+ case DateType => "Integer"
+ case TimestampType => "java.sql.Timestamp"
+ case _ => "Object"
}
/**
* List of data types that have special accessors and setters in [[Row]].
*/
protected val nativeTypes =
- Seq(IntegerType, BooleanType, LongType, DoubleType, FloatType, ShortType, ByteType, StringType)
+ Seq(IntegerType, BooleanType, LongType, DoubleType, FloatType, ShortType, ByteType)
/**
* Returns true if the data type has a special accessor and setter in [[Row]].
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala
index 840260703a..638b53fe0f 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala
@@ -19,15 +19,14 @@ package org.apache.spark.sql.catalyst.expressions.codegen
import org.apache.spark.sql.catalyst.expressions._
+// MutableProjection is not accessible in Java
+abstract class BaseMutableProjection extends MutableProjection {}
+
/**
* Generates byte code that produces a [[MutableRow]] object that can update itself based on a new
* input [[Row]] for a fixed set of [[Expression Expressions]].
*/
object GenerateMutableProjection extends CodeGenerator[Seq[Expression], () => MutableProjection] {
- import scala.reflect.runtime.{universe => ru}
- import scala.reflect.runtime.universe._
-
- val mutableRowName = newTermName("mutableRow")
protected def canonicalize(in: Seq[Expression]): Seq[Expression] =
in.map(ExpressionCanonicalizer.execute)
@@ -36,41 +35,61 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], () => Mu
in.map(BindReferences.bindReference(_, inputSchema))
protected def create(expressions: Seq[Expression]): (() => MutableProjection) = {
- val projectionCode = expressions.zipWithIndex.flatMap { case (e, i) =>
- val evaluationCode = expressionEvaluator(e)
-
- evaluationCode.code :+
- q"""
- if(${evaluationCode.nullTerm})
- mutableRow.setNullAt($i)
- else
- ${setColumn(mutableRowName, e.dataType, i, evaluationCode.primitiveTerm)}
- """
- }
+ val ctx = newCodeGenContext()
+ val projectionCode = expressions.zipWithIndex.map { case (e, i) =>
+ val evaluationCode = expressionEvaluator(e, ctx)
+ evaluationCode.code +
+ s"""
+ if(${evaluationCode.nullTerm})
+ mutableRow.setNullAt($i);
+ else
+ ${setColumn("mutableRow", e.dataType, i, evaluationCode.primitiveTerm)};
+ """
+ }.mkString("\n")
+ val code = s"""
+ import org.apache.spark.sql.Row;
+
+ public SpecificProjection generate($exprType[] expr) {
+ return new SpecificProjection(expr);
+ }
+
+ class SpecificProjection extends ${classOf[BaseMutableProjection].getName} {
- val code =
- q"""
- () => { new $mutableProjectionType {
+ private $exprType[] expressions = null;
+ private $mutableRowType mutableRow = null;
- private[this] var $mutableRowName: $mutableRowType =
- new $genericMutableRowType(${expressions.size})
+ public SpecificProjection($exprType[] expr) {
+ expressions = expr;
+ mutableRow = new $genericMutableRowType(${expressions.size});
+ }
- def target(row: $mutableRowType): $mutableProjectionType = {
- $mutableRowName = row
- this
- }
+ public ${classOf[BaseMutableProjection].getName} target($mutableRowType row) {
+ mutableRow = row;
+ return this;
+ }
- /* Provide immutable access to the last projected row. */
- def currentValue: $rowType = mutableRow
+ /* Provide immutable access to the last projected row. */
+ public Row currentValue() {
+ return mutableRow;
+ }
- def apply(i: $rowType): $rowType = {
- ..$projectionCode
- mutableRow
- }
- } }
- """
+ public Object apply(Object _i) {
+ Row i = (Row) _i;
+ $projectionCode
- log.debug(s"code for ${expressions.mkString(",")}:\n$code")
- toolBox.eval(code).asInstanceOf[() => MutableProjection]
+ return mutableRow;
+ }
+ }
+ """
+
+
+ logDebug(s"code for ${expressions.mkString(",")}:\n$code")
+
+ val c = compile(code)
+ // fetch the only one method `generate(Expression[])`
+ val m = c.getDeclaredMethods()(0)
+ () => {
+ m.invoke(c.newInstance(), ctx.references.toArray).asInstanceOf[BaseMutableProjection]
+ }
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala
index b129c0d898..0ff840dab3 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala
@@ -18,18 +18,29 @@
package org.apache.spark.sql.catalyst.expressions.codegen
import org.apache.spark.Logging
+import org.apache.spark.annotation.Private
+import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.types.{BinaryType, StringType, NumericType}
+import org.apache.spark.sql.types.{BinaryType, NumericType}
+
+/**
+ * Inherits some default implementation for Java from `Ordering[Row]`
+ */
+@Private
+class BaseOrdering extends Ordering[Row] {
+ def compare(a: Row, b: Row): Int = {
+ throw new UnsupportedOperationException
+ }
+}
/**
* Generates bytecode for an [[Ordering]] of [[Row Rows]] for a given set of
* [[Expression Expressions]].
*/
object GenerateOrdering extends CodeGenerator[Seq[SortOrder], Ordering[Row]] with Logging {
- import scala.reflect.runtime.{universe => ru}
import scala.reflect.runtime.universe._
- protected def canonicalize(in: Seq[SortOrder]): Seq[SortOrder] =
+ protected def canonicalize(in: Seq[SortOrder]): Seq[SortOrder] =
in.map(ExpressionCanonicalizer.execute(_).asInstanceOf[SortOrder])
protected def bind(in: Seq[SortOrder], inputSchema: Seq[Attribute]): Seq[SortOrder] =
@@ -38,73 +49,90 @@ object GenerateOrdering extends CodeGenerator[Seq[SortOrder], Ordering[Row]] wit
protected def create(ordering: Seq[SortOrder]): Ordering[Row] = {
val a = newTermName("a")
val b = newTermName("b")
- val comparisons = ordering.zipWithIndex.map { case (order, i) =>
- val evalA = expressionEvaluator(order.child)
- val evalB = expressionEvaluator(order.child)
+ val ctx = newCodeGenContext()
+ val comparisons = ordering.zipWithIndex.map { case (order, i) =>
+ val evalA = expressionEvaluator(order.child, ctx)
+ val evalB = expressionEvaluator(order.child, ctx)
+ val asc = order.direction == Ascending
val compare = order.child.dataType match {
case BinaryType =>
- q"""
- val x = ${if (order.direction == Ascending) evalA.primitiveTerm else evalB.primitiveTerm}
- val y = ${if (order.direction != Ascending) evalB.primitiveTerm else evalA.primitiveTerm}
- var i = 0
- while (i < x.length && i < y.length) {
- val res = x(i).compareTo(y(i))
- if (res != 0) return res
- i = i+1
- }
- return x.length - y.length
- """
+ s"""
+ {
+ byte[] x = ${if (asc) evalA.primitiveTerm else evalB.primitiveTerm};
+ byte[] y = ${if (!asc) evalB.primitiveTerm else evalA.primitiveTerm};
+ int j = 0;
+ while (j < x.length && j < y.length) {
+ if (x[j] != y[j]) return x[j] - y[j];
+ j = j + 1;
+ }
+ int d = x.length - y.length;
+ if (d != 0) {
+ return d;
+ }
+ }"""
case _: NumericType =>
- q"""
- val comp = ${evalA.primitiveTerm} - ${evalB.primitiveTerm}
- if(comp != 0) {
- return ${if (order.direction == Ascending) q"comp.toInt" else q"-comp.toInt"}
- }
- """
- case StringType =>
- if (order.direction == Ascending) {
- q"""return ${evalA.primitiveTerm}.compare(${evalB.primitiveTerm})"""
+ s"""
+ if (${evalA.primitiveTerm} != ${evalB.primitiveTerm}) {
+ if (${evalA.primitiveTerm} > ${evalB.primitiveTerm}) {
+ return ${if (asc) "1" else "-1"};
+ } else {
+ return ${if (asc) "-1" else "1"};
+ }
+ }"""
+ case _ =>
+ s"""
+ int comp = ${evalA.primitiveTerm}.compare(${evalB.primitiveTerm});
+ if (comp != 0) {
+ return ${if (asc) "comp" else "-comp"};
+ }"""
+ }
+
+ s"""
+ i = $a;
+ ${evalA.code}
+ i = $b;
+ ${evalB.code}
+ if (${evalA.nullTerm} && ${evalB.nullTerm}) {
+ // Nothing
+ } else if (${evalA.nullTerm}) {
+ return ${if (order.direction == Ascending) "-1" else "1"};
+ } else if (${evalB.nullTerm}) {
+ return ${if (order.direction == Ascending) "1" else "-1"};
} else {
- q"""return ${evalB.primitiveTerm}.compare(${evalA.primitiveTerm})"""
+ $compare
}
+ """
+ }.mkString("\n")
+
+ val code = s"""
+ import org.apache.spark.sql.Row;
+
+ public SpecificOrdering generate($exprType[] expr) {
+ return new SpecificOrdering(expr);
}
- q"""
- i = $a
- ..${evalA.code}
- i = $b
- ..${evalB.code}
- if (${evalA.nullTerm} && ${evalB.nullTerm}) {
- // Nothing
- } else if (${evalA.nullTerm}) {
- return ${if (order.direction == Ascending) q"-1" else q"1"}
- } else if (${evalB.nullTerm}) {
- return ${if (order.direction == Ascending) q"1" else q"-1"}
- } else {
- $compare
+ class SpecificOrdering extends ${typeOf[BaseOrdering]} {
+
+ private $exprType[] expressions = null;
+
+ public SpecificOrdering($exprType[] expr) {
+ expressions = expr;
}
- """
- }
- val q"class $orderingName extends $orderingType { ..$body }" = reify {
- class SpecificOrdering extends Ordering[Row] {
- val o = ordering
- }
- }.tree.children.head
-
- val code = q"""
- class $orderingName extends $orderingType {
- ..$body
- def compare(a: $rowType, b: $rowType): Int = {
- var i: $rowType = null // Holds current row being evaluated.
- ..$comparisons
- return 0
+ @Override
+ public int compare(Row a, Row b) {
+ Row i = null; // Holds current row being evaluated.
+ $comparisons
+ return 0;
}
- }
- new $orderingName()
- """
+ }"""
+
logDebug(s"Generated Ordering: $code")
- toolBox.eval(code).asInstanceOf[Ordering[Row]]
+
+ val c = compile(code)
+ // fetch the only one method `generate(Expression[])`
+ val m = c.getDeclaredMethods()(0)
+ m.invoke(c.newInstance(), ctx.references.toArray).asInstanceOf[BaseOrdering]
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratePredicate.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratePredicate.scala
index 40e1630243..fb18769f00 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratePredicate.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratePredicate.scala
@@ -20,11 +20,16 @@ package org.apache.spark.sql.catalyst.expressions.codegen
import org.apache.spark.sql.catalyst.expressions._
/**
+ * Interface for generated predicate
+ */
+abstract class Predicate {
+ def eval(r: Row): Boolean
+}
+
+/**
* Generates bytecode that evaluates a boolean [[Expression]] on a given input [[Row]].
*/
object GeneratePredicate extends CodeGenerator[Expression, (Row) => Boolean] {
- import scala.reflect.runtime.{universe => ru}
- import scala.reflect.runtime.universe._
protected def canonicalize(in: Expression): Expression = ExpressionCanonicalizer.execute(in)
@@ -32,17 +37,34 @@ object GeneratePredicate extends CodeGenerator[Expression, (Row) => Boolean] {
BindReferences.bindReference(in, inputSchema)
protected def create(predicate: Expression): ((Row) => Boolean) = {
- val cEval = expressionEvaluator(predicate)
+ val ctx = newCodeGenContext()
+ val eval = expressionEvaluator(predicate, ctx)
+ val code = s"""
+ import org.apache.spark.sql.Row;
- val code =
- q"""
- (i: $rowType) => {
- ..${cEval.code}
- if (${cEval.nullTerm}) false else ${cEval.primitiveTerm}
+ public SpecificPredicate generate($exprType[] expr) {
+ return new SpecificPredicate(expr);
+ }
+
+ class SpecificPredicate extends ${classOf[Predicate].getName} {
+ private final $exprType[] expressions;
+ public SpecificPredicate($exprType[] expr) {
+ expressions = expr;
+ }
+
+ @Override
+ public boolean eval(Row i) {
+ ${eval.code}
+ return !${eval.nullTerm} && ${eval.primitiveTerm};
}
- """
+ }"""
+
+ logDebug(s"Generated predicate '$predicate':\n$code")
- log.debug(s"Generated predicate '$predicate':\n$code")
- toolBox.eval(code).asInstanceOf[Row => Boolean]
+ val c = compile(code)
+ // fetch the only one method `generate(Expression[])`
+ val m = c.getDeclaredMethods()(0)
+ val p = m.invoke(c.newInstance(), ctx.references.toArray).asInstanceOf[Predicate]
+ (r: Row) => p.eval(r)
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala
index 31c63a79eb..d5be1fc12e 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala
@@ -17,9 +17,14 @@
package org.apache.spark.sql.catalyst.expressions.codegen
+import org.apache.spark.sql.BaseMutableRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.types._
+/**
+ * Java can not access Projection (in package object)
+ */
+abstract class BaseProject extends Projection {}
/**
* Generates bytecode that produces a new [[Row]] object based on a fixed set of input
@@ -27,7 +32,6 @@ import org.apache.spark.sql.types._
* generated based on the output types of the [[Expression]] to avoid boxing of primitive values.
*/
object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] {
- import scala.reflect.runtime.{universe => ru}
import scala.reflect.runtime.universe._
protected def canonicalize(in: Seq[Expression]): Seq[Expression] =
@@ -38,201 +42,183 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] {
// Make Mutablility optional...
protected def create(expressions: Seq[Expression]): Projection = {
- val tupleLength = ru.Literal(Constant(expressions.length))
- val lengthDef = q"final val length = $tupleLength"
-
- /* TODO: Configurable...
- val nullFunctions =
- q"""
- private final val nullSet = new org.apache.spark.util.collection.BitSet(length)
- final def setNullAt(i: Int) = nullSet.set(i)
- final def isNullAt(i: Int) = nullSet.get(i)
- """
- */
-
- val nullFunctions =
- q"""
- private[this] var nullBits = new Array[Boolean](${expressions.size})
- override def setNullAt(i: Int) = { nullBits(i) = true }
- override def isNullAt(i: Int) = nullBits(i)
- """.children
-
- val tupleElements = expressions.zipWithIndex.flatMap {
+ val ctx = newCodeGenContext()
+ val columns = expressions.zipWithIndex.map {
case (e, i) =>
- val elementName = newTermName(s"c$i")
- val evaluatedExpression = expressionEvaluator(e)
- val iLit = ru.Literal(Constant(i))
+ s"private ${primitiveForType(e.dataType)} c$i = ${defaultPrimitive(e.dataType)};\n"
+ }.mkString("\n ")
- q"""
- var ${newTermName(s"c$i")}: ${termForType(e.dataType)} = _
+ val initColumns = expressions.zipWithIndex.map {
+ case (e, i) =>
+ val eval = expressionEvaluator(e, ctx)
+ s"""
{
- ..${evaluatedExpression.code}
- if(${evaluatedExpression.nullTerm})
- setNullAt($iLit)
- else {
- nullBits($iLit) = false
- $elementName = ${evaluatedExpression.primitiveTerm}
+ // column$i
+ ${eval.code}
+ nullBits[$i] = ${eval.nullTerm};
+ if(!${eval.nullTerm}) {
+ c$i = ${eval.primitiveTerm};
}
}
- """.children : Seq[Tree]
- }
+ """
+ }.mkString("\n")
- val accessorFailure = q"""scala.sys.error("Invalid ordinal:" + i)"""
- val applyFunction = {
- val cases = (0 until expressions.size).map { i =>
- val ordinal = ru.Literal(Constant(i))
- val elementName = newTermName(s"c$i")
- val iLit = ru.Literal(Constant(i))
+ val getCases = (0 until expressions.size).map { i =>
+ s"case $i: return c$i;"
+ }.mkString("\n ")
- q"if(i == $ordinal) { if(isNullAt($i)) return null else return $elementName }"
- }
- q"override def apply(i: Int): Any = { ..$cases; $accessorFailure }"
- }
-
- val updateFunction = {
- val cases = expressions.zipWithIndex.map {case (e, i) =>
- val ordinal = ru.Literal(Constant(i))
- val elementName = newTermName(s"c$i")
- val iLit = ru.Literal(Constant(i))
-
- q"""
- if(i == $ordinal) {
- if(value == null) {
- setNullAt(i)
- } else {
- nullBits(i) = false
- $elementName = value.asInstanceOf[${termForType(e.dataType)}]
- }
- return
- }"""
- }
- q"override def update(i: Int, value: Any): Unit = { ..$cases; $accessorFailure }"
- }
+ val updateCases = expressions.zipWithIndex.map { case (e, i) =>
+ s"case $i: { c$i = (${termForType(e.dataType)})value; return;}"
+ }.mkString("\n ")
val specificAccessorFunctions = nativeTypes.map { dataType =>
- val ifStatements = expressions.zipWithIndex.flatMap {
- // getString() is not used by expressions
- case (e, i) if e.dataType == dataType && dataType != StringType =>
- val elementName = newTermName(s"c$i")
- // TODO: The string of ifs gets pretty inefficient as the row grows in size.
- // TODO: Optional null checks?
- q"if(i == $i) return $elementName" :: Nil
- case _ => Nil
- }
- dataType match {
- // Row() need this interface to compile
- case StringType =>
- q"""
- override def getString(i: Int): String = {
- $accessorFailure
- }"""
- case other =>
- q"""
- override def ${accessorForType(dataType)}(i: Int): ${termForType(dataType)} = {
- ..$ifStatements;
- $accessorFailure
- }"""
+ val cases = expressions.zipWithIndex.map {
+ case (e, i) if e.dataType == dataType =>
+ s"case $i: return c$i;"
+ case _ => ""
+ }.mkString("\n ")
+ if (cases.count(_ != '\n') > 0) {
+ s"""
+ @Override
+ public ${primitiveForType(dataType)} ${accessorForType(dataType)}(int i) {
+ if (isNullAt(i)) {
+ return ${defaultPrimitive(dataType)};
+ }
+ switch (i) {
+ $cases
+ }
+ return ${defaultPrimitive(dataType)};
+ }"""
+ } else {
+ ""
}
- }
+ }.mkString("\n")
val specificMutatorFunctions = nativeTypes.map { dataType =>
- val ifStatements = expressions.zipWithIndex.flatMap {
- // setString() is not used by expressions
- case (e, i) if e.dataType == dataType && dataType != StringType =>
- val elementName = newTermName(s"c$i")
- // TODO: The string of ifs gets pretty inefficient as the row grows in size.
- // TODO: Optional null checks?
- q"if(i == $i) { nullBits($i) = false; $elementName = value; return }" :: Nil
- case _ => Nil
- }
- dataType match {
- case StringType =>
- // MutableRow() need this interface to compile
- q"""
- override def setString(i: Int, value: String) {
- $accessorFailure
- }"""
- case other =>
- q"""
- override def ${mutatorForType(dataType)}(i: Int, value: ${termForType(dataType)}) {
- ..$ifStatements;
- $accessorFailure
- }"""
+ val cases = expressions.zipWithIndex.map {
+ case (e, i) if e.dataType == dataType =>
+ s"case $i: { c$i = value; return; }"
+ case _ => ""
+ }.mkString("\n")
+ if (cases.count(_ != '\n') > 0) {
+ s"""
+ @Override
+ public void ${mutatorForType(dataType)}(int i, ${primitiveForType(dataType)} value) {
+ nullBits[i] = false;
+ switch (i) {
+ $cases
+ }
+ }"""
+ } else {
+ ""
}
- }
+ }.mkString("\n")
val hashValues = expressions.zipWithIndex.map { case (e, i) =>
- val elementName = newTermName(s"c$i")
+ val col = newTermName(s"c$i")
val nonNull = e.dataType match {
- case BooleanType => q"if ($elementName) 0 else 1"
- case ByteType | ShortType | IntegerType => q"$elementName.toInt"
- case LongType => q"($elementName ^ ($elementName >>> 32)).toInt"
- case FloatType => q"java.lang.Float.floatToIntBits($elementName)"
+ case BooleanType => s"$col ? 0 : 1"
+ case ByteType | ShortType | IntegerType | DateType => s"$col"
+ case LongType => s"$col ^ ($col >>> 32)"
+ case FloatType => s"Float.floatToIntBits($col)"
case DoubleType =>
- q"{ val b = java.lang.Double.doubleToLongBits($elementName); (b ^ (b >>>32)).toInt }"
- case _ => q"$elementName.hashCode"
+ s"Double.doubleToLongBits($col) ^ (Double.doubleToLongBits($col) >>> 32)"
+ case _ => s"$col.hashCode()"
}
- q"if (isNullAt($i)) 0 else $nonNull"
+ s"isNullAt($i) ? 0 : ($nonNull)"
}
- val hashUpdates: Seq[Tree] = hashValues.map(v => q"""result = 37 * result + $v""": Tree)
+ val hashUpdates: String = hashValues.map( v =>
+ s"""
+ result *= 37; result += $v;"""
+ ).mkString("\n")
- val hashCodeFunction =
- q"""
- override def hashCode(): Int = {
- var result: Int = 37
- ..$hashUpdates
- result
- }
+ val columnChecks = expressions.zipWithIndex.map { case (e, i) =>
+ s"""
+ if (isNullAt($i) != row.isNullAt($i) || !isNullAt($i) && !get($i).equals(row.get($i))) {
+ return false;
+ }
"""
+ }.mkString("\n")
- val columnChecks = (0 until expressions.size).map { i =>
- val elementName = newTermName(s"c$i")
- q"if (this.$elementName != specificType.$elementName) return false"
+ val code = s"""
+ import org.apache.spark.sql.Row;
+
+ public SpecificProjection generate($exprType[] expr) {
+ return new SpecificProjection(expr);
}
- val equalsFunction =
- q"""
- override def equals(other: Any): Boolean = other match {
- case specificType: SpecificRow =>
- ..$columnChecks
- return true
- case other => super.equals(other)
- }
- """
+ class SpecificProjection extends ${typeOf[BaseProject]} {
+ private $exprType[] expressions = null;
+
+ public SpecificProjection($exprType[] expr) {
+ expressions = expr;
+ }
- val allColumns = (0 until expressions.size).map { i =>
- val iLit = ru.Literal(Constant(i))
- q"if(isNullAt($iLit)) { null } else { ${newTermName(s"c$i")} }"
+ @Override
+ public Object apply(Object r) {
+ return new SpecificRow(expressions, (Row) r);
+ }
}
- val copyFunction =
- q"override def copy() = new $genericRowType(Array[Any](..$allColumns))"
-
- val toSeqFunction =
- q"override def toSeq: Seq[Any] = Seq(..$allColumns)"
-
- val classBody =
- nullFunctions ++ (
- lengthDef +:
- applyFunction +:
- updateFunction +:
- equalsFunction +:
- hashCodeFunction +:
- copyFunction +:
- toSeqFunction +:
- (tupleElements ++ specificAccessorFunctions ++ specificMutatorFunctions))
-
- val code = q"""
- final class SpecificRow(i: $rowType) extends $mutableRowType {
- ..$classBody
+ final class SpecificRow extends ${typeOf[BaseMutableRow]} {
+
+ $columns
+
+ public SpecificRow($exprType[] expressions, Row i) {
+ $initColumns
+ }
+
+ public int size() { return ${expressions.length};}
+ private boolean[] nullBits = new boolean[${expressions.length}];
+ public void setNullAt(int i) { nullBits[i] = true; }
+ public boolean isNullAt(int i) { return nullBits[i]; }
+
+ public Object get(int i) {
+ if (isNullAt(i)) return null;
+ switch (i) {
+ $getCases
+ }
+ return null;
+ }
+ public void update(int i, Object value) {
+ if (value == null) {
+ setNullAt(i);
+ return;
+ }
+ nullBits[i] = false;
+ switch (i) {
+ $updateCases
+ }
+ }
+ $specificAccessorFunctions
+ $specificMutatorFunctions
+
+ @Override
+ public int hashCode() {
+ int result = 37;
+ $hashUpdates
+ return result;
}
- new $projectionType { def apply(r: $rowType) = new SpecificRow(r) }
+ @Override
+ public boolean equals(Object other) {
+ if (other instanceof Row) {
+ Row row = (Row) other;
+ if (row.length() != size()) return false;
+ $columnChecks
+ return true;
+ }
+ return super.equals(other);
+ }
+ }
"""
- log.debug(
- s"MutableRow, initExprs: ${expressions.mkString(",")} code:\n${toolBox.typeCheck(code)}")
- toolBox.eval(code).asInstanceOf[Projection]
+ logDebug(s"MutableRow, initExprs: ${expressions.mkString(",")} code:\n${code}")
+
+ val c = compile(code)
+ // fetch the only one method `generate(Expression[])`
+ val m = c.getDeclaredMethods()(0)
+ m.invoke(c.newInstance(), ctx.references.toArray).asInstanceOf[Projection]
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/package.scala
index 528e38a50a..7f1b12cdd5 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/package.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/package.scala
@@ -27,12 +27,6 @@ import org.apache.spark.util.Utils
*/
package object codegen {
- /**
- * A lock to protect invoking the scala compiler at runtime, since it is not thread safe in Scala
- * 2.10.
- */
- protected[codegen] val globalLock = org.apache.spark.sql.catalyst.ScalaReflectionLock
-
/** Canonicalizes an expression so those that differ only by names can reuse the same code. */
object ExpressionCanonicalizer extends rules.RuleExecutor[Expression] {
val batches =
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala
index b6927485f4..5df528770c 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala
@@ -344,7 +344,7 @@ class ExpressionEvaluationSuite extends ExpressionEvaluationBaseSuite {
checkEvaluation("abdef" cast TimestampType, null)
checkEvaluation("12.65" cast DecimalType.Unlimited, Decimal(12.65))
- checkEvaluation(Literal(1) cast LongType, 1)
+ checkEvaluation(Literal(1) cast LongType, 1.toLong)
checkEvaluation(Cast(Literal(1000) cast TimestampType, LongType), 1.toLong)
checkEvaluation(Cast(Literal(-1200) cast TimestampType, LongType), -2.toLong)
checkEvaluation(Cast(Literal(1.toDouble) cast TimestampType, DoubleType), 1.toDouble)
@@ -363,13 +363,16 @@ class ExpressionEvaluationSuite extends ExpressionEvaluationBaseSuite {
checkEvaluation(Cast("abdef" cast BinaryType, StringType), "abdef")
checkEvaluation(Cast(Cast(Cast(Cast(
- Cast("5" cast ByteType, ShortType), IntegerType), FloatType), DoubleType), LongType), 5)
+ Cast("5" cast ByteType, ShortType), IntegerType), FloatType), DoubleType), LongType),
+ 5.toLong)
checkEvaluation(Cast(Cast(Cast(Cast(Cast("5" cast
- ByteType, TimestampType), DecimalType.Unlimited), LongType), StringType), ShortType), 0)
+ ByteType, TimestampType), DecimalType.Unlimited), LongType), StringType), ShortType),
+ 0.toShort)
checkEvaluation(Cast(Cast(Cast(Cast(Cast("5" cast
TimestampType, ByteType), DecimalType.Unlimited), LongType), StringType), ShortType), null)
checkEvaluation(Cast(Cast(Cast(Cast(Cast("5" cast
- DecimalType.Unlimited, ByteType), TimestampType), LongType), StringType), ShortType), 0)
+ DecimalType.Unlimited, ByteType), TimestampType), LongType), StringType), ShortType),
+ 0.toShort)
checkEvaluation(Literal(true) cast IntegerType, 1)
checkEvaluation(Literal(false) cast IntegerType, 0)
checkEvaluation(Literal(true) cast StringType, "true")
@@ -509,9 +512,9 @@ class ExpressionEvaluationSuite extends ExpressionEvaluationBaseSuite {
val seconds = millis * 1000 + 2
val ts = new Timestamp(millis)
val tss = new Timestamp(seconds)
- checkEvaluation(Cast(ts, ShortType), 15)
+ checkEvaluation(Cast(ts, ShortType), 15.toShort)
checkEvaluation(Cast(ts, IntegerType), 15)
- checkEvaluation(Cast(ts, LongType), 15)
+ checkEvaluation(Cast(ts, LongType), 15.toLong)
checkEvaluation(Cast(ts, FloatType), 15.002f)
checkEvaluation(Cast(ts, DoubleType), 15.002)
checkEvaluation(Cast(Cast(tss, ShortType), TimestampType), ts)
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/GeneratedEvaluationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/GeneratedEvaluationSuite.scala
index d7c437095e..8cfd853afa 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/GeneratedEvaluationSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/GeneratedEvaluationSuite.scala
@@ -32,11 +32,12 @@ class GeneratedEvaluationSuite extends ExpressionEvaluationSuite {
GenerateMutableProjection.generate(Alias(expression, s"Optimized($expression)")() :: Nil)()
} catch {
case e: Throwable =>
- val evaluated = GenerateProjection.expressionEvaluator(expression)
+ val ctx = GenerateProjection.newCodeGenContext()
+ val evaluated = GenerateProjection.expressionEvaluator(expression, ctx)
fail(
s"""
|Code generation of $expression failed:
- |${evaluated.code.mkString("\n")}
+ |${evaluated.code}
|$e
""".stripMargin)
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/GeneratedMutableEvaluationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/GeneratedMutableEvaluationSuite.scala
index a40324b008..9ab1f7d7ad 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/GeneratedMutableEvaluationSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/GeneratedMutableEvaluationSuite.scala
@@ -28,7 +28,8 @@ class GeneratedMutableEvaluationSuite extends ExpressionEvaluationSuite {
expression: Expression,
expected: Any,
inputRow: Row = EmptyRow): Unit = {
- lazy val evaluated = GenerateProjection.expressionEvaluator(expression)
+ val ctx = GenerateProjection.newCodeGenContext()
+ lazy val evaluated = GenerateProjection.expressionEvaluator(expression, ctx)
val plan = try {
GenerateProjection.generate(Alias(expression, s"Optimized($expression)")() :: Nil)
@@ -37,7 +38,7 @@ class GeneratedMutableEvaluationSuite extends ExpressionEvaluationSuite {
fail(
s"""
|Code generation of $expression failed:
- |${evaluated.code.mkString("\n")}
+ |${evaluated.code}
|$e
""".stripMargin)
}
@@ -49,7 +50,7 @@ class GeneratedMutableEvaluationSuite extends ExpressionEvaluationSuite {
s"""
|Mismatched hashCodes for values: $actual, $expectedRow
|Hash Codes: ${actual.hashCode()} != ${expectedRow.hashCode()}
- |${evaluated.code.mkString("\n")}
+ |${evaluated.code}
""".stripMargin)
}
if (actual != expectedRow) {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
index 9aaec2b064..b41b1b77d0 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
@@ -451,10 +451,13 @@ class DataFrameSuite extends QueryTest {
test("SPARK-6899") {
val originalValue = TestSQLContext.conf.codegenEnabled
TestSQLContext.setConf(SQLConf.CODEGEN_ENABLED, "true")
- checkAnswer(
- decimalData.agg(avg('a)),
- Row(new java.math.BigDecimal(2.0)))
- TestSQLContext.setConf(SQLConf.CODEGEN_ENABLED, originalValue.toString)
+ try{
+ checkAnswer(
+ decimalData.agg(avg('a)),
+ Row(new java.math.BigDecimal(2.0)))
+ } finally {
+ TestSQLContext.setConf(SQLConf.CODEGEN_ENABLED, originalValue.toString)
+ }
}
test("SPARK-7133: Implement struct, array, and map field accessor") {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
index 63f7d314fb..55b68d8e22 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
@@ -184,77 +184,79 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils {
checkAnswer(df, expectedResults)
}
- // Just to group rows.
- testCodeGen(
- "SELECT key FROM testData3x GROUP BY key",
- (1 to 100).map(Row(_)))
- // COUNT
- testCodeGen(
- "SELECT key, count(value) FROM testData3x GROUP BY key",
- (1 to 100).map(i => Row(i, 3)))
- testCodeGen(
- "SELECT count(key) FROM testData3x",
- Row(300) :: Nil)
- // COUNT DISTINCT ON int
- testCodeGen(
- "SELECT value, count(distinct key) FROM testData3x GROUP BY value",
- (1 to 100).map(i => Row(i.toString, 1)))
- testCodeGen(
- "SELECT count(distinct key) FROM testData3x",
- Row(100) :: Nil)
- // SUM
- testCodeGen(
- "SELECT value, sum(key) FROM testData3x GROUP BY value",
- (1 to 100).map(i => Row(i.toString, 3 * i)))
- testCodeGen(
- "SELECT sum(key), SUM(CAST(key as Double)) FROM testData3x",
- Row(5050 * 3, 5050 * 3.0) :: Nil)
- // AVERAGE
- testCodeGen(
- "SELECT value, avg(key) FROM testData3x GROUP BY value",
- (1 to 100).map(i => Row(i.toString, i)))
- testCodeGen(
- "SELECT avg(key) FROM testData3x",
- Row(50.5) :: Nil)
- // MAX
- testCodeGen(
- "SELECT value, max(key) FROM testData3x GROUP BY value",
- (1 to 100).map(i => Row(i.toString, i)))
- testCodeGen(
- "SELECT max(key) FROM testData3x",
- Row(100) :: Nil)
- // MIN
- testCodeGen(
- "SELECT value, min(key) FROM testData3x GROUP BY value",
- (1 to 100).map(i => Row(i.toString, i)))
- testCodeGen(
- "SELECT min(key) FROM testData3x",
- Row(1) :: Nil)
- // Some combinations.
- testCodeGen(
- """
- |SELECT
- | value,
- | sum(key),
- | max(key),
- | min(key),
- | avg(key),
- | count(key),
- | count(distinct key)
- |FROM testData3x
- |GROUP BY value
- """.stripMargin,
- (1 to 100).map(i => Row(i.toString, i*3, i, i, i, 3, 1)))
- testCodeGen(
- "SELECT max(key), min(key), avg(key), count(key), count(distinct key) FROM testData3x",
- Row(100, 1, 50.5, 300, 100) :: Nil)
- // Aggregate with Code generation handling all null values
- testCodeGen(
- "SELECT sum('a'), avg('a'), count(null) FROM testData",
- Row(0, null, 0) :: Nil)
-
- dropTempTable("testData3x")
- setConf(SQLConf.CODEGEN_ENABLED, originalValue.toString)
+ try {
+ // Just to group rows.
+ testCodeGen(
+ "SELECT key FROM testData3x GROUP BY key",
+ (1 to 100).map(Row(_)))
+ // COUNT
+ testCodeGen(
+ "SELECT key, count(value) FROM testData3x GROUP BY key",
+ (1 to 100).map(i => Row(i, 3)))
+ testCodeGen(
+ "SELECT count(key) FROM testData3x",
+ Row(300) :: Nil)
+ // COUNT DISTINCT ON int
+ testCodeGen(
+ "SELECT value, count(distinct key) FROM testData3x GROUP BY value",
+ (1 to 100).map(i => Row(i.toString, 1)))
+ testCodeGen(
+ "SELECT count(distinct key) FROM testData3x",
+ Row(100) :: Nil)
+ // SUM
+ testCodeGen(
+ "SELECT value, sum(key) FROM testData3x GROUP BY value",
+ (1 to 100).map(i => Row(i.toString, 3 * i)))
+ testCodeGen(
+ "SELECT sum(key), SUM(CAST(key as Double)) FROM testData3x",
+ Row(5050 * 3, 5050 * 3.0) :: Nil)
+ // AVERAGE
+ testCodeGen(
+ "SELECT value, avg(key) FROM testData3x GROUP BY value",
+ (1 to 100).map(i => Row(i.toString, i)))
+ testCodeGen(
+ "SELECT avg(key) FROM testData3x",
+ Row(50.5) :: Nil)
+ // MAX
+ testCodeGen(
+ "SELECT value, max(key) FROM testData3x GROUP BY value",
+ (1 to 100).map(i => Row(i.toString, i)))
+ testCodeGen(
+ "SELECT max(key) FROM testData3x",
+ Row(100) :: Nil)
+ // MIN
+ testCodeGen(
+ "SELECT value, min(key) FROM testData3x GROUP BY value",
+ (1 to 100).map(i => Row(i.toString, i)))
+ testCodeGen(
+ "SELECT min(key) FROM testData3x",
+ Row(1) :: Nil)
+ // Some combinations.
+ testCodeGen(
+ """
+ |SELECT
+ | value,
+ | sum(key),
+ | max(key),
+ | min(key),
+ | avg(key),
+ | count(key),
+ | count(distinct key)
+ |FROM testData3x
+ |GROUP BY value
+ """.stripMargin,
+ (1 to 100).map(i => Row(i.toString, i*3, i, i, i, 3, 1)))
+ testCodeGen(
+ "SELECT max(key), min(key), avg(key), count(key), count(distinct key) FROM testData3x",
+ Row(100, 1, 50.5, 300, 100) :: Nil)
+ // Aggregate with Code generation handling all null values
+ testCodeGen(
+ "SELECT sum('a'), avg('a'), count(null) FROM testData",
+ Row(0, null, 0) :: Nil)
+ } finally {
+ dropTempTable("testData3x")
+ setConf(SQLConf.CODEGEN_ENABLED, originalValue.toString)
+ }
}
test("Add Parser of SQL COALESCE()") {
@@ -463,9 +465,12 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils {
val codegenbefore = conf.codegenEnabled
setConf(SQLConf.EXTERNAL_SORT, "false")
setConf(SQLConf.CODEGEN_ENABLED, "true")
- sortTest()
- setConf(SQLConf.EXTERNAL_SORT, externalbefore.toString)
- setConf(SQLConf.CODEGEN_ENABLED, codegenbefore.toString)
+ try{
+ sortTest()
+ } finally {
+ setConf(SQLConf.EXTERNAL_SORT, externalbefore.toString)
+ setConf(SQLConf.CODEGEN_ENABLED, codegenbefore.toString)
+ }
}
test("SPARK-6927 external sorting with codegen on") {
@@ -473,9 +478,12 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils {
val codegenbefore = conf.codegenEnabled
setConf(SQLConf.CODEGEN_ENABLED, "true")
setConf(SQLConf.EXTERNAL_SORT, "true")
- sortTest()
- setConf(SQLConf.EXTERNAL_SORT, externalbefore.toString)
- setConf(SQLConf.CODEGEN_ENABLED, codegenbefore.toString)
+ try {
+ sortTest()
+ } finally {
+ setConf(SQLConf.EXTERNAL_SORT, externalbefore.toString)
+ setConf(SQLConf.CODEGEN_ENABLED, codegenbefore.toString)
+ }
}
test("limit") {