aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
Diffstat (limited to 'sql')
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala32
1 files changed, 29 insertions, 3 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala
index d1ea7cc3e9..ae77f72998 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala
@@ -23,7 +23,7 @@ import org.apache.spark.api.java.{JavaRDD, JavaSparkContext}
import org.apache.spark.api.r.SerDe
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.expressions.{Alias, Expression, NamedExpression}
-import org.apache.spark.sql.types.{DataType, StructType}
+import org.apache.spark.sql.types._
import org.apache.spark.sql.{Column, DataFrame, GroupedData, Row, SQLContext, SaveMode}
private[r] object SQLUtils {
@@ -39,8 +39,34 @@ private[r] object SQLUtils {
arr.toSeq
}
- def createDF(rdd: RDD[Array[Byte]], schemaString: String, sqlContext: SQLContext): DataFrame = {
- val schema = DataType.fromJson(schemaString).asInstanceOf[StructType]
+ def createStructType(fields : Seq[StructField]): StructType = {
+ StructType(fields)
+ }
+
+ def getSQLDataType(dataType: String): DataType = {
+ dataType match {
+ case "byte" => org.apache.spark.sql.types.ByteType
+ case "integer" => org.apache.spark.sql.types.IntegerType
+ case "double" => org.apache.spark.sql.types.DoubleType
+ case "numeric" => org.apache.spark.sql.types.DoubleType
+ case "character" => org.apache.spark.sql.types.StringType
+ case "string" => org.apache.spark.sql.types.StringType
+ case "binary" => org.apache.spark.sql.types.BinaryType
+ case "raw" => org.apache.spark.sql.types.BinaryType
+ case "logical" => org.apache.spark.sql.types.BooleanType
+ case "boolean" => org.apache.spark.sql.types.BooleanType
+ case "timestamp" => org.apache.spark.sql.types.TimestampType
+ case "date" => org.apache.spark.sql.types.DateType
+ case _ => throw new IllegalArgumentException(s"Invaid type $dataType")
+ }
+ }
+
+ def createStructField(name: String, dataType: String, nullable: Boolean): StructField = {
+ val dtObj = getSQLDataType(dataType)
+ StructField(name, dtObj, nullable)
+ }
+
+ def createDF(rdd: RDD[Array[Byte]], schema: StructType, sqlContext: SQLContext): DataFrame = {
val num = schema.fields.size
val rowRDD = rdd.map(bytesToRow)
sqlContext.createDataFrame(rowRDD, schema)