diff options
author | Tarek Auel <tarek.auel@googlemail.com> | 2015-07-20 22:08:12 -0700 |
---|---|---|
committer | Reynold Xin <rxin@databricks.com> | 2015-07-20 22:08:12 -0700 |
commit | a3c7a3ce32697ad293b8bcaf29f9384c8255b37f (patch) | |
tree | 23745a78e41dc23fe2afa9bc6cdfb5d48dd1abef | |
parent | 1cbdd8991898912a8471a7070c472a0edb92487c (diff) | |
download | spark-a3c7a3ce32697ad293b8bcaf29f9384c8255b37f.tar.gz spark-a3c7a3ce32697ad293b8bcaf29f9384c8255b37f.tar.bz2 spark-a3c7a3ce32697ad293b8bcaf29f9384c8255b37f.zip |
[SPARK-9132][SPARK-9163][SQL] codegen conv
Jira: https://issues.apache.org/jira/browse/SPARK-9132
https://issues.apache.org/jira/browse/SPARK-9163
rxin as you proposed in the Jira ticket, I just moved the logic to a separate object. I haven't changed anything of the logic of `NumberConverter`.
Author: Tarek Auel <tarek.auel@googlemail.com>
Closes #7552 from tarekauel/SPARK-9163 and squashes the following commits:
40dcde9 [Tarek Auel] [SPARK-9132][SPARK-9163][SQL] style fix
fa985bd [Tarek Auel] [SPARK-9132][SPARK-9163][SQL] codegen conv
4 files changed, 263 insertions, 161 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala index 7a9be02ba4..68cca0ad3d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala @@ -23,6 +23,7 @@ import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{TypeCheckSuccess, TypeCheckFailure} import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.util.NumberConverter import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -164,7 +165,7 @@ case class Cosh(child: Expression) extends UnaryMathExpression(math.cosh, "COSH" * @param toBaseExpr to which base */ case class Conv(numExpr: Expression, fromBaseExpr: Expression, toBaseExpr: Expression) - extends Expression with ImplicitCastInputTypes with CodegenFallback { + extends Expression with ImplicitCastInputTypes { override def foldable: Boolean = numExpr.foldable && fromBaseExpr.foldable && toBaseExpr.foldable @@ -179,169 +180,54 @@ case class Conv(numExpr: Expression, fromBaseExpr: Expression, toBaseExpr: Expre /** Returns the result of evaluating this expression on a given input Row */ override def eval(input: InternalRow): Any = { val num = numExpr.eval(input) - val fromBase = fromBaseExpr.eval(input) - val toBase = toBaseExpr.eval(input) - if (num == null || fromBase == null || toBase == null) { - null - } else { - conv( - num.asInstanceOf[UTF8String].getBytes, - fromBase.asInstanceOf[Int], - toBase.asInstanceOf[Int]) - } - } - - private val value = new Array[Byte](64) - - /** - * Divide x by m as if x is an unsigned 64-bit integer. Examples: - * unsignedLongDiv(-1, 2) == Long.MAX_VALUE unsignedLongDiv(6, 3) == 2 - * unsignedLongDiv(0, 5) == 0 - * - * @param x is treated as unsigned - * @param m is treated as signed - */ - private def unsignedLongDiv(x: Long, m: Int): Long = { - if (x >= 0) { - x / m - } else { - // Let uval be the value of the unsigned long with the same bits as x - // Two's complement => x = uval - 2*MAX - 2 - // => uval = x + 2*MAX + 2 - // Now, use the fact: (a+b)/c = a/c + b/c + (a%c+b%c)/c - x / m + 2 * (Long.MaxValue / m) + 2 / m + (x % m + 2 * (Long.MaxValue % m) + 2 % m) / m - } - } - - /** - * Decode v into value[]. - * - * @param v is treated as an unsigned 64-bit integer - * @param radix must be between MIN_RADIX and MAX_RADIX - */ - private def decode(v: Long, radix: Int): Unit = { - var tmpV = v - java.util.Arrays.fill(value, 0.asInstanceOf[Byte]) - var i = value.length - 1 - while (tmpV != 0) { - val q = unsignedLongDiv(tmpV, radix) - value(i) = (tmpV - q * radix).asInstanceOf[Byte] - tmpV = q - i -= 1 - } - } - - /** - * Convert value[] into a long. On overflow, return -1 (as mySQL does). If a - * negative digit is found, ignore the suffix starting there. - * - * @param radix must be between MIN_RADIX and MAX_RADIX - * @param fromPos is the first element that should be conisdered - * @return the result should be treated as an unsigned 64-bit integer. - */ - private def encode(radix: Int, fromPos: Int): Long = { - var v: Long = 0L - val bound = unsignedLongDiv(-1 - radix, radix) // Possible overflow once - // val - // exceeds this value - var i = fromPos - while (i < value.length && value(i) >= 0) { - if (v >= bound) { - // Check for overflow - if (unsignedLongDiv(-1 - value(i), radix) < v) { - return -1 + if (num != null) { + val fromBase = fromBaseExpr.eval(input) + if (fromBase != null) { + val toBase = toBaseExpr.eval(input) + if (toBase != null) { + NumberConverter.convert( + num.asInstanceOf[UTF8String].getBytes, + fromBase.asInstanceOf[Int], + toBase.asInstanceOf[Int]) + } else { + null } - } - v = v * radix + value(i) - i += 1 - } - v - } - - /** - * Convert the bytes in value[] to the corresponding chars. - * - * @param radix must be between MIN_RADIX and MAX_RADIX - * @param fromPos is the first nonzero element - */ - private def byte2char(radix: Int, fromPos: Int): Unit = { - var i = fromPos - while (i < value.length) { - value(i) = Character.toUpperCase(Character.forDigit(value(i), radix)).asInstanceOf[Byte] - i += 1 - } - } - - /** - * Convert the chars in value[] to the corresponding integers. Convert invalid - * characters to -1. - * - * @param radix must be between MIN_RADIX and MAX_RADIX - * @param fromPos is the first nonzero element - */ - private def char2byte(radix: Int, fromPos: Int): Unit = { - var i = fromPos - while ( i < value.length) { - value(i) = Character.digit(value(i), radix).asInstanceOf[Byte] - i += 1 - } - } - - /** - * Convert numbers between different number bases. If toBase>0 the result is - * unsigned, otherwise it is signed. - * NB: This logic is borrowed from org.apache.hadoop.hive.ql.ud.UDFConv - */ - private def conv(n: Array[Byte] , fromBase: Int, toBase: Int ): UTF8String = { - if (fromBase < Character.MIN_RADIX || fromBase > Character.MAX_RADIX - || Math.abs(toBase) < Character.MIN_RADIX - || Math.abs(toBase) > Character.MAX_RADIX) { - return null - } - - if (n.length == 0) { - return null - } - - var (negative, first) = if (n(0) == '-') (true, 1) else (false, 0) - - // Copy the digits in the right side of the array - var i = 1 - while (i <= n.length - first) { - value(value.length - i) = n(n.length - i) - i += 1 - } - char2byte(fromBase, value.length - n.length + first) - - // Do the conversion by going through a 64 bit integer - var v = encode(fromBase, value.length - n.length + first) - if (negative && toBase > 0) { - if (v < 0) { - v = -1 } else { - v = -v + null } + } else { + null } - if (toBase < 0 && v < 0) { - v = -v - negative = true - } - decode(v, Math.abs(toBase)) - - // Find the first non-zero digit or the last digits if all are zero. - val firstNonZeroPos = { - val firstNonZero = value.indexWhere( _ != 0) - if (firstNonZero != -1) firstNonZero else value.length - 1 - } - - byte2char(Math.abs(toBase), firstNonZeroPos) + } - var resultStartPos = firstNonZeroPos - if (negative && toBase < 0) { - resultStartPos = firstNonZeroPos - 1 - value(resultStartPos) = '-' - } - UTF8String.fromBytes(java.util.Arrays.copyOfRange(value, resultStartPos, value.length)) + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + val numGen = numExpr.gen(ctx) + val from = fromBaseExpr.gen(ctx) + val to = toBaseExpr.gen(ctx) + + val numconv = NumberConverter.getClass.getName.stripSuffix("$") + s""" + ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)}; + ${numGen.code} + boolean ${ev.isNull} = ${numGen.isNull}; + if (!${ev.isNull}) { + ${from.code} + if (!${from.isNull}) { + ${to.code} + if (!${to.isNull}) { + ${ev.primitive} = $numconv.convert(${numGen.primitive}.getBytes(), + ${from.primitive}, ${to.primitive}); + if (${ev.primitive} == null) { + ${ev.isNull} = true; + } + } else { + ${ev.isNull} = true; + } + } else { + ${ev.isNull} = true; + } + } + """ } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/NumberConverter.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/NumberConverter.scala new file mode 100644 index 0000000000..9fefc5656a --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/NumberConverter.scala @@ -0,0 +1,176 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.util + +import org.apache.spark.unsafe.types.UTF8String + +object NumberConverter { + + private val value = new Array[Byte](64) + + /** + * Divide x by m as if x is an unsigned 64-bit integer. Examples: + * unsignedLongDiv(-1, 2) == Long.MAX_VALUE unsignedLongDiv(6, 3) == 2 + * unsignedLongDiv(0, 5) == 0 + * + * @param x is treated as unsigned + * @param m is treated as signed + */ + private def unsignedLongDiv(x: Long, m: Int): Long = { + if (x >= 0) { + x / m + } else { + // Let uval be the value of the unsigned long with the same bits as x + // Two's complement => x = uval - 2*MAX - 2 + // => uval = x + 2*MAX + 2 + // Now, use the fact: (a+b)/c = a/c + b/c + (a%c+b%c)/c + x / m + 2 * (Long.MaxValue / m) + 2 / m + (x % m + 2 * (Long.MaxValue % m) + 2 % m) / m + } + } + + /** + * Decode v into value[]. + * + * @param v is treated as an unsigned 64-bit integer + * @param radix must be between MIN_RADIX and MAX_RADIX + */ + private def decode(v: Long, radix: Int): Unit = { + var tmpV = v + java.util.Arrays.fill(value, 0.asInstanceOf[Byte]) + var i = value.length - 1 + while (tmpV != 0) { + val q = unsignedLongDiv(tmpV, radix) + value(i) = (tmpV - q * radix).asInstanceOf[Byte] + tmpV = q + i -= 1 + } + } + + /** + * Convert value[] into a long. On overflow, return -1 (as mySQL does). If a + * negative digit is found, ignore the suffix starting there. + * + * @param radix must be between MIN_RADIX and MAX_RADIX + * @param fromPos is the first element that should be conisdered + * @return the result should be treated as an unsigned 64-bit integer. + */ + private def encode(radix: Int, fromPos: Int): Long = { + var v: Long = 0L + val bound = unsignedLongDiv(-1 - radix, radix) // Possible overflow once + // val + // exceeds this value + var i = fromPos + while (i < value.length && value(i) >= 0) { + if (v >= bound) { + // Check for overflow + if (unsignedLongDiv(-1 - value(i), radix) < v) { + return -1 + } + } + v = v * radix + value(i) + i += 1 + } + v + } + + /** + * Convert the bytes in value[] to the corresponding chars. + * + * @param radix must be between MIN_RADIX and MAX_RADIX + * @param fromPos is the first nonzero element + */ + private def byte2char(radix: Int, fromPos: Int): Unit = { + var i = fromPos + while (i < value.length) { + value(i) = Character.toUpperCase(Character.forDigit(value(i), radix)).asInstanceOf[Byte] + i += 1 + } + } + + /** + * Convert the chars in value[] to the corresponding integers. Convert invalid + * characters to -1. + * + * @param radix must be between MIN_RADIX and MAX_RADIX + * @param fromPos is the first nonzero element + */ + private def char2byte(radix: Int, fromPos: Int): Unit = { + var i = fromPos + while ( i < value.length) { + value(i) = Character.digit(value(i), radix).asInstanceOf[Byte] + i += 1 + } + } + + /** + * Convert numbers between different number bases. If toBase>0 the result is + * unsigned, otherwise it is signed. + * NB: This logic is borrowed from org.apache.hadoop.hive.ql.ud.UDFConv + */ + def convert(n: Array[Byte] , fromBase: Int, toBase: Int ): UTF8String = { + if (fromBase < Character.MIN_RADIX || fromBase > Character.MAX_RADIX + || Math.abs(toBase) < Character.MIN_RADIX + || Math.abs(toBase) > Character.MAX_RADIX) { + return null + } + + if (n.length == 0) { + return null + } + + var (negative, first) = if (n(0) == '-') (true, 1) else (false, 0) + + // Copy the digits in the right side of the array + var i = 1 + while (i <= n.length - first) { + value(value.length - i) = n(n.length - i) + i += 1 + } + char2byte(fromBase, value.length - n.length + first) + + // Do the conversion by going through a 64 bit integer + var v = encode(fromBase, value.length - n.length + first) + if (negative && toBase > 0) { + if (v < 0) { + v = -1 + } else { + v = -v + } + } + if (toBase < 0 && v < 0) { + v = -v + negative = true + } + decode(v, Math.abs(toBase)) + + // Find the first non-zero digit or the last digits if all are zero. + val firstNonZeroPos = { + val firstNonZero = value.indexWhere( _ != 0) + if (firstNonZero != -1) firstNonZero else value.length - 1 + } + + byte2char(Math.abs(toBase), firstNonZeroPos) + + var resultStartPos = firstNonZeroPos + if (negative && toBase < 0) { + resultStartPos = firstNonZeroPos - 1 + value(resultStartPos) = '-' + } + UTF8String.fromBytes(java.util.Arrays.copyOfRange(value, resultStartPos, value.length)) + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala index 04acd5b5ff..a2b0fad7b7 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala @@ -115,8 +115,8 @@ class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(Conv(Literal("-15"), Literal(10), Literal(-16)), "-F") checkEvaluation(Conv(Literal("-15"), Literal(10), Literal(16)), "FFFFFFFFFFFFFFF1") checkEvaluation(Conv(Literal("big"), Literal(36), Literal(16)), "3A48") - checkEvaluation(Conv(Literal(null), Literal(36), Literal(16)), null) - checkEvaluation(Conv(Literal("3"), Literal(null), Literal(16)), null) + checkEvaluation(Conv(Literal.create(null, StringType), Literal(36), Literal(16)), null) + checkEvaluation(Conv(Literal("3"), Literal.create(null, IntegerType), Literal(16)), null) checkEvaluation( Conv(Literal("1234"), Literal(10), Literal(37)), null) checkEvaluation( diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/NumberConverterSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/NumberConverterSuite.scala new file mode 100644 index 0000000000..13265a1ff1 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/NumberConverterSuite.scala @@ -0,0 +1,40 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.util + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.util.NumberConverter.convert +import org.apache.spark.unsafe.types.UTF8String + +class NumberConverterSuite extends SparkFunSuite { + + private[this] def checkConv(n: String, fromBase: Int, toBase: Int, expected: String): Unit = { + assert(convert(UTF8String.fromString(n).getBytes, fromBase, toBase) === + UTF8String.fromString(expected)) + } + + test("convert") { + checkConv("3", 10, 2, "11") + checkConv("-15", 10, -16, "-F") + checkConv("-15", 10, 16, "FFFFFFFFFFFFFFF1") + checkConv("big", 36, 16, "3A48") + checkConv("9223372036854775807", 36, 16, "FFFFFFFFFFFFFFFF") + checkConv("11abc", 10, 16, "B") + } + +} |