diff options
Diffstat (limited to 'sql')
4 files changed, 55 insertions, 1 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index d3d4c56baf..24d60ea074 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -26,6 +26,7 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.dsl import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.types._ import org.apache.spark.sql.catalyst.optimizer.Optimizer import org.apache.spark.sql.catalyst.plans.logical.{Subquery, LogicalPlan} import org.apache.spark.sql.catalyst.rules.RuleExecutor @@ -241,4 +242,30 @@ class SQLContext(@transient val sparkContext: SparkContext) */ def debugExec() = DebugQuery(executedPlan).execute().collect() } + + /** + * Peek at the first row of the RDD and infer its schema. + * TODO: We only support primitive types, add support for nested types. + */ + private[sql] def inferSchema(rdd: RDD[Map[String, _]]): SchemaRDD = { + val schema = rdd.first.map { case (fieldName, obj) => + val dataType = obj.getClass match { + case c: Class[_] if c == classOf[java.lang.String] => StringType + case c: Class[_] if c == classOf[java.lang.Integer] => IntegerType + case c: Class[_] if c == classOf[java.lang.Long] => LongType + case c: Class[_] if c == classOf[java.lang.Double] => DoubleType + case c: Class[_] if c == classOf[java.lang.Boolean] => BooleanType + case c => throw new Exception(s"Object of type $c cannot be used") + } + AttributeReference(fieldName, dataType, true)() + }.toSeq + + val rowRdd = rdd.mapPartitions { iter => + iter.map { map => + new GenericRow(map.values.toArray.asInstanceOf[Array[Any]]): Row + } + } + new SchemaRDD(this, SparkLogicalPlan(ExistingRdd(schema, rowRdd))) + } + } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala index 91500416ee..a771147f90 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql +import net.razorvine.pickle.Pickler + import org.apache.spark.{Dependency, OneToOneDependency, Partition, TaskContext} import org.apache.spark.annotation.{AlphaComponent, Experimental} import org.apache.spark.rdd.RDD @@ -25,6 +27,8 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.plans.{Inner, JoinType} import org.apache.spark.sql.catalyst.types.BooleanType +import org.apache.spark.api.java.JavaRDD +import java.util.{Map => JMap} /** * :: AlphaComponent :: @@ -308,4 +312,23 @@ class SchemaRDD( /** FOR INTERNAL USE ONLY */ def analyze = sqlContext.analyzer(logicalPlan) + + private[sql] def javaToPython: JavaRDD[Array[Byte]] = { + val fieldNames: Seq[String] = this.queryExecution.analyzed.output.map(_.name) + this.mapPartitions { iter => + val pickle = new Pickler + iter.map { row => + val map: JMap[String, Any] = new java.util.HashMap + // TODO: We place the map in an ArrayList so that the object is pickled to a List[Dict]. + // Ideally we should be able to pickle an object directly into a Python collection so we + // don't have to create an ArrayList every time. + val arr: java.util.ArrayList[Any] = new java.util.ArrayList + row.zip(fieldNames).foreach { case (obj, name) => + map.put(name, obj) + } + arr.add(map) + pickle.dumps(arr) + } + } + } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TestHive.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TestHive.scala index 465e5f146f..444bbfb4dd 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TestHive.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TestHive.scala @@ -261,8 +261,9 @@ class TestHiveContext(sc: SparkContext) extends LocalHiveContext(sc) { testTables.get(name).map(_.commands).getOrElse(sys.error(s"Unknown test table $name")) createCmds.foreach(_()) - if (cacheTables) + if (cacheTables) { cacheTable(name) + } } } diff --git a/sql/hive/src/test/resources/log4j.properties b/sql/hive/src/test/resources/log4j.properties index 5e17e3b596..c07d8fedf1 100644 --- a/sql/hive/src/test/resources/log4j.properties +++ b/sql/hive/src/test/resources/log4j.properties @@ -45,3 +45,6 @@ log4j.logger.org.apache.hadoop.hive.metastore.RetryingHMSHandler=OFF log4j.additivity.hive.ql.metadata.Hive=false log4j.logger.hive.ql.metadata.Hive=OFF +log4j.additivity.org.apache.hadoop.hive.ql.io.RCFile=false +log4j.logger.org.apache.hadoop.hive.ql.io.RCFile=ERROR + |