aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorDavies Liu <davies.liu@gmail.com>2014-09-12 19:05:39 -0700
committerJosh Rosen <joshrosen@apache.org>2014-09-12 19:28:45 -0700
commit9c06c723018d4ef96ff31eb947226a6273ed8080 (patch)
tree3833a3fe3bc848c0ed88c2017c1e21c02d8f780c
parent6cbf83c05c7a073d4df81b59a1663fea38ce65f6 (diff)
downloadspark-9c06c723018d4ef96ff31eb947226a6273ed8080.tar.gz
spark-9c06c723018d4ef96ff31eb947226a6273ed8080.tar.bz2
spark-9c06c723018d4ef96ff31eb947226a6273ed8080.zip
[SPARK-3500] [SQL] use JavaSchemaRDD as SchemaRDD._jschema_rdd
Currently, SchemaRDD._jschema_rdd is SchemaRDD, the Scala API (coalesce(), repartition()) can not been called in Python easily, there is no way to specify the implicit parameter `ord`. The _jrdd is an JavaRDD, so _jschema_rdd should also be JavaSchemaRDD. In this patch, change _schema_rdd to JavaSchemaRDD, also added an assert for it. If some methods are missing from JavaSchemaRDD, then it's called by _schema_rdd.baseSchemaRDD().xxx(). BTW, Do we need JavaSQLContext? Author: Davies Liu <davies.liu@gmail.com> Closes #2369 from davies/fix_schemardd and squashes the following commits: abee159 [Davies Liu] use JavaSchemaRDD as SchemaRDD._jschema_rdd (cherry picked from commit 885d1621bc06bc1f009c9707c3452eac26baf828) Signed-off-by: Josh Rosen <joshrosen@apache.org> Conflicts: python/pyspark/tests.py
-rw-r--r--python/pyspark/sql.py38
-rw-r--r--python/pyspark/tests.py37
2 files changed, 55 insertions, 20 deletions
diff --git a/python/pyspark/sql.py b/python/pyspark/sql.py
index 0ff6a548a8..07b39c92b8 100644
--- a/python/pyspark/sql.py
+++ b/python/pyspark/sql.py
@@ -1121,7 +1121,7 @@ class SQLContext:
batched = isinstance(rdd._jrdd_deserializer, BatchedSerializer)
jrdd = self._pythonToJava(rdd._jrdd, batched)
srdd = self._ssql_ctx.applySchemaToPythonRDD(jrdd.rdd(), str(schema))
- return SchemaRDD(srdd, self)
+ return SchemaRDD(srdd.toJavaSchemaRDD(), self)
def registerRDDAsTable(self, rdd, tableName):
"""Registers the given RDD as a temporary table in the catalog.
@@ -1133,8 +1133,8 @@ class SQLContext:
>>> sqlCtx.registerRDDAsTable(srdd, "table1")
"""
if (rdd.__class__ is SchemaRDD):
- jschema_rdd = rdd._jschema_rdd
- self._ssql_ctx.registerRDDAsTable(jschema_rdd, tableName)
+ srdd = rdd._jschema_rdd.baseSchemaRDD()
+ self._ssql_ctx.registerRDDAsTable(srdd, tableName)
else:
raise ValueError("Can only register SchemaRDD as table")
@@ -1150,7 +1150,7 @@ class SQLContext:
>>> sorted(srdd.collect()) == sorted(srdd2.collect())
True
"""
- jschema_rdd = self._ssql_ctx.parquetFile(path)
+ jschema_rdd = self._ssql_ctx.parquetFile(path).toJavaSchemaRDD()
return SchemaRDD(jschema_rdd, self)
def jsonFile(self, path, schema=None):
@@ -1206,11 +1206,11 @@ class SQLContext:
[Row(f1=u'row1', f2=None, f3=None)...Row(f1=u'row3', f2=[], f3=None)]
"""
if schema is None:
- jschema_rdd = self._ssql_ctx.jsonFile(path)
+ srdd = self._ssql_ctx.jsonFile(path)
else:
scala_datatype = self._ssql_ctx.parseDataType(str(schema))
- jschema_rdd = self._ssql_ctx.jsonFile(path, scala_datatype)
- return SchemaRDD(jschema_rdd, self)
+ srdd = self._ssql_ctx.jsonFile(path, scala_datatype)
+ return SchemaRDD(srdd.toJavaSchemaRDD(), self)
def jsonRDD(self, rdd, schema=None):
"""Loads an RDD storing one JSON object per string as a L{SchemaRDD}.
@@ -1274,11 +1274,11 @@ class SQLContext:
keyed._bypass_serializer = True
jrdd = keyed._jrdd.map(self._jvm.BytesToString())
if schema is None:
- jschema_rdd = self._ssql_ctx.jsonRDD(jrdd.rdd())
+ srdd = self._ssql_ctx.jsonRDD(jrdd.rdd())
else:
scala_datatype = self._ssql_ctx.parseDataType(str(schema))
- jschema_rdd = self._ssql_ctx.jsonRDD(jrdd.rdd(), scala_datatype)
- return SchemaRDD(jschema_rdd, self)
+ srdd = self._ssql_ctx.jsonRDD(jrdd.rdd(), scala_datatype)
+ return SchemaRDD(srdd.toJavaSchemaRDD(), self)
def sql(self, sqlQuery):
"""Return a L{SchemaRDD} representing the result of the given query.
@@ -1289,7 +1289,7 @@ class SQLContext:
>>> srdd2.collect()
[Row(f1=1, f2=u'row1'), Row(f1=2, f2=u'row2'), Row(f1=3, f2=u'row3')]
"""
- return SchemaRDD(self._ssql_ctx.sql(sqlQuery), self)
+ return SchemaRDD(self._ssql_ctx.sql(sqlQuery).toJavaSchemaRDD(), self)
def table(self, tableName):
"""Returns the specified table as a L{SchemaRDD}.
@@ -1300,7 +1300,7 @@ class SQLContext:
>>> sorted(srdd.collect()) == sorted(srdd2.collect())
True
"""
- return SchemaRDD(self._ssql_ctx.table(tableName), self)
+ return SchemaRDD(self._ssql_ctx.table(tableName).toJavaSchemaRDD(), self)
def cacheTable(self, tableName):
"""Caches the specified table in-memory."""
@@ -1352,7 +1352,7 @@ class HiveContext(SQLContext):
warnings.warn("hiveql() is deprecated as the sql function now parses using HiveQL by" +
"default. The SQL dialect for parsing can be set using 'spark.sql.dialect'",
DeprecationWarning)
- return SchemaRDD(self._ssql_ctx.hiveql(hqlQuery), self)
+ return SchemaRDD(self._ssql_ctx.hiveql(hqlQuery).toJavaSchemaRDD(), self)
def hql(self, hqlQuery):
"""
@@ -1508,6 +1508,8 @@ class SchemaRDD(RDD):
def __init__(self, jschema_rdd, sql_ctx):
self.sql_ctx = sql_ctx
self._sc = sql_ctx._sc
+ clsName = jschema_rdd.getClass().getName()
+ assert clsName.endswith("JavaSchemaRDD"), "jschema_rdd must be JavaSchemaRDD"
self._jschema_rdd = jschema_rdd
self.is_cached = False
@@ -1524,7 +1526,7 @@ class SchemaRDD(RDD):
L{pyspark.rdd.RDD} super class (map, filter, etc.).
"""
if not hasattr(self, '_lazy_jrdd'):
- self._lazy_jrdd = self._jschema_rdd.javaToPython()
+ self._lazy_jrdd = self._jschema_rdd.baseSchemaRDD().javaToPython()
return self._lazy_jrdd
@property
@@ -1580,7 +1582,7 @@ class SchemaRDD(RDD):
def schema(self):
"""Returns the schema of this SchemaRDD (represented by
a L{StructType})."""
- return _parse_datatype_string(self._jschema_rdd.schema().toString())
+ return _parse_datatype_string(self._jschema_rdd.baseSchemaRDD().schema().toString())
def schemaString(self):
"""Returns the output schema in the tree format."""
@@ -1631,8 +1633,6 @@ class SchemaRDD(RDD):
rdd = RDD(self._jrdd, self._sc, self._jrdd_deserializer)
schema = self.schema()
- import pickle
- pickle.loads(pickle.dumps(schema))
def applySchema(_, it):
cls = _create_cls(schema)
@@ -1669,10 +1669,8 @@ class SchemaRDD(RDD):
def getCheckpointFile(self):
checkpointFile = self._jschema_rdd.getCheckpointFile()
- if checkpointFile.isDefined():
+ if checkpointFile.isPresent():
return checkpointFile.get()
- else:
- return None
def coalesce(self, numPartitions, shuffle=False):
rdd = self._jschema_rdd.coalesce(numPartitions, shuffle)
diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py
index 1db922f513..8f0a351b6b 100644
--- a/python/pyspark/tests.py
+++ b/python/pyspark/tests.py
@@ -41,6 +41,8 @@ from pyspark.context import SparkContext
from pyspark.files import SparkFiles
from pyspark.serializers import read_int, BatchedSerializer, MarshalSerializer, PickleSerializer
from pyspark.shuffle import Aggregator, InMemoryMerger, ExternalMerger
+from pyspark.storagelevel import StorageLevel
+from pyspark.sql import SQLContext
_have_scipy = False
_have_numpy = False
@@ -469,6 +471,41 @@ class TestRDDFunctions(PySparkTestCase):
self.assertRaises(TypeError, lambda: rdd.histogram(2))
+class TestSQL(PySparkTestCase):
+
+ def setUp(self):
+ PySparkTestCase.setUp(self)
+ self.sqlCtx = SQLContext(self.sc)
+
+ def test_basic_functions(self):
+ rdd = self.sc.parallelize(['{"foo":"bar"}', '{"foo":"baz"}'])
+ srdd = self.sqlCtx.jsonRDD(rdd)
+ srdd.count()
+ srdd.collect()
+ srdd.schemaString()
+ srdd.schema()
+
+ # cache and checkpoint
+ self.assertFalse(srdd.is_cached)
+ srdd.persist(StorageLevel.MEMORY_ONLY_SER)
+ srdd.unpersist()
+ srdd.cache()
+ self.assertTrue(srdd.is_cached)
+ self.assertFalse(srdd.isCheckpointed())
+ self.assertEqual(None, srdd.getCheckpointFile())
+
+ srdd = srdd.coalesce(2, True)
+ srdd = srdd.repartition(3)
+ srdd = srdd.distinct()
+ srdd.intersection(srdd)
+ self.assertEqual(2, srdd.count())
+
+ srdd.registerTempTable("temp")
+ srdd = self.sqlCtx.sql("select foo from temp")
+ srdd.count()
+ srdd.collect()
+
+
class TestIO(PySparkTestCase):
def test_stdout_redirection(self):