aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTarek Auel <tarek.auel@googlemail.com>2015-07-20 22:08:12 -0700
committerReynold Xin <rxin@databricks.com>2015-07-20 22:08:12 -0700
commita3c7a3ce32697ad293b8bcaf29f9384c8255b37f (patch)
tree23745a78e41dc23fe2afa9bc6cdfb5d48dd1abef
parent1cbdd8991898912a8471a7070c472a0edb92487c (diff)
downloadspark-a3c7a3ce32697ad293b8bcaf29f9384c8255b37f.tar.gz
spark-a3c7a3ce32697ad293b8bcaf29f9384c8255b37f.tar.bz2
spark-a3c7a3ce32697ad293b8bcaf29f9384c8255b37f.zip
[SPARK-9132][SPARK-9163][SQL] codegen conv
Jira: https://issues.apache.org/jira/browse/SPARK-9132 https://issues.apache.org/jira/browse/SPARK-9163 rxin as you proposed in the Jira ticket, I just moved the logic to a separate object. I haven't changed anything of the logic of `NumberConverter`. Author: Tarek Auel <tarek.auel@googlemail.com> Closes #7552 from tarekauel/SPARK-9163 and squashes the following commits: 40dcde9 [Tarek Auel] [SPARK-9132][SPARK-9163][SQL] style fix fa985bd [Tarek Auel] [SPARK-9132][SPARK-9163][SQL] codegen conv
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala204
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/NumberConverter.scala176
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala4
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/NumberConverterSuite.scala40
4 files changed, 263 insertions, 161 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala
index 7a9be02ba4..68cca0ad3d 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala
@@ -23,6 +23,7 @@ import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{TypeCheckSuccess, TypeCheckFailure}
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.util.NumberConverter
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String
@@ -164,7 +165,7 @@ case class Cosh(child: Expression) extends UnaryMathExpression(math.cosh, "COSH"
* @param toBaseExpr to which base
*/
case class Conv(numExpr: Expression, fromBaseExpr: Expression, toBaseExpr: Expression)
- extends Expression with ImplicitCastInputTypes with CodegenFallback {
+ extends Expression with ImplicitCastInputTypes {
override def foldable: Boolean = numExpr.foldable && fromBaseExpr.foldable && toBaseExpr.foldable
@@ -179,169 +180,54 @@ case class Conv(numExpr: Expression, fromBaseExpr: Expression, toBaseExpr: Expre
/** Returns the result of evaluating this expression on a given input Row */
override def eval(input: InternalRow): Any = {
val num = numExpr.eval(input)
- val fromBase = fromBaseExpr.eval(input)
- val toBase = toBaseExpr.eval(input)
- if (num == null || fromBase == null || toBase == null) {
- null
- } else {
- conv(
- num.asInstanceOf[UTF8String].getBytes,
- fromBase.asInstanceOf[Int],
- toBase.asInstanceOf[Int])
- }
- }
-
- private val value = new Array[Byte](64)
-
- /**
- * Divide x by m as if x is an unsigned 64-bit integer. Examples:
- * unsignedLongDiv(-1, 2) == Long.MAX_VALUE unsignedLongDiv(6, 3) == 2
- * unsignedLongDiv(0, 5) == 0
- *
- * @param x is treated as unsigned
- * @param m is treated as signed
- */
- private def unsignedLongDiv(x: Long, m: Int): Long = {
- if (x >= 0) {
- x / m
- } else {
- // Let uval be the value of the unsigned long with the same bits as x
- // Two's complement => x = uval - 2*MAX - 2
- // => uval = x + 2*MAX + 2
- // Now, use the fact: (a+b)/c = a/c + b/c + (a%c+b%c)/c
- x / m + 2 * (Long.MaxValue / m) + 2 / m + (x % m + 2 * (Long.MaxValue % m) + 2 % m) / m
- }
- }
-
- /**
- * Decode v into value[].
- *
- * @param v is treated as an unsigned 64-bit integer
- * @param radix must be between MIN_RADIX and MAX_RADIX
- */
- private def decode(v: Long, radix: Int): Unit = {
- var tmpV = v
- java.util.Arrays.fill(value, 0.asInstanceOf[Byte])
- var i = value.length - 1
- while (tmpV != 0) {
- val q = unsignedLongDiv(tmpV, radix)
- value(i) = (tmpV - q * radix).asInstanceOf[Byte]
- tmpV = q
- i -= 1
- }
- }
-
- /**
- * Convert value[] into a long. On overflow, return -1 (as mySQL does). If a
- * negative digit is found, ignore the suffix starting there.
- *
- * @param radix must be between MIN_RADIX and MAX_RADIX
- * @param fromPos is the first element that should be conisdered
- * @return the result should be treated as an unsigned 64-bit integer.
- */
- private def encode(radix: Int, fromPos: Int): Long = {
- var v: Long = 0L
- val bound = unsignedLongDiv(-1 - radix, radix) // Possible overflow once
- // val
- // exceeds this value
- var i = fromPos
- while (i < value.length && value(i) >= 0) {
- if (v >= bound) {
- // Check for overflow
- if (unsignedLongDiv(-1 - value(i), radix) < v) {
- return -1
+ if (num != null) {
+ val fromBase = fromBaseExpr.eval(input)
+ if (fromBase != null) {
+ val toBase = toBaseExpr.eval(input)
+ if (toBase != null) {
+ NumberConverter.convert(
+ num.asInstanceOf[UTF8String].getBytes,
+ fromBase.asInstanceOf[Int],
+ toBase.asInstanceOf[Int])
+ } else {
+ null
}
- }
- v = v * radix + value(i)
- i += 1
- }
- v
- }
-
- /**
- * Convert the bytes in value[] to the corresponding chars.
- *
- * @param radix must be between MIN_RADIX and MAX_RADIX
- * @param fromPos is the first nonzero element
- */
- private def byte2char(radix: Int, fromPos: Int): Unit = {
- var i = fromPos
- while (i < value.length) {
- value(i) = Character.toUpperCase(Character.forDigit(value(i), radix)).asInstanceOf[Byte]
- i += 1
- }
- }
-
- /**
- * Convert the chars in value[] to the corresponding integers. Convert invalid
- * characters to -1.
- *
- * @param radix must be between MIN_RADIX and MAX_RADIX
- * @param fromPos is the first nonzero element
- */
- private def char2byte(radix: Int, fromPos: Int): Unit = {
- var i = fromPos
- while ( i < value.length) {
- value(i) = Character.digit(value(i), radix).asInstanceOf[Byte]
- i += 1
- }
- }
-
- /**
- * Convert numbers between different number bases. If toBase>0 the result is
- * unsigned, otherwise it is signed.
- * NB: This logic is borrowed from org.apache.hadoop.hive.ql.ud.UDFConv
- */
- private def conv(n: Array[Byte] , fromBase: Int, toBase: Int ): UTF8String = {
- if (fromBase < Character.MIN_RADIX || fromBase > Character.MAX_RADIX
- || Math.abs(toBase) < Character.MIN_RADIX
- || Math.abs(toBase) > Character.MAX_RADIX) {
- return null
- }
-
- if (n.length == 0) {
- return null
- }
-
- var (negative, first) = if (n(0) == '-') (true, 1) else (false, 0)
-
- // Copy the digits in the right side of the array
- var i = 1
- while (i <= n.length - first) {
- value(value.length - i) = n(n.length - i)
- i += 1
- }
- char2byte(fromBase, value.length - n.length + first)
-
- // Do the conversion by going through a 64 bit integer
- var v = encode(fromBase, value.length - n.length + first)
- if (negative && toBase > 0) {
- if (v < 0) {
- v = -1
} else {
- v = -v
+ null
}
+ } else {
+ null
}
- if (toBase < 0 && v < 0) {
- v = -v
- negative = true
- }
- decode(v, Math.abs(toBase))
-
- // Find the first non-zero digit or the last digits if all are zero.
- val firstNonZeroPos = {
- val firstNonZero = value.indexWhere( _ != 0)
- if (firstNonZero != -1) firstNonZero else value.length - 1
- }
-
- byte2char(Math.abs(toBase), firstNonZeroPos)
+ }
- var resultStartPos = firstNonZeroPos
- if (negative && toBase < 0) {
- resultStartPos = firstNonZeroPos - 1
- value(resultStartPos) = '-'
- }
- UTF8String.fromBytes(java.util.Arrays.copyOfRange(value, resultStartPos, value.length))
+ override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
+ val numGen = numExpr.gen(ctx)
+ val from = fromBaseExpr.gen(ctx)
+ val to = toBaseExpr.gen(ctx)
+
+ val numconv = NumberConverter.getClass.getName.stripSuffix("$")
+ s"""
+ ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)};
+ ${numGen.code}
+ boolean ${ev.isNull} = ${numGen.isNull};
+ if (!${ev.isNull}) {
+ ${from.code}
+ if (!${from.isNull}) {
+ ${to.code}
+ if (!${to.isNull}) {
+ ${ev.primitive} = $numconv.convert(${numGen.primitive}.getBytes(),
+ ${from.primitive}, ${to.primitive});
+ if (${ev.primitive} == null) {
+ ${ev.isNull} = true;
+ }
+ } else {
+ ${ev.isNull} = true;
+ }
+ } else {
+ ${ev.isNull} = true;
+ }
+ }
+ """
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/NumberConverter.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/NumberConverter.scala
new file mode 100644
index 0000000000..9fefc5656a
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/NumberConverter.scala
@@ -0,0 +1,176 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.util
+
+import org.apache.spark.unsafe.types.UTF8String
+
+object NumberConverter {
+
+ private val value = new Array[Byte](64)
+
+ /**
+ * Divide x by m as if x is an unsigned 64-bit integer. Examples:
+ * unsignedLongDiv(-1, 2) == Long.MAX_VALUE unsignedLongDiv(6, 3) == 2
+ * unsignedLongDiv(0, 5) == 0
+ *
+ * @param x is treated as unsigned
+ * @param m is treated as signed
+ */
+ private def unsignedLongDiv(x: Long, m: Int): Long = {
+ if (x >= 0) {
+ x / m
+ } else {
+ // Let uval be the value of the unsigned long with the same bits as x
+ // Two's complement => x = uval - 2*MAX - 2
+ // => uval = x + 2*MAX + 2
+ // Now, use the fact: (a+b)/c = a/c + b/c + (a%c+b%c)/c
+ x / m + 2 * (Long.MaxValue / m) + 2 / m + (x % m + 2 * (Long.MaxValue % m) + 2 % m) / m
+ }
+ }
+
+ /**
+ * Decode v into value[].
+ *
+ * @param v is treated as an unsigned 64-bit integer
+ * @param radix must be between MIN_RADIX and MAX_RADIX
+ */
+ private def decode(v: Long, radix: Int): Unit = {
+ var tmpV = v
+ java.util.Arrays.fill(value, 0.asInstanceOf[Byte])
+ var i = value.length - 1
+ while (tmpV != 0) {
+ val q = unsignedLongDiv(tmpV, radix)
+ value(i) = (tmpV - q * radix).asInstanceOf[Byte]
+ tmpV = q
+ i -= 1
+ }
+ }
+
+ /**
+ * Convert value[] into a long. On overflow, return -1 (as mySQL does). If a
+ * negative digit is found, ignore the suffix starting there.
+ *
+ * @param radix must be between MIN_RADIX and MAX_RADIX
+ * @param fromPos is the first element that should be conisdered
+ * @return the result should be treated as an unsigned 64-bit integer.
+ */
+ private def encode(radix: Int, fromPos: Int): Long = {
+ var v: Long = 0L
+ val bound = unsignedLongDiv(-1 - radix, radix) // Possible overflow once
+ // val
+ // exceeds this value
+ var i = fromPos
+ while (i < value.length && value(i) >= 0) {
+ if (v >= bound) {
+ // Check for overflow
+ if (unsignedLongDiv(-1 - value(i), radix) < v) {
+ return -1
+ }
+ }
+ v = v * radix + value(i)
+ i += 1
+ }
+ v
+ }
+
+ /**
+ * Convert the bytes in value[] to the corresponding chars.
+ *
+ * @param radix must be between MIN_RADIX and MAX_RADIX
+ * @param fromPos is the first nonzero element
+ */
+ private def byte2char(radix: Int, fromPos: Int): Unit = {
+ var i = fromPos
+ while (i < value.length) {
+ value(i) = Character.toUpperCase(Character.forDigit(value(i), radix)).asInstanceOf[Byte]
+ i += 1
+ }
+ }
+
+ /**
+ * Convert the chars in value[] to the corresponding integers. Convert invalid
+ * characters to -1.
+ *
+ * @param radix must be between MIN_RADIX and MAX_RADIX
+ * @param fromPos is the first nonzero element
+ */
+ private def char2byte(radix: Int, fromPos: Int): Unit = {
+ var i = fromPos
+ while ( i < value.length) {
+ value(i) = Character.digit(value(i), radix).asInstanceOf[Byte]
+ i += 1
+ }
+ }
+
+ /**
+ * Convert numbers between different number bases. If toBase>0 the result is
+ * unsigned, otherwise it is signed.
+ * NB: This logic is borrowed from org.apache.hadoop.hive.ql.ud.UDFConv
+ */
+ def convert(n: Array[Byte] , fromBase: Int, toBase: Int ): UTF8String = {
+ if (fromBase < Character.MIN_RADIX || fromBase > Character.MAX_RADIX
+ || Math.abs(toBase) < Character.MIN_RADIX
+ || Math.abs(toBase) > Character.MAX_RADIX) {
+ return null
+ }
+
+ if (n.length == 0) {
+ return null
+ }
+
+ var (negative, first) = if (n(0) == '-') (true, 1) else (false, 0)
+
+ // Copy the digits in the right side of the array
+ var i = 1
+ while (i <= n.length - first) {
+ value(value.length - i) = n(n.length - i)
+ i += 1
+ }
+ char2byte(fromBase, value.length - n.length + first)
+
+ // Do the conversion by going through a 64 bit integer
+ var v = encode(fromBase, value.length - n.length + first)
+ if (negative && toBase > 0) {
+ if (v < 0) {
+ v = -1
+ } else {
+ v = -v
+ }
+ }
+ if (toBase < 0 && v < 0) {
+ v = -v
+ negative = true
+ }
+ decode(v, Math.abs(toBase))
+
+ // Find the first non-zero digit or the last digits if all are zero.
+ val firstNonZeroPos = {
+ val firstNonZero = value.indexWhere( _ != 0)
+ if (firstNonZero != -1) firstNonZero else value.length - 1
+ }
+
+ byte2char(Math.abs(toBase), firstNonZeroPos)
+
+ var resultStartPos = firstNonZeroPos
+ if (negative && toBase < 0) {
+ resultStartPos = firstNonZeroPos - 1
+ value(resultStartPos) = '-'
+ }
+ UTF8String.fromBytes(java.util.Arrays.copyOfRange(value, resultStartPos, value.length))
+ }
+}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala
index 04acd5b5ff..a2b0fad7b7 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala
@@ -115,8 +115,8 @@ class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper {
checkEvaluation(Conv(Literal("-15"), Literal(10), Literal(-16)), "-F")
checkEvaluation(Conv(Literal("-15"), Literal(10), Literal(16)), "FFFFFFFFFFFFFFF1")
checkEvaluation(Conv(Literal("big"), Literal(36), Literal(16)), "3A48")
- checkEvaluation(Conv(Literal(null), Literal(36), Literal(16)), null)
- checkEvaluation(Conv(Literal("3"), Literal(null), Literal(16)), null)
+ checkEvaluation(Conv(Literal.create(null, StringType), Literal(36), Literal(16)), null)
+ checkEvaluation(Conv(Literal("3"), Literal.create(null, IntegerType), Literal(16)), null)
checkEvaluation(
Conv(Literal("1234"), Literal(10), Literal(37)), null)
checkEvaluation(
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/NumberConverterSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/NumberConverterSuite.scala
new file mode 100644
index 0000000000..13265a1ff1
--- /dev/null
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/NumberConverterSuite.scala
@@ -0,0 +1,40 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.util
+
+import org.apache.spark.SparkFunSuite
+import org.apache.spark.sql.catalyst.util.NumberConverter.convert
+import org.apache.spark.unsafe.types.UTF8String
+
+class NumberConverterSuite extends SparkFunSuite {
+
+ private[this] def checkConv(n: String, fromBase: Int, toBase: Int, expected: String): Unit = {
+ assert(convert(UTF8String.fromString(n).getBytes, fromBase, toBase) ===
+ UTF8String.fromString(expected))
+ }
+
+ test("convert") {
+ checkConv("3", 10, 2, "11")
+ checkConv("-15", 10, -16, "-F")
+ checkConv("-15", 10, 16, "FFFFFFFFFFFFFFF1")
+ checkConv("big", 36, 16, "3A48")
+ checkConv("9223372036854775807", 36, 16, "FFFFFFFFFFFFFFFF")
+ checkConv("11abc", 10, 16, "B")
+ }
+
+}