aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
Diffstat (limited to 'sql')
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala27
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala23
-rw-r--r--sql/hive/src/main/scala/org/apache/spark/sql/hive/TestHive.scala3
-rw-r--r--sql/hive/src/test/resources/log4j.properties3
4 files changed, 55 insertions, 1 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
index d3d4c56baf..24d60ea074 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
@@ -26,6 +26,7 @@ import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.analysis._
import org.apache.spark.sql.catalyst.dsl
import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.types._
import org.apache.spark.sql.catalyst.optimizer.Optimizer
import org.apache.spark.sql.catalyst.plans.logical.{Subquery, LogicalPlan}
import org.apache.spark.sql.catalyst.rules.RuleExecutor
@@ -241,4 +242,30 @@ class SQLContext(@transient val sparkContext: SparkContext)
*/
def debugExec() = DebugQuery(executedPlan).execute().collect()
}
+
+ /**
+ * Peek at the first row of the RDD and infer its schema.
+ * TODO: We only support primitive types, add support for nested types.
+ */
+ private[sql] def inferSchema(rdd: RDD[Map[String, _]]): SchemaRDD = {
+ val schema = rdd.first.map { case (fieldName, obj) =>
+ val dataType = obj.getClass match {
+ case c: Class[_] if c == classOf[java.lang.String] => StringType
+ case c: Class[_] if c == classOf[java.lang.Integer] => IntegerType
+ case c: Class[_] if c == classOf[java.lang.Long] => LongType
+ case c: Class[_] if c == classOf[java.lang.Double] => DoubleType
+ case c: Class[_] if c == classOf[java.lang.Boolean] => BooleanType
+ case c => throw new Exception(s"Object of type $c cannot be used")
+ }
+ AttributeReference(fieldName, dataType, true)()
+ }.toSeq
+
+ val rowRdd = rdd.mapPartitions { iter =>
+ iter.map { map =>
+ new GenericRow(map.values.toArray.asInstanceOf[Array[Any]]): Row
+ }
+ }
+ new SchemaRDD(this, SparkLogicalPlan(ExistingRdd(schema, rowRdd)))
+ }
+
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala
index 91500416ee..a771147f90 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala
@@ -17,6 +17,8 @@
package org.apache.spark.sql
+import net.razorvine.pickle.Pickler
+
import org.apache.spark.{Dependency, OneToOneDependency, Partition, TaskContext}
import org.apache.spark.annotation.{AlphaComponent, Experimental}
import org.apache.spark.rdd.RDD
@@ -25,6 +27,8 @@ import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.plans.{Inner, JoinType}
import org.apache.spark.sql.catalyst.types.BooleanType
+import org.apache.spark.api.java.JavaRDD
+import java.util.{Map => JMap}
/**
* :: AlphaComponent ::
@@ -308,4 +312,23 @@ class SchemaRDD(
/** FOR INTERNAL USE ONLY */
def analyze = sqlContext.analyzer(logicalPlan)
+
+ private[sql] def javaToPython: JavaRDD[Array[Byte]] = {
+ val fieldNames: Seq[String] = this.queryExecution.analyzed.output.map(_.name)
+ this.mapPartitions { iter =>
+ val pickle = new Pickler
+ iter.map { row =>
+ val map: JMap[String, Any] = new java.util.HashMap
+ // TODO: We place the map in an ArrayList so that the object is pickled to a List[Dict].
+ // Ideally we should be able to pickle an object directly into a Python collection so we
+ // don't have to create an ArrayList every time.
+ val arr: java.util.ArrayList[Any] = new java.util.ArrayList
+ row.zip(fieldNames).foreach { case (obj, name) =>
+ map.put(name, obj)
+ }
+ arr.add(map)
+ pickle.dumps(arr)
+ }
+ }
+ }
}
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TestHive.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TestHive.scala
index 465e5f146f..444bbfb4dd 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TestHive.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TestHive.scala
@@ -261,8 +261,9 @@ class TestHiveContext(sc: SparkContext) extends LocalHiveContext(sc) {
testTables.get(name).map(_.commands).getOrElse(sys.error(s"Unknown test table $name"))
createCmds.foreach(_())
- if (cacheTables)
+ if (cacheTables) {
cacheTable(name)
+ }
}
}
diff --git a/sql/hive/src/test/resources/log4j.properties b/sql/hive/src/test/resources/log4j.properties
index 5e17e3b596..c07d8fedf1 100644
--- a/sql/hive/src/test/resources/log4j.properties
+++ b/sql/hive/src/test/resources/log4j.properties
@@ -45,3 +45,6 @@ log4j.logger.org.apache.hadoop.hive.metastore.RetryingHMSHandler=OFF
log4j.additivity.hive.ql.metadata.Hive=false
log4j.logger.hive.ql.metadata.Hive=OFF
+log4j.additivity.org.apache.hadoop.hive.ql.io.RCFile=false
+log4j.logger.org.apache.hadoop.hive.ql.io.RCFile=ERROR
+