aboutsummaryrefslogtreecommitdiff
path: root/sql/catalyst/src
diff options
context:
space:
mode:
Diffstat (limited to 'sql/catalyst/src')
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala57
1 files changed, 57 insertions, 0 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
index d29c27c14b..fa09f821fc 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
@@ -271,6 +271,63 @@ class CodegenContext {
}
/**
+ * Returns the specialized code to set a given value in a column vector for a given `DataType`.
+ */
+ def setValue(batch: String, row: String, dataType: DataType, ordinal: Int,
+ value: String): String = {
+ val jt = javaType(dataType)
+ dataType match {
+ case _ if isPrimitiveType(jt) =>
+ s"$batch.column($ordinal).put${primitiveTypeName(jt)}($row, $value);"
+ case t: DecimalType => s"$batch.column($ordinal).putDecimal($row, $value, ${t.precision});"
+ case t: StringType => s"$batch.column($ordinal).putByteArray($row, $value.getBytes());"
+ case _ =>
+ throw new IllegalArgumentException(s"cannot generate code for unsupported type: $dataType")
+ }
+ }
+
+ /**
+ * Returns the specialized code to set a given value in a column vector for a given `DataType`
+ * that could potentially be nullable.
+ */
+ def updateColumn(
+ batch: String,
+ row: String,
+ dataType: DataType,
+ ordinal: Int,
+ ev: ExprCode,
+ nullable: Boolean): String = {
+ if (nullable) {
+ s"""
+ if (!${ev.isNull}) {
+ ${setValue(batch, row, dataType, ordinal, ev.value)}
+ } else {
+ $batch.column($ordinal).putNull($row);
+ }
+ """
+ } else {
+ s"""${setValue(batch, row, dataType, ordinal, ev.value)};"""
+ }
+ }
+
+ /**
+ * Returns the specialized code to access a value from a column vector for a given `DataType`.
+ */
+ def getValue(batch: String, row: String, dataType: DataType, ordinal: Int): String = {
+ val jt = javaType(dataType)
+ dataType match {
+ case _ if isPrimitiveType(jt) =>
+ s"$batch.column($ordinal).get${primitiveTypeName(jt)}($row)"
+ case t: DecimalType =>
+ s"$batch.column($ordinal).getDecimal($row, ${t.precision}, ${t.scale})"
+ case StringType =>
+ s"$batch.column($ordinal).getUTF8String($row)"
+ case _ =>
+ throw new IllegalArgumentException(s"cannot generate code for unsupported type: $dataType")
+ }
+ }
+
+ /**
* Returns the name used in accessor and setter for a Java primitive type.
*/
def primitiveTypeName(jt: String): String = jt match {