aboutsummaryrefslogtreecommitdiff
path: root/python
diff options
context:
space:
mode:
authorReynold Xin <rxin@databricks.com>2015-01-27 16:08:24 -0800
committerReynold Xin <rxin@databricks.com>2015-01-27 16:08:24 -0800
commit119f45d61d7b48d376cca05e1b4f0c7fcf65bfa8 (patch)
tree714df6362313e93bee0e9dba2f84b3ba1697e555 /python
parentb1b35ca2e440df40b253bf967bb93705d355c1c0 (diff)
downloadspark-119f45d61d7b48d376cca05e1b4f0c7fcf65bfa8.tar.gz
spark-119f45d61d7b48d376cca05e1b4f0c7fcf65bfa8.tar.bz2
spark-119f45d61d7b48d376cca05e1b4f0c7fcf65bfa8.zip
[SPARK-5097][SQL] DataFrame
This pull request redesigns the existing Spark SQL dsl, which already provides data frame like functionalities. TODOs: With the exception of Python support, other tasks can be done in separate, follow-up PRs. - [ ] Audit of the API - [ ] Documentation - [ ] More test cases to cover the new API - [x] Python support - [ ] Type alias SchemaRDD Author: Reynold Xin <rxin@databricks.com> Author: Davies Liu <davies@databricks.com> Closes #4173 from rxin/df1 and squashes the following commits: 0a1a73b [Reynold Xin] Merge branch 'df1' of github.com:rxin/spark into df1 23b4427 [Reynold Xin] Mima. 828f70d [Reynold Xin] Merge pull request #7 from davies/df 257b9e6 [Davies Liu] add repartition 6bf2b73 [Davies Liu] fix collect with UDT and tests e971078 [Reynold Xin] Missing quotes. b9306b4 [Reynold Xin] Remove removeColumn/updateColumn for now. a728bf2 [Reynold Xin] Example rename. e8aa3d3 [Reynold Xin] groupby -> groupBy. 9662c9e [Davies Liu] improve DataFrame Python API 4ae51ea [Davies Liu] python API for dataframe 1e5e454 [Reynold Xin] Fixed a bug with symbol conversion. 2ca74db [Reynold Xin] Couple minor fixes. ea98ea1 [Reynold Xin] Documentation & literal expressions. 2b22684 [Reynold Xin] Got rid of IntelliJ problems. 02bbfbc [Reynold Xin] Tightening imports. ffbce66 [Reynold Xin] Fixed compilation error. 59b6d8b [Reynold Xin] Style violation. b85edfb [Reynold Xin] ALS. 8c37f0a [Reynold Xin] Made MLlib and examples compile 6d53134 [Reynold Xin] Hive module. d35efd5 [Reynold Xin] Fixed compilation error. ce4a5d2 [Reynold Xin] Fixed test cases in SQL except ParquetIOSuite. 66d5ef1 [Reynold Xin] SQLContext minor patch. c9bcdc0 [Reynold Xin] Checkpoint: SQL module compiles!
Diffstat (limited to 'python')
-rw-r--r--python/pyspark/java_gateway.py7
-rw-r--r--python/pyspark/sql.py967
-rw-r--r--python/pyspark/tests.py155
3 files changed, 793 insertions, 336 deletions
diff --git a/python/pyspark/java_gateway.py b/python/pyspark/java_gateway.py
index a975dc19cb..a0a028446d 100644
--- a/python/pyspark/java_gateway.py
+++ b/python/pyspark/java_gateway.py
@@ -111,10 +111,9 @@ def launch_gateway():
java_import(gateway.jvm, "org.apache.spark.api.java.*")
java_import(gateway.jvm, "org.apache.spark.api.python.*")
java_import(gateway.jvm, "org.apache.spark.mllib.api.python.*")
- java_import(gateway.jvm, "org.apache.spark.sql.SQLContext")
- java_import(gateway.jvm, "org.apache.spark.sql.hive.HiveContext")
- java_import(gateway.jvm, "org.apache.spark.sql.hive.LocalHiveContext")
- java_import(gateway.jvm, "org.apache.spark.sql.hive.TestHiveContext")
+ # TODO(davies): move into sql
+ java_import(gateway.jvm, "org.apache.spark.sql.*")
+ java_import(gateway.jvm, "org.apache.spark.sql.hive.*")
java_import(gateway.jvm, "scala.Tuple2")
return gateway
diff --git a/python/pyspark/sql.py b/python/pyspark/sql.py
index 1990323249..7d7550c854 100644
--- a/python/pyspark/sql.py
+++ b/python/pyspark/sql.py
@@ -20,15 +20,19 @@ public classes of Spark SQL:
- L{SQLContext}
Main entry point for SQL functionality.
- - L{SchemaRDD}
+ - L{DataFrame}
A Resilient Distributed Dataset (RDD) with Schema information for the data contained. In
- addition to normal RDD operations, SchemaRDDs also support SQL.
+ addition to normal RDD operations, DataFrames also support SQL.
+ - L{GroupedDataFrame}
+ - L{Column}
+ Column is a DataFrame with a single column.
- L{Row}
A Row of data returned by a Spark SQL query.
- L{HiveContext}
Main entry point for accessing data stored in Apache Hive..
"""
+import sys
import itertools
import decimal
import datetime
@@ -36,6 +40,9 @@ import keyword
import warnings
import json
import re
+import random
+import os
+from tempfile import NamedTemporaryFile
from array import array
from operator import itemgetter
from itertools import imap
@@ -43,6 +50,7 @@ from itertools import imap
from py4j.protocol import Py4JError
from py4j.java_collections import ListConverter, MapConverter
+from pyspark.context import SparkContext
from pyspark.rdd import RDD
from pyspark.serializers import BatchedSerializer, AutoBatchedSerializer, PickleSerializer, \
CloudPickleSerializer, UTF8Deserializer
@@ -54,7 +62,8 @@ __all__ = [
"StringType", "BinaryType", "BooleanType", "DateType", "TimestampType", "DecimalType",
"DoubleType", "FloatType", "ByteType", "IntegerType", "LongType",
"ShortType", "ArrayType", "MapType", "StructField", "StructType",
- "SQLContext", "HiveContext", "SchemaRDD", "Row"]
+ "SQLContext", "HiveContext", "DataFrame", "GroupedDataFrame", "Column", "Row",
+ "SchemaRDD"]
class DataType(object):
@@ -1171,7 +1180,7 @@ def _create_cls(dataType):
class Row(tuple):
- """ Row in SchemaRDD """
+ """ Row in DataFrame """
__DATATYPE__ = dataType
__FIELDS__ = tuple(f.name for f in dataType.fields)
__slots__ = ()
@@ -1198,7 +1207,7 @@ class SQLContext(object):
"""Main entry point for Spark SQL functionality.
- A SQLContext can be used create L{SchemaRDD}, register L{SchemaRDD} as
+ A SQLContext can be used create L{DataFrame}, register L{DataFrame} as
tables, execute SQL over tables, cache tables, and read parquet files.
"""
@@ -1209,8 +1218,8 @@ class SQLContext(object):
:param sqlContext: An optional JVM Scala SQLContext. If set, we do not instatiate a new
SQLContext in the JVM, instead we make all calls to this object.
- >>> srdd = sqlCtx.inferSchema(rdd)
- >>> sqlCtx.inferSchema(srdd) # doctest: +IGNORE_EXCEPTION_DETAIL
+ >>> df = sqlCtx.inferSchema(rdd)
+ >>> sqlCtx.inferSchema(df) # doctest: +IGNORE_EXCEPTION_DETAIL
Traceback (most recent call last):
...
TypeError:...
@@ -1225,12 +1234,12 @@ class SQLContext(object):
>>> allTypes = sc.parallelize([Row(i=1, s="string", d=1.0, l=1L,
... b=True, list=[1, 2, 3], dict={"s": 0}, row=Row(a=1),
... time=datetime(2014, 8, 1, 14, 1, 5))])
- >>> srdd = sqlCtx.inferSchema(allTypes)
- >>> srdd.registerTempTable("allTypes")
+ >>> df = sqlCtx.inferSchema(allTypes)
+ >>> df.registerTempTable("allTypes")
>>> sqlCtx.sql('select i+1, d+1, not b, list[1], dict["s"], time, row.a '
... 'from allTypes where b and i > 0').collect()
[Row(c0=2, c1=2.0, c2=False, c3=2, c4=0...8, 1, 14, 1, 5), a=1)]
- >>> srdd.map(lambda x: (x.i, x.s, x.d, x.l, x.b, x.time,
+ >>> df.map(lambda x: (x.i, x.s, x.d, x.l, x.b, x.time,
... x.row.a, x.list)).collect()
[(1, u'string', 1.0, 1, True, ...(2014, 8, 1, 14, 1, 5), 1, [1, 2, 3])]
"""
@@ -1309,23 +1318,23 @@ class SQLContext(object):
... [Row(field1=1, field2="row1"),
... Row(field1=2, field2="row2"),
... Row(field1=3, field2="row3")])
- >>> srdd = sqlCtx.inferSchema(rdd)
- >>> srdd.collect()[0]
+ >>> df = sqlCtx.inferSchema(rdd)
+ >>> df.collect()[0]
Row(field1=1, field2=u'row1')
>>> NestedRow = Row("f1", "f2")
>>> nestedRdd1 = sc.parallelize([
... NestedRow(array('i', [1, 2]), {"row1": 1.0}),
... NestedRow(array('i', [2, 3]), {"row2": 2.0})])
- >>> srdd = sqlCtx.inferSchema(nestedRdd1)
- >>> srdd.collect()
+ >>> df = sqlCtx.inferSchema(nestedRdd1)
+ >>> df.collect()
[Row(f1=[1, 2], f2={u'row1': 1.0}), ..., f2={u'row2': 2.0})]
>>> nestedRdd2 = sc.parallelize([
... NestedRow([[1, 2], [2, 3]], [1, 2]),
... NestedRow([[2, 3], [3, 4]], [2, 3])])
- >>> srdd = sqlCtx.inferSchema(nestedRdd2)
- >>> srdd.collect()
+ >>> df = sqlCtx.inferSchema(nestedRdd2)
+ >>> df.collect()
[Row(f1=[[1, 2], [2, 3]], f2=[1, 2]), ..., f2=[2, 3])]
>>> from collections import namedtuple
@@ -1334,13 +1343,13 @@ class SQLContext(object):
... [CustomRow(field1=1, field2="row1"),
... CustomRow(field1=2, field2="row2"),
... CustomRow(field1=3, field2="row3")])
- >>> srdd = sqlCtx.inferSchema(rdd)
- >>> srdd.collect()[0]
+ >>> df = sqlCtx.inferSchema(rdd)
+ >>> df.collect()[0]
Row(field1=1, field2=u'row1')
"""
- if isinstance(rdd, SchemaRDD):
- raise TypeError("Cannot apply schema to SchemaRDD")
+ if isinstance(rdd, DataFrame):
+ raise TypeError("Cannot apply schema to DataFrame")
first = rdd.first()
if not first:
@@ -1384,10 +1393,10 @@ class SQLContext(object):
>>> rdd2 = sc.parallelize([(1, "row1"), (2, "row2"), (3, "row3")])
>>> schema = StructType([StructField("field1", IntegerType(), False),
... StructField("field2", StringType(), False)])
- >>> srdd = sqlCtx.applySchema(rdd2, schema)
- >>> sqlCtx.registerRDDAsTable(srdd, "table1")
- >>> srdd2 = sqlCtx.sql("SELECT * from table1")
- >>> srdd2.collect()
+ >>> df = sqlCtx.applySchema(rdd2, schema)
+ >>> sqlCtx.registerRDDAsTable(df, "table1")
+ >>> df2 = sqlCtx.sql("SELECT * from table1")
+ >>> df2.collect()
[Row(field1=1, field2=u'row1'),..., Row(field1=3, field2=u'row3')]
>>> from datetime import date, datetime
@@ -1410,15 +1419,15 @@ class SQLContext(object):
... StructType([StructField("b", ShortType(), False)]), False),
... StructField("list", ArrayType(ByteType(), False), False),
... StructField("null", DoubleType(), True)])
- >>> srdd = sqlCtx.applySchema(rdd, schema)
- >>> results = srdd.map(
+ >>> df = sqlCtx.applySchema(rdd, schema)
+ >>> results = df.map(
... lambda x: (x.byte1, x.byte2, x.short1, x.short2, x.int, x.float, x.date,
... x.time, x.map["a"], x.struct.b, x.list, x.null))
>>> results.collect()[0] # doctest: +NORMALIZE_WHITESPACE
(127, -128, -32768, 32767, 2147483647, 1.0, datetime.date(2010, 1, 1),
datetime.datetime(2010, 1, 1, 1, 1, 1), 1, 2, [1, 2, 3], None)
- >>> srdd.registerTempTable("table2")
+ >>> df.registerTempTable("table2")
>>> sqlCtx.sql(
... "SELECT byte1 - 1 AS byte1, byte2 + 1 AS byte2, " +
... "short1 + 1 AS short1, short2 - 1 AS short2, int - 1 AS int, " +
@@ -1431,13 +1440,13 @@ class SQLContext(object):
>>> abstract = "byte short float time map{} struct(b) list[]"
>>> schema = _parse_schema_abstract(abstract)
>>> typedSchema = _infer_schema_type(rdd.first(), schema)
- >>> srdd = sqlCtx.applySchema(rdd, typedSchema)
- >>> srdd.collect()
+ >>> df = sqlCtx.applySchema(rdd, typedSchema)
+ >>> df.collect()
[Row(byte=127, short=-32768, float=1.0, time=..., list=[1, 2, 3])]
"""
- if isinstance(rdd, SchemaRDD):
- raise TypeError("Cannot apply schema to SchemaRDD")
+ if isinstance(rdd, DataFrame):
+ raise TypeError("Cannot apply schema to DataFrame")
if not isinstance(schema, StructType):
raise TypeError("schema should be StructType")
@@ -1457,8 +1466,8 @@ class SQLContext(object):
rdd = rdd.map(converter)
jrdd = self._jvm.SerDeUtil.toJavaArray(rdd._to_java_object_rdd())
- srdd = self._ssql_ctx.applySchemaToPythonRDD(jrdd.rdd(), schema.json())
- return SchemaRDD(srdd, self)
+ df = self._ssql_ctx.applySchemaToPythonRDD(jrdd.rdd(), schema.json())
+ return DataFrame(df, self)
def registerRDDAsTable(self, rdd, tableName):
"""Registers the given RDD as a temporary table in the catalog.
@@ -1466,34 +1475,34 @@ class SQLContext(object):
Temporary tables exist only during the lifetime of this instance of
SQLContext.
- >>> srdd = sqlCtx.inferSchema(rdd)
- >>> sqlCtx.registerRDDAsTable(srdd, "table1")
+ >>> df = sqlCtx.inferSchema(rdd)
+ >>> sqlCtx.registerRDDAsTable(df, "table1")
"""
- if (rdd.__class__ is SchemaRDD):
- srdd = rdd._jschema_rdd.baseSchemaRDD()
- self._ssql_ctx.registerRDDAsTable(srdd, tableName)
+ if (rdd.__class__ is DataFrame):
+ df = rdd._jdf
+ self._ssql_ctx.registerRDDAsTable(df, tableName)
else:
- raise ValueError("Can only register SchemaRDD as table")
+ raise ValueError("Can only register DataFrame as table")
def parquetFile(self, path):
- """Loads a Parquet file, returning the result as a L{SchemaRDD}.
+ """Loads a Parquet file, returning the result as a L{DataFrame}.
>>> import tempfile, shutil
>>> parquetFile = tempfile.mkdtemp()
>>> shutil.rmtree(parquetFile)
- >>> srdd = sqlCtx.inferSchema(rdd)
- >>> srdd.saveAsParquetFile(parquetFile)
- >>> srdd2 = sqlCtx.parquetFile(parquetFile)
- >>> sorted(srdd.collect()) == sorted(srdd2.collect())
+ >>> df = sqlCtx.inferSchema(rdd)
+ >>> df.saveAsParquetFile(parquetFile)
+ >>> df2 = sqlCtx.parquetFile(parquetFile)
+ >>> sorted(df.collect()) == sorted(df2.collect())
True
"""
- jschema_rdd = self._ssql_ctx.parquetFile(path)
- return SchemaRDD(jschema_rdd, self)
+ jdf = self._ssql_ctx.parquetFile(path)
+ return DataFrame(jdf, self)
def jsonFile(self, path, schema=None, samplingRatio=1.0):
"""
Loads a text file storing one JSON object per line as a
- L{SchemaRDD}.
+ L{DataFrame}.
If the schema is provided, applies the given schema to this
JSON dataset.
@@ -1508,23 +1517,23 @@ class SQLContext(object):
>>> for json in jsonStrings:
... print>>ofn, json
>>> ofn.close()
- >>> srdd1 = sqlCtx.jsonFile(jsonFile)
- >>> sqlCtx.registerRDDAsTable(srdd1, "table1")
- >>> srdd2 = sqlCtx.sql(
+ >>> df1 = sqlCtx.jsonFile(jsonFile)
+ >>> sqlCtx.registerRDDAsTable(df1, "table1")
+ >>> df2 = sqlCtx.sql(
... "SELECT field1 AS f1, field2 as f2, field3 as f3, "
... "field6 as f4 from table1")
- >>> for r in srdd2.collect():
+ >>> for r in df2.collect():
... print r
Row(f1=1, f2=u'row1', f3=Row(field4=11, field5=None), f4=None)
Row(f1=2, f2=None, f3=Row(field4=22,..., f4=[Row(field7=u'row2')])
Row(f1=None, f2=u'row3', f3=Row(field4=33, field5=[]), f4=None)
- >>> srdd3 = sqlCtx.jsonFile(jsonFile, srdd1.schema())
- >>> sqlCtx.registerRDDAsTable(srdd3, "table2")
- >>> srdd4 = sqlCtx.sql(
+ >>> df3 = sqlCtx.jsonFile(jsonFile, df1.schema())
+ >>> sqlCtx.registerRDDAsTable(df3, "table2")
+ >>> df4 = sqlCtx.sql(
... "SELECT field1 AS f1, field2 as f2, field3 as f3, "
... "field6 as f4 from table2")
- >>> for r in srdd4.collect():
+ >>> for r in df4.collect():
... print r
Row(f1=1, f2=u'row1', f3=Row(field4=11, field5=None), f4=None)
Row(f1=2, f2=None, f3=Row(field4=22,..., f4=[Row(field7=u'row2')])
@@ -1536,23 +1545,23 @@ class SQLContext(object):
... StructType([
... StructField("field5",
... ArrayType(IntegerType(), False), True)]), False)])
- >>> srdd5 = sqlCtx.jsonFile(jsonFile, schema)
- >>> sqlCtx.registerRDDAsTable(srdd5, "table3")
- >>> srdd6 = sqlCtx.sql(
+ >>> df5 = sqlCtx.jsonFile(jsonFile, schema)
+ >>> sqlCtx.registerRDDAsTable(df5, "table3")
+ >>> df6 = sqlCtx.sql(
... "SELECT field2 AS f1, field3.field5 as f2, "
... "field3.field5[0] as f3 from table3")
- >>> srdd6.collect()
+ >>> df6.collect()
[Row(f1=u'row1', f2=None, f3=None)...Row(f1=u'row3', f2=[], f3=None)]
"""
if schema is None:
- srdd = self._ssql_ctx.jsonFile(path, samplingRatio)
+ df = self._ssql_ctx.jsonFile(path, samplingRatio)
else:
scala_datatype = self._ssql_ctx.parseDataType(schema.json())
- srdd = self._ssql_ctx.jsonFile(path, scala_datatype)
- return SchemaRDD(srdd, self)
+ df = self._ssql_ctx.jsonFile(path, scala_datatype)
+ return DataFrame(df, self)
def jsonRDD(self, rdd, schema=None, samplingRatio=1.0):
- """Loads an RDD storing one JSON object per string as a L{SchemaRDD}.
+ """Loads an RDD storing one JSON object per string as a L{DataFrame}.
If the schema is provided, applies the given schema to this
JSON dataset.
@@ -1560,23 +1569,23 @@ class SQLContext(object):
Otherwise, it samples the dataset with ratio `samplingRatio` to
determine the schema.
- >>> srdd1 = sqlCtx.jsonRDD(json)
- >>> sqlCtx.registerRDDAsTable(srdd1, "table1")
- >>> srdd2 = sqlCtx.sql(
+ >>> df1 = sqlCtx.jsonRDD(json)
+ >>> sqlCtx.registerRDDAsTable(df1, "table1")
+ >>> df2 = sqlCtx.sql(
... "SELECT field1 AS f1, field2 as f2, field3 as f3, "
... "field6 as f4 from table1")
- >>> for r in srdd2.collect():
+ >>> for r in df2.collect():
... print r
Row(f1=1, f2=u'row1', f3=Row(field4=11, field5=None), f4=None)
Row(f1=2, f2=None, f3=Row(field4=22..., f4=[Row(field7=u'row2')])
Row(f1=None, f2=u'row3', f3=Row(field4=33, field5=[]), f4=None)
- >>> srdd3 = sqlCtx.jsonRDD(json, srdd1.schema())
- >>> sqlCtx.registerRDDAsTable(srdd3, "table2")
- >>> srdd4 = sqlCtx.sql(
+ >>> df3 = sqlCtx.jsonRDD(json, df1.schema())
+ >>> sqlCtx.registerRDDAsTable(df3, "table2")
+ >>> df4 = sqlCtx.sql(
... "SELECT field1 AS f1, field2 as f2, field3 as f3, "
... "field6 as f4 from table2")
- >>> for r in srdd4.collect():
+ >>> for r in df4.collect():
... print r
Row(f1=1, f2=u'row1', f3=Row(field4=11, field5=None), f4=None)
Row(f1=2, f2=None, f3=Row(field4=22..., f4=[Row(field7=u'row2')])
@@ -1588,12 +1597,12 @@ class SQLContext(object):
... StructType([
... StructField("field5",
... ArrayType(IntegerType(), False), True)]), False)])
- >>> srdd5 = sqlCtx.jsonRDD(json, schema)
- >>> sqlCtx.registerRDDAsTable(srdd5, "table3")
- >>> srdd6 = sqlCtx.sql(
+ >>> df5 = sqlCtx.jsonRDD(json, schema)
+ >>> sqlCtx.registerRDDAsTable(df5, "table3")
+ >>> df6 = sqlCtx.sql(
... "SELECT field2 AS f1, field3.field5 as f2, "
... "field3.field5[0] as f3 from table3")
- >>> srdd6.collect()
+ >>> df6.collect()
[Row(f1=u'row1', f2=None,...Row(f1=u'row3', f2=[], f3=None)]
>>> sqlCtx.jsonRDD(sc.parallelize(['{}',
@@ -1615,33 +1624,33 @@ class SQLContext(object):
keyed._bypass_serializer = True
jrdd = keyed._jrdd.map(self._jvm.BytesToString())
if schema is None:
- srdd = self._ssql_ctx.jsonRDD(jrdd.rdd(), samplingRatio)
+ df = self._ssql_ctx.jsonRDD(jrdd.rdd(), samplingRatio)
else:
scala_datatype = self._ssql_ctx.parseDataType(schema.json())
- srdd = self._ssql_ctx.jsonRDD(jrdd.rdd(), scala_datatype)
- return SchemaRDD(srdd, self)
+ df = self._ssql_ctx.jsonRDD(jrdd.rdd(), scala_datatype)
+ return DataFrame(df, self)
def sql(self, sqlQuery):
- """Return a L{SchemaRDD} representing the result of the given query.
+ """Return a L{DataFrame} representing the result of the given query.
- >>> srdd = sqlCtx.inferSchema(rdd)
- >>> sqlCtx.registerRDDAsTable(srdd, "table1")
- >>> srdd2 = sqlCtx.sql("SELECT field1 AS f1, field2 as f2 from table1")
- >>> srdd2.collect()
+ >>> df = sqlCtx.inferSchema(rdd)
+ >>> sqlCtx.registerRDDAsTable(df, "table1")
+ >>> df2 = sqlCtx.sql("SELECT field1 AS f1, field2 as f2 from table1")
+ >>> df2.collect()
[Row(f1=1, f2=u'row1'), Row(f1=2, f2=u'row2'), Row(f1=3, f2=u'row3')]
"""
- return SchemaRDD(self._ssql_ctx.sql(sqlQuery), self)
+ return DataFrame(self._ssql_ctx.sql(sqlQuery), self)
def table(self, tableName):
- """Returns the specified table as a L{SchemaRDD}.
+ """Returns the specified table as a L{DataFrame}.
- >>> srdd = sqlCtx.inferSchema(rdd)
- >>> sqlCtx.registerRDDAsTable(srdd, "table1")
- >>> srdd2 = sqlCtx.table("table1")
- >>> sorted(srdd.collect()) == sorted(srdd2.collect())
+ >>> df = sqlCtx.inferSchema(rdd)
+ >>> sqlCtx.registerRDDAsTable(df, "table1")
+ >>> df2 = sqlCtx.table("table1")
+ >>> sorted(df.collect()) == sorted(df2.collect())
True
"""
- return SchemaRDD(self._ssql_ctx.table(tableName), self)
+ return DataFrame(self._ssql_ctx.table(tableName), self)
def cacheTable(self, tableName):
"""Caches the specified table in-memory."""
@@ -1707,7 +1716,7 @@ def _create_row(fields, values):
class Row(tuple):
"""
- A row in L{SchemaRDD}. The fields in it can be accessed like attributes.
+ A row in L{DataFrame}. The fields in it can be accessed like attributes.
Row can be used to create a row object by using named arguments,
the fields will be sorted by names.
@@ -1799,111 +1808,119 @@ def inherit_doc(cls):
return cls
-@inherit_doc
-class SchemaRDD(RDD):
+class DataFrame(object):
- """An RDD of L{Row} objects that has an associated schema.
+ """A collection of rows that have the same columns.
- The underlying JVM object is a SchemaRDD, not a PythonRDD, so we can
- utilize the relational query api exposed by Spark SQL.
+ A :class:`DataFrame` is equivalent to a relational table in Spark SQL,
+ and can be created using various functions in :class:`SQLContext`::
- For normal L{pyspark.rdd.RDD} operations (map, count, etc.) the
- L{SchemaRDD} is not operated on directly, as it's underlying
- implementation is an RDD composed of Java objects. Instead it is
- converted to a PythonRDD in the JVM, on which Python operations can
- be done.
+ people = sqlContext.parquetFile("...")
- This class receives raw tuples from Java but assigns a class to it in
- all its data-collection methods (mapPartitionsWithIndex, collect, take,
- etc) so that PySpark sees them as Row objects with named fields.
+ Once created, it can be manipulated using the various domain-specific-language
+ (DSL) functions defined in: [[DataFrame]], [[Column]].
+
+ To select a column from the data frame, use the apply method::
+
+ ageCol = people.age
+
+ Note that the :class:`Column` type can also be manipulated
+ through its various functions::
+
+ # The following creates a new column that increases everybody's age by 10.
+ people.age + 10
+
+
+ A more concrete example::
+
+ # To create DataFrame using SQLContext
+ people = sqlContext.parquetFile("...")
+ department = sqlContext.parquetFile("...")
+
+ people.filter(people.age > 30).join(department, people.deptId == department.id)) \
+ .groupBy(department.name, "gender").agg({"salary": "avg", "age": "max"})
"""
- def __init__(self, jschema_rdd, sql_ctx):
+ def __init__(self, jdf, sql_ctx):
+ self._jdf = jdf
self.sql_ctx = sql_ctx
- self._sc = sql_ctx._sc
- clsName = jschema_rdd.getClass().getName()
- assert clsName.endswith("SchemaRDD"), "jschema_rdd must be SchemaRDD"
- self._jschema_rdd = jschema_rdd
- self._id = None
+ self._sc = sql_ctx and sql_ctx._sc
self.is_cached = False
- self.is_checkpointed = False
- self.ctx = self.sql_ctx._sc
- # the _jrdd is created by javaToPython(), serialized by pickle
- self._jrdd_deserializer = AutoBatchedSerializer(PickleSerializer())
@property
- def _jrdd(self):
- """Lazy evaluation of PythonRDD object.
+ def rdd(self):
+ """Return the content of the :class:`DataFrame` as an :class:`RDD`
+ of :class:`Row`s. """
+ if not hasattr(self, '_lazy_rdd'):
+ jrdd = self._jdf.javaToPython()
+ rdd = RDD(jrdd, self.sql_ctx._sc, BatchedSerializer(PickleSerializer()))
+ schema = self.schema()
- Only done when a user calls methods defined by the
- L{pyspark.rdd.RDD} super class (map, filter, etc.).
- """
- if not hasattr(self, '_lazy_jrdd'):
- self._lazy_jrdd = self._jschema_rdd.baseSchemaRDD().javaToPython()
- return self._lazy_jrdd
+ def applySchema(it):
+ cls = _create_cls(schema)
+ return itertools.imap(cls, it)
- def id(self):
- if self._id is None:
- self._id = self._jrdd.id()
- return self._id
+ self._lazy_rdd = rdd.mapPartitions(applySchema)
+
+ return self._lazy_rdd
def limit(self, num):
"""Limit the result count to the number specified.
- >>> srdd = sqlCtx.inferSchema(rdd)
- >>> srdd.limit(2).collect()
+ >>> df = sqlCtx.inferSchema(rdd)
+ >>> df.limit(2).collect()
[Row(field1=1, field2=u'row1'), Row(field1=2, field2=u'row2')]
- >>> srdd.limit(0).collect()
+ >>> df.limit(0).collect()
[]
"""
- rdd = self._jschema_rdd.baseSchemaRDD().limit(num)
- return SchemaRDD(rdd, self.sql_ctx)
+ jdf = self._jdf.limit(num)
+ return DataFrame(jdf, self.sql_ctx)
def toJSON(self, use_unicode=False):
- """Convert a SchemaRDD into a MappedRDD of JSON documents; one document per row.
+ """Convert a DataFrame into a MappedRDD of JSON documents; one document per row.
- >>> srdd1 = sqlCtx.jsonRDD(json)
- >>> sqlCtx.registerRDDAsTable(srdd1, "table1")
- >>> srdd2 = sqlCtx.sql( "SELECT * from table1")
- >>> srdd2.toJSON().take(1)[0] == '{"field1":1,"field2":"row1","field3":{"field4":11}}'
+ >>> df1 = sqlCtx.jsonRDD(json)
+ >>> sqlCtx.registerRDDAsTable(df1, "table1")
+ >>> df2 = sqlCtx.sql( "SELECT * from table1")
+ >>> df2.toJSON().take(1)[0] == '{"field1":1,"field2":"row1","field3":{"field4":11}}'
True
- >>> srdd3 = sqlCtx.sql( "SELECT field3.field4 from table1")
- >>> srdd3.toJSON().collect() == ['{"field4":11}', '{"field4":22}', '{"field4":33}']
+ >>> df3 = sqlCtx.sql( "SELECT field3.field4 from table1")
+ >>> df3.toJSON().collect() == ['{"field4":11}', '{"field4":22}', '{"field4":33}']
True
"""
- rdd = self._jschema_rdd.baseSchemaRDD().toJSON()
+ rdd = self._jdf.toJSON()
return RDD(rdd.toJavaRDD(), self._sc, UTF8Deserializer(use_unicode))
def saveAsParquetFile(self, path):
"""Save the contents as a Parquet file, preserving the schema.
Files that are written out using this method can be read back in as
- a SchemaRDD using the L{SQLContext.parquetFile} method.
+ a DataFrame using the L{SQLContext.parquetFile} method.
>>> import tempfile, shutil
>>> parquetFile = tempfile.mkdtemp()
>>> shutil.rmtree(parquetFile)
- >>> srdd = sqlCtx.inferSchema(rdd)
- >>> srdd.saveAsParquetFile(parquetFile)
- >>> srdd2 = sqlCtx.parquetFile(parquetFile)
- >>> sorted(srdd2.collect()) == sorted(srdd.collect())
+ >>> df = sqlCtx.inferSchema(rdd)
+ >>> df.saveAsParquetFile(parquetFile)
+ >>> df2 = sqlCtx.parquetFile(parquetFile)
+ >>> sorted(df2.collect()) == sorted(df.collect())
True
"""
- self._jschema_rdd.saveAsParquetFile(path)
+ self._jdf.saveAsParquetFile(path)
def registerTempTable(self, name):
"""Registers this RDD as a temporary table using the given name.
The lifetime of this temporary table is tied to the L{SQLContext}
- that was used to create this SchemaRDD.
+ that was used to create this DataFrame.
- >>> srdd = sqlCtx.inferSchema(rdd)
- >>> srdd.registerTempTable("test")
- >>> srdd2 = sqlCtx.sql("select * from test")
- >>> sorted(srdd.collect()) == sorted(srdd2.collect())
+ >>> df = sqlCtx.inferSchema(rdd)
+ >>> df.registerTempTable("test")
+ >>> df2 = sqlCtx.sql("select * from test")
+ >>> sorted(df.collect()) == sorted(df2.collect())
True
"""
- self._jschema_rdd.registerTempTable(name)
+ self._jdf.registerTempTable(name)
def registerAsTable(self, name):
"""DEPRECATED: use registerTempTable() instead"""
@@ -1911,62 +1928,61 @@ class SchemaRDD(RDD):
self.registerTempTable(name)
def insertInto(self, tableName, overwrite=False):
- """Inserts the contents of this SchemaRDD into the specified table.
+ """Inserts the contents of this DataFrame into the specified table.
Optionally overwriting any existing data.
"""
- self._jschema_rdd.insertInto(tableName, overwrite)
+ self._jdf.insertInto(tableName, overwrite)
def saveAsTable(self, tableName):
- """Creates a new table with the contents of this SchemaRDD."""
- self._jschema_rdd.saveAsTable(tableName)
+ """Creates a new table with the contents of this DataFrame."""
+ self._jdf.saveAsTable(tableName)
def schema(self):
- """Returns the schema of this SchemaRDD (represented by
+ """Returns the schema of this DataFrame (represented by
a L{StructType})."""
- return _parse_datatype_json_string(self._jschema_rdd.baseSchemaRDD().schema().json())
-
- def schemaString(self):
- """Returns the output schema in the tree format."""
- return self._jschema_rdd.schemaString()
+ return _parse_datatype_json_string(self._jdf.schema().json())
def printSchema(self):
"""Prints out the schema in the tree format."""
- print self.schemaString()
+ print (self._jdf.schema().treeString())
def count(self):
"""Return the number of elements in this RDD.
Unlike the base RDD implementation of count, this implementation
- leverages the query optimizer to compute the count on the SchemaRDD,
+ leverages the query optimizer to compute the count on the DataFrame,
which supports features such as filter pushdown.
- >>> srdd = sqlCtx.inferSchema(rdd)
- >>> srdd.count()
+ >>> df = sqlCtx.inferSchema(rdd)
+ >>> df.count()
3L
- >>> srdd.count() == srdd.map(lambda x: x).count()
+ >>> df.count() == df.map(lambda x: x).count()
True
"""
- return self._jschema_rdd.count()
+ return self._jdf.count()
def collect(self):
- """Return a list that contains all of the rows in this RDD.
+ """Return a list that contains all of the rows.
Each object in the list is a Row, the fields can be accessed as
attributes.
- Unlike the base RDD implementation of collect, this implementation
- leverages the query optimizer to perform a collect on the SchemaRDD,
- which supports features such as filter pushdown.
-
- >>> srdd = sqlCtx.inferSchema(rdd)
- >>> srdd.collect()
+ >>> df = sqlCtx.inferSchema(rdd)
+ >>> df.collect()
[Row(field1=1, field2=u'row1'), ..., Row(field1=3, field2=u'row3')]
"""
- with SCCallSiteSync(self.context) as css:
- bytesInJava = self._jschema_rdd.baseSchemaRDD().collectToPython().iterator()
+ with SCCallSiteSync(self._sc) as css:
+ bytesInJava = self._jdf.javaToPython().collect().iterator()
cls = _create_cls(self.schema())
- return map(cls, self._collect_iterator_through_file(bytesInJava))
+ tempFile = NamedTemporaryFile(delete=False, dir=self._sc._temp_dir)
+ tempFile.close()
+ self._sc._writeToFile(bytesInJava, tempFile.name)
+ # Read the data into Python and deserialize it:
+ with open(tempFile.name, 'rb') as tempFile:
+ rs = list(BatchedSerializer(PickleSerializer()).load_stream(tempFile))
+ os.unlink(tempFile.name)
+ return [cls(r) for r in rs]
def take(self, num):
"""Take the first num rows of the RDD.
@@ -1974,130 +1990,555 @@ class SchemaRDD(RDD):
Each object in the list is a Row, the fields can be accessed as
attributes.
- Unlike the base RDD implementation of take, this implementation
- leverages the query optimizer to perform a collect on a SchemaRDD,
- which supports features such as filter pushdown.
-
- >>> srdd = sqlCtx.inferSchema(rdd)
- >>> srdd.take(2)
+ >>> df = sqlCtx.inferSchema(rdd)
+ >>> df.take(2)
[Row(field1=1, field2=u'row1'), Row(field1=2, field2=u'row2')]
"""
return self.limit(num).collect()
- # Convert each object in the RDD to a Row with the right class
- # for this SchemaRDD, so that fields can be accessed as attributes.
- def mapPartitionsWithIndex(self, f, preservesPartitioning=False):
+ def map(self, f):
+ """ Return a new RDD by applying a function to each Row, it's a
+ shorthand for df.rdd.map()
"""
- Return a new RDD by applying a function to each partition of this RDD,
- while tracking the index of the original partition.
+ return self.rdd.map(f)
- >>> rdd = sc.parallelize([1, 2, 3, 4], 4)
- >>> def f(splitIndex, iterator): yield splitIndex
- >>> rdd.mapPartitionsWithIndex(f).sum()
- 6
+ def mapPartitions(self, f, preservesPartitioning=False):
"""
- rdd = RDD(self._jrdd, self._sc, self._jrdd_deserializer)
-
- schema = self.schema()
+ Return a new RDD by applying a function to each partition.
- def applySchema(_, it):
- cls = _create_cls(schema)
- return itertools.imap(cls, it)
-
- objrdd = rdd.mapPartitionsWithIndex(applySchema, preservesPartitioning)
- return objrdd.mapPartitionsWithIndex(f, preservesPartitioning)
+ >>> rdd = sc.parallelize([1, 2, 3, 4], 4)
+ >>> def f(iterator): yield 1
+ >>> rdd.mapPartitions(f).sum()
+ 4
+ """
+ return self.rdd.mapPartitions(f, preservesPartitioning)
- # We override the default cache/persist/checkpoint behavior
- # as we want to cache the underlying SchemaRDD object in the JVM,
- # not the PythonRDD checkpointed by the super class
def cache(self):
+ """ Persist with the default storage level (C{MEMORY_ONLY_SER}).
+ """
self.is_cached = True
- self._jschema_rdd.cache()
+ self._jdf.cache()
return self
def persist(self, storageLevel=StorageLevel.MEMORY_ONLY_SER):
+ """ Set the storage level to persist its values across operations
+ after the first time it is computed. This can only be used to assign
+ a new storage level if the RDD does not have a storage level set yet.
+ If no storage level is specified defaults to (C{MEMORY_ONLY_SER}).
+ """
self.is_cached = True
- javaStorageLevel = self.ctx._getJavaStorageLevel(storageLevel)
- self._jschema_rdd.persist(javaStorageLevel)
+ javaStorageLevel = self._sc._getJavaStorageLevel(storageLevel)
+ self._jdf.persist(javaStorageLevel)
return self
def unpersist(self, blocking=True):
+ """ Mark it as non-persistent, and remove all blocks for it from
+ memory and disk.
+ """
self.is_cached = False
- self._jschema_rdd.unpersist(blocking)
+ self._jdf.unpersist(blocking)
return self
- def checkpoint(self):
- self.is_checkpointed = True
- self._jschema_rdd.checkpoint()
+ # def coalesce(self, numPartitions, shuffle=False):
+ # rdd = self._jdf.coalesce(numPartitions, shuffle, None)
+ # return DataFrame(rdd, self.sql_ctx)
- def isCheckpointed(self):
- return self._jschema_rdd.isCheckpointed()
+ def repartition(self, numPartitions):
+ """ Return a new :class:`DataFrame` that has exactly `numPartitions`
+ partitions.
+ """
+ rdd = self._jdf.repartition(numPartitions, None)
+ return DataFrame(rdd, self.sql_ctx)
- def getCheckpointFile(self):
- checkpointFile = self._jschema_rdd.getCheckpointFile()
- if checkpointFile.isDefined():
- return checkpointFile.get()
+ def sample(self, withReplacement, fraction, seed=None):
+ """
+ Return a sampled subset of this DataFrame.
- def coalesce(self, numPartitions, shuffle=False):
- rdd = self._jschema_rdd.coalesce(numPartitions, shuffle, None)
- return SchemaRDD(rdd, self.sql_ctx)
+ >>> df = sqlCtx.inferSchema(rdd)
+ >>> df.sample(False, 0.5, 97).count()
+ 2L
+ """
+ assert fraction >= 0.0, "Negative fraction value: %s" % fraction
+ seed = seed if seed is not None else random.randint(0, sys.maxint)
+ rdd = self._jdf.sample(withReplacement, fraction, long(seed))
+ return DataFrame(rdd, self.sql_ctx)
+
+ # def takeSample(self, withReplacement, num, seed=None):
+ # """Return a fixed-size sampled subset of this DataFrame.
+ #
+ # >>> df = sqlCtx.inferSchema(rdd)
+ # >>> df.takeSample(False, 2, 97)
+ # [Row(field1=3, field2=u'row3'), Row(field1=1, field2=u'row1')]
+ # """
+ # seed = seed if seed is not None else random.randint(0, sys.maxint)
+ # with SCCallSiteSync(self.context) as css:
+ # bytesInJava = self._jdf \
+ # .takeSampleToPython(withReplacement, num, long(seed)) \
+ # .iterator()
+ # cls = _create_cls(self.schema())
+ # return map(cls, self._collect_iterator_through_file(bytesInJava))
- def distinct(self, numPartitions=None):
- if numPartitions is None:
- rdd = self._jschema_rdd.distinct()
+ @property
+ def dtypes(self):
+ """Return all column names and their data types as a list.
+ """
+ return [(f.name, str(f.dataType)) for f in self.schema().fields]
+
+ @property
+ def columns(self):
+ """ Return all column names as a list.
+ """
+ return [f.name for f in self.schema().fields]
+
+ def show(self):
+ raise NotImplemented
+
+ def join(self, other, joinExprs=None, joinType=None):
+ """
+ Join with another DataFrame, using the given join expression.
+ The following performs a full outer join between `df1` and `df2`::
+
+ df1.join(df2, df1.key == df2.key, "outer")
+
+ :param other: Right side of the join
+ :param joinExprs: Join expression
+ :param joinType: One of `inner`, `outer`, `left_outer`, `right_outer`,
+ `semijoin`.
+ """
+ if joinType is None:
+ if joinExprs is None:
+ jdf = self._jdf.join(other._jdf)
+ else:
+ jdf = self._jdf.join(other._jdf, joinExprs)
else:
- rdd = self._jschema_rdd.distinct(numPartitions, None)
- return SchemaRDD(rdd, self.sql_ctx)
+ jdf = self._jdf.join(other._jdf, joinExprs, joinType)
+ return DataFrame(jdf, self.sql_ctx)
+
+ def sort(self, *cols):
+ """ Return a new [[DataFrame]] sorted by the specified column,
+ in ascending column.
+
+ :param cols: The columns or expressions used for sorting
+ """
+ if not cols:
+ raise ValueError("should sort by at least one column")
+ for i, c in enumerate(cols):
+ if isinstance(c, basestring):
+ cols[i] = Column(c)
+ jcols = [c._jc for c in cols]
+ jdf = self._jdf.join(*jcols)
+ return DataFrame(jdf, self.sql_ctx)
+
+ sortBy = sort
+
+ def head(self, n=None):
+ """ Return the first `n` rows or the first row if n is None. """
+ if n is None:
+ rs = self.head(1)
+ return rs[0] if rs else None
+ return self.take(n)
+
+ def tail(self):
+ raise NotImplemented
+
+ def __getitem__(self, item):
+ if isinstance(item, basestring):
+ return Column(self._jdf.apply(item))
+
+ # TODO projection
+ raise IndexError
+
+ def __getattr__(self, name):
+ """ Return the column by given name """
+ if isinstance(name, basestring):
+ return Column(self._jdf.apply(name))
+ raise AttributeError
+
+ def As(self, name):
+ """ Alias the current DataFrame """
+ return DataFrame(getattr(self._jdf, "as")(name), self.sql_ctx)
+
+ def select(self, *cols):
+ """ Selecting a set of expressions.::
+
+ df.select()
+ df.select('colA', 'colB')
+ df.select(df.colA, df.colB + 1)
- def intersection(self, other):
- if (other.__class__ is SchemaRDD):
- rdd = self._jschema_rdd.intersection(other._jschema_rdd)
- return SchemaRDD(rdd, self.sql_ctx)
+ """
+ if not cols:
+ cols = ["*"]
+ if isinstance(cols[0], basestring):
+ cols = [_create_column_from_name(n) for n in cols]
else:
- raise ValueError("Can only intersect with another SchemaRDD")
+ cols = [c._jc for c in cols]
+ jcols = ListConverter().convert(cols, self._sc._gateway._gateway_client)
+ jdf = self._jdf.select(self._jdf.toColumnArray(jcols))
+ return DataFrame(jdf, self.sql_ctx)
- def repartition(self, numPartitions):
- rdd = self._jschema_rdd.repartition(numPartitions, None)
- return SchemaRDD(rdd, self.sql_ctx)
+ def filter(self, condition):
+ """ Filtering rows using the given condition::
- def subtract(self, other, numPartitions=None):
- if (other.__class__ is SchemaRDD):
- if numPartitions is None:
- rdd = self._jschema_rdd.subtract(other._jschema_rdd)
- else:
- rdd = self._jschema_rdd.subtract(other._jschema_rdd,
- numPartitions)
- return SchemaRDD(rdd, self.sql_ctx)
+ df.filter(df.age > 15)
+ df.where(df.age > 15)
+
+ """
+ return DataFrame(self._jdf.filter(condition._jc), self.sql_ctx)
+
+ where = filter
+
+ def groupBy(self, *cols):
+ """ Group the [[DataFrame]] using the specified columns,
+ so we can run aggregation on them. See :class:`GroupedDataFrame`
+ for all the available aggregate functions::
+
+ df.groupBy(df.department).avg()
+ df.groupBy("department", "gender").agg({
+ "salary": "avg",
+ "age": "max",
+ })
+ """
+ if cols and isinstance(cols[0], basestring):
+ cols = [_create_column_from_name(n) for n in cols]
else:
- raise ValueError("Can only subtract another SchemaRDD")
+ cols = [c._jc for c in cols]
+ jcols = ListConverter().convert(cols, self._sc._gateway._gateway_client)
+ jdf = self._jdf.groupBy(self._jdf.toColumnArray(jcols))
+ return GroupedDataFrame(jdf, self.sql_ctx)
- def sample(self, withReplacement, fraction, seed=None):
+ def agg(self, *exprs):
+ """ Aggregate on the entire [[DataFrame]] without groups
+ (shorthand for df.groupBy.agg())::
+
+ df.agg({"age": "max", "salary": "avg"})
"""
- Return a sampled subset of this SchemaRDD.
+ return self.groupBy().agg(*exprs)
- >>> srdd = sqlCtx.inferSchema(rdd)
- >>> srdd.sample(False, 0.5, 97).count()
- 2L
+ def unionAll(self, other):
+ """ Return a new DataFrame containing union of rows in this
+ frame and another frame.
+
+ This is equivalent to `UNION ALL` in SQL.
"""
- assert fraction >= 0.0, "Negative fraction value: %s" % fraction
- seed = seed if seed is not None else random.randint(0, sys.maxint)
- rdd = self._jschema_rdd.sample(withReplacement, fraction, long(seed))
- return SchemaRDD(rdd, self.sql_ctx)
+ return DataFrame(self._jdf.unionAll(other._jdf), self.sql_ctx)
- def takeSample(self, withReplacement, num, seed=None):
- """Return a fixed-size sampled subset of this SchemaRDD.
+ def intersect(self, other):
+ """ Return a new [[DataFrame]] containing rows only in
+ both this frame and another frame.
- >>> srdd = sqlCtx.inferSchema(rdd)
- >>> srdd.takeSample(False, 2, 97)
- [Row(field1=3, field2=u'row3'), Row(field1=1, field2=u'row1')]
+ This is equivalent to `INTERSECT` in SQL.
"""
- seed = seed if seed is not None else random.randint(0, sys.maxint)
- with SCCallSiteSync(self.context) as css:
- bytesInJava = self._jschema_rdd.baseSchemaRDD() \
- .takeSampleToPython(withReplacement, num, long(seed)) \
- .iterator()
- cls = _create_cls(self.schema())
- return map(cls, self._collect_iterator_through_file(bytesInJava))
+ return DataFrame(self._jdf.intersect(other._jdf), self.sql_ctx)
+
+ def Except(self, other):
+ """ Return a new [[DataFrame]] containing rows in this frame
+ but not in another frame.
+
+ This is equivalent to `EXCEPT` in SQL.
+ """
+ return DataFrame(getattr(self._jdf, "except")(other._jdf), self.sql_ctx)
+
+ def sample(self, withReplacement, fraction, seed=None):
+ """ Return a new DataFrame by sampling a fraction of rows. """
+ if seed is None:
+ jdf = self._jdf.sample(withReplacement, fraction)
+ else:
+ jdf = self._jdf.sample(withReplacement, fraction, seed)
+ return DataFrame(jdf, self.sql_ctx)
+
+ def addColumn(self, colName, col):
+ """ Return a new [[DataFrame]] by adding a column. """
+ return self.select('*', col.As(colName))
+
+ def removeColumn(self, colName):
+ raise NotImplemented
+
+
+# Having SchemaRDD for backward compatibility (for docs)
+class SchemaRDD(DataFrame):
+ """
+ SchemaRDD is deprecated, please use DataFrame
+ """
+
+
+def dfapi(f):
+ def _api(self):
+ name = f.__name__
+ jdf = getattr(self._jdf, name)()
+ return DataFrame(jdf, self.sql_ctx)
+ _api.__name__ = f.__name__
+ _api.__doc__ = f.__doc__
+ return _api
+
+
+class GroupedDataFrame(object):
+
+ """
+ A set of methods for aggregations on a :class:`DataFrame`,
+ created by DataFrame.groupBy().
+ """
+
+ def __init__(self, jdf, sql_ctx):
+ self._jdf = jdf
+ self.sql_ctx = sql_ctx
+
+ def agg(self, *exprs):
+ """ Compute aggregates by specifying a map from column name
+ to aggregate methods.
+
+ The available aggregate methods are `avg`, `max`, `min`,
+ `sum`, `count`.
+
+ :param exprs: list or aggregate columns or a map from column
+ name to agregate methods.
+ """
+ if len(exprs) == 1 and isinstance(exprs[0], dict):
+ jmap = MapConverter().convert(exprs[0],
+ self.sql_ctx._sc._gateway._gateway_client)
+ jdf = self._jdf.agg(jmap)
+ else:
+ # Columns
+ assert all(isinstance(c, Column) for c in exprs), "all exprs should be Columns"
+ jdf = self._jdf.agg(*exprs)
+ return DataFrame(jdf, self.sql_ctx)
+
+ @dfapi
+ def count(self):
+ """ Count the number of rows for each group. """
+
+ @dfapi
+ def mean(self):
+ """Compute the average value for each numeric columns
+ for each group. This is an alias for `avg`."""
+
+ @dfapi
+ def avg(self):
+ """Compute the average value for each numeric columns
+ for each group."""
+
+ @dfapi
+ def max(self):
+ """Compute the max value for each numeric columns for
+ each group. """
+
+ @dfapi
+ def min(self):
+ """Compute the min value for each numeric column for
+ each group."""
+
+ @dfapi
+ def sum(self):
+ """Compute the sum for each numeric columns for each
+ group."""
+
+
+SCALA_METHOD_MAPPINGS = {
+ '=': '$eq',
+ '>': '$greater',
+ '<': '$less',
+ '+': '$plus',
+ '-': '$minus',
+ '*': '$times',
+ '/': '$div',
+ '!': '$bang',
+ '@': '$at',
+ '#': '$hash',
+ '%': '$percent',
+ '^': '$up',
+ '&': '$amp',
+ '~': '$tilde',
+ '?': '$qmark',
+ '|': '$bar',
+ '\\': '$bslash',
+ ':': '$colon',
+}
+
+
+def _create_column_from_literal(literal):
+ sc = SparkContext._active_spark_context
+ return sc._jvm.Literal.apply(literal)
+
+
+def _create_column_from_name(name):
+ sc = SparkContext._active_spark_context
+ return sc._jvm.Column(name)
+
+
+def _scalaMethod(name):
+ """ Translate operators into methodName in Scala
+
+ For example:
+ >>> _scalaMethod('+')
+ '$plus'
+ >>> _scalaMethod('>=')
+ '$greater$eq'
+ >>> _scalaMethod('cast')
+ 'cast'
+ """
+ return ''.join(SCALA_METHOD_MAPPINGS.get(c, c) for c in name)
+
+
+def _unary_op(name):
+ """ Create a method for given unary operator """
+ def _(self):
+ return Column(getattr(self._jc, _scalaMethod(name))(), self._jdf, self.sql_ctx)
+ return _
+
+
+def _bin_op(name):
+ """ Create a method for given binary operator """
+ def _(self, other):
+ if isinstance(other, Column):
+ jc = other._jc
+ else:
+ jc = _create_column_from_literal(other)
+ return Column(getattr(self._jc, _scalaMethod(name))(jc), self._jdf, self.sql_ctx)
+ return _
+
+
+def _reverse_op(name):
+ """ Create a method for binary operator (this object is on right side)
+ """
+ def _(self, other):
+ return Column(getattr(_create_column_from_literal(other), _scalaMethod(name))(self._jc),
+ self._jdf, self.sql_ctx)
+ return _
+
+
+class Column(DataFrame):
+
+ """
+ A column in a DataFrame.
+
+ `Column` instances can be created by:
+ {{{
+ // 1. Select a column out of a DataFrame
+ df.colName
+ df["colName"]
+
+ // 2. Create from an expression
+ df["colName"] + 1
+ }}}
+ """
+
+ def __init__(self, jc, jdf=None, sql_ctx=None):
+ self._jc = jc
+ super(Column, self).__init__(jdf, sql_ctx)
+
+ # arithmetic operators
+ __neg__ = _unary_op("unary_-")
+ __add__ = _bin_op("+")
+ __sub__ = _bin_op("-")
+ __mul__ = _bin_op("*")
+ __div__ = _bin_op("/")
+ __mod__ = _bin_op("%")
+ __radd__ = _bin_op("+")
+ __rsub__ = _reverse_op("-")
+ __rmul__ = _bin_op("*")
+ __rdiv__ = _reverse_op("/")
+ __rmod__ = _reverse_op("%")
+ __abs__ = _unary_op("abs")
+ abs = _unary_op("abs")
+ sqrt = _unary_op("sqrt")
+
+ # logistic operators
+ __eq__ = _bin_op("===")
+ __ne__ = _bin_op("!==")
+ __lt__ = _bin_op("<")
+ __le__ = _bin_op("<=")
+ __ge__ = _bin_op(">=")
+ __gt__ = _bin_op(">")
+ # `and`, `or`, `not` cannot be overloaded in Python
+ And = _bin_op('&&')
+ Or = _bin_op('||')
+ Not = _unary_op('unary_!')
+
+ # bitwise operators
+ __and__ = _bin_op("&")
+ __or__ = _bin_op("|")
+ __invert__ = _unary_op("unary_~")
+ __xor__ = _bin_op("^")
+ # __lshift__ = _bin_op("<<")
+ # __rshift__ = _bin_op(">>")
+ __rand__ = _bin_op("&")
+ __ror__ = _bin_op("|")
+ __rxor__ = _bin_op("^")
+ # __rlshift__ = _reverse_op("<<")
+ # __rrshift__ = _reverse_op(">>")
+
+ # container operators
+ __contains__ = _bin_op("contains")
+ __getitem__ = _bin_op("getItem")
+ # __getattr__ = _bin_op("getField")
+
+ # string methods
+ rlike = _bin_op("rlike")
+ like = _bin_op("like")
+ startswith = _bin_op("startsWith")
+ endswith = _bin_op("endsWith")
+ upper = _unary_op("upper")
+ lower = _unary_op("lower")
+
+ def substr(self, startPos, pos):
+ if type(startPos) != type(pos):
+ raise TypeError("Can not mix the type")
+ if isinstance(startPos, (int, long)):
+
+ jc = self._jc.substr(startPos, pos)
+ elif isinstance(startPos, Column):
+ jc = self._jc.substr(startPos._jc, pos._jc)
+ else:
+ raise TypeError("Unexpected type: %s" % type(startPos))
+ return Column(jc, self._jdf, self.sql_ctx)
+
+ __getslice__ = substr
+
+ # order
+ asc = _unary_op("asc")
+ desc = _unary_op("desc")
+
+ isNull = _unary_op("isNull")
+ isNotNull = _unary_op("isNotNull")
+
+ # `as` is keyword
+ def As(self, alias):
+ return Column(getattr(self._jsc, "as")(alias), self._jdf, self.sql_ctx)
+
+ def cast(self, dataType):
+ if self.sql_ctx is None:
+ sc = SparkContext._active_spark_context
+ ssql_ctx = sc._jvm.SQLContext(sc._jsc.sc())
+ else:
+ ssql_ctx = self.sql_ctx._ssql_ctx
+ jdt = ssql_ctx.parseDataType(dataType.json())
+ return Column(self._jc.cast(jdt), self._jdf, self.sql_ctx)
+
+
+def _aggregate_func(name):
+ """ Creat a function for aggregator by name"""
+ def _(col):
+ sc = SparkContext._active_spark_context
+ if isinstance(col, Column):
+ jcol = col._jc
+ else:
+ jcol = _create_column_from_name(col)
+ # FIXME: can not access dsl.min/max ...
+ jc = getattr(sc._jvm.org.apache.spark.sql.dsl(), name)(jcol)
+ return Column(jc)
+ return staticmethod(_)
+
+
+class Aggregator(object):
+ """
+ A collections of builtin aggregators
+ """
+ max = _aggregate_func("max")
+ min = _aggregate_func("min")
+ avg = mean = _aggregate_func("mean")
+ sum = _aggregate_func("sum")
+ first = _aggregate_func("first")
+ last = _aggregate_func("last")
+ count = _aggregate_func("count")
def _test():
diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py
index b474fcf5bf..e8e207af46 100644
--- a/python/pyspark/tests.py
+++ b/python/pyspark/tests.py
@@ -806,6 +806,9 @@ class SQLTests(ReusedPySparkTestCase):
def setUp(self):
self.sqlCtx = SQLContext(self.sc)
+ self.testData = [Row(key=i, value=str(i)) for i in range(100)]
+ rdd = self.sc.parallelize(self.testData)
+ self.df = self.sqlCtx.inferSchema(rdd)
def test_udf(self):
self.sqlCtx.registerFunction("twoArgs", lambda x, y: len(x) + y, IntegerType())
@@ -821,7 +824,7 @@ class SQLTests(ReusedPySparkTestCase):
def test_udf_with_array_type(self):
d = [Row(l=range(3), d={"key": range(5)})]
rdd = self.sc.parallelize(d)
- srdd = self.sqlCtx.inferSchema(rdd).registerTempTable("test")
+ self.sqlCtx.inferSchema(rdd).registerTempTable("test")
self.sqlCtx.registerFunction("copylist", lambda l: list(l), ArrayType(IntegerType()))
self.sqlCtx.registerFunction("maplen", lambda d: len(d), IntegerType())
[(l1, l2)] = self.sqlCtx.sql("select copylist(l), maplen(d) from test").collect()
@@ -839,68 +842,51 @@ class SQLTests(ReusedPySparkTestCase):
def test_basic_functions(self):
rdd = self.sc.parallelize(['{"foo":"bar"}', '{"foo":"baz"}'])
- srdd = self.sqlCtx.jsonRDD(rdd)
- srdd.count()
- srdd.collect()
- srdd.schemaString()
- srdd.schema()
+ df = self.sqlCtx.jsonRDD(rdd)
+ df.count()
+ df.collect()
+ df.schema()
# cache and checkpoint
- self.assertFalse(srdd.is_cached)
- srdd.persist()
- srdd.unpersist()
- srdd.cache()
- self.assertTrue(srdd.is_cached)
- self.assertFalse(srdd.isCheckpointed())
- self.assertEqual(None, srdd.getCheckpointFile())
-
- srdd = srdd.coalesce(2, True)
- srdd = srdd.repartition(3)
- srdd = srdd.distinct()
- srdd.intersection(srdd)
- self.assertEqual(2, srdd.count())
-
- srdd.registerTempTable("temp")
- srdd = self.sqlCtx.sql("select foo from temp")
- srdd.count()
- srdd.collect()
-
- def test_distinct(self):
- rdd = self.sc.parallelize(['{"a": 1}', '{"b": 2}', '{"c": 3}']*10, 10)
- srdd = self.sqlCtx.jsonRDD(rdd)
- self.assertEquals(srdd.getNumPartitions(), 10)
- self.assertEquals(srdd.distinct().count(), 3)
- result = srdd.distinct(5)
- self.assertEquals(result.getNumPartitions(), 5)
- self.assertEquals(result.count(), 3)
+ self.assertFalse(df.is_cached)
+ df.persist()
+ df.unpersist()
+ df.cache()
+ self.assertTrue(df.is_cached)
+ self.assertEqual(2, df.count())
+
+ df.registerTempTable("temp")
+ df = self.sqlCtx.sql("select foo from temp")
+ df.count()
+ df.collect()
def test_apply_schema_to_row(self):
- srdd = self.sqlCtx.jsonRDD(self.sc.parallelize(["""{"a":2}"""]))
- srdd2 = self.sqlCtx.applySchema(srdd.map(lambda x: x), srdd.schema())
- self.assertEqual(srdd.collect(), srdd2.collect())
+ df = self.sqlCtx.jsonRDD(self.sc.parallelize(["""{"a":2}"""]))
+ df2 = self.sqlCtx.applySchema(df.map(lambda x: x), df.schema())
+ self.assertEqual(df.collect(), df2.collect())
rdd = self.sc.parallelize(range(10)).map(lambda x: Row(a=x))
- srdd3 = self.sqlCtx.applySchema(rdd, srdd.schema())
- self.assertEqual(10, srdd3.count())
+ df3 = self.sqlCtx.applySchema(rdd, df.schema())
+ self.assertEqual(10, df3.count())
def test_serialize_nested_array_and_map(self):
d = [Row(l=[Row(a=1, b='s')], d={"key": Row(c=1.0, d="2")})]
rdd = self.sc.parallelize(d)
- srdd = self.sqlCtx.inferSchema(rdd)
- row = srdd.first()
+ df = self.sqlCtx.inferSchema(rdd)
+ row = df.head()
self.assertEqual(1, len(row.l))
self.assertEqual(1, row.l[0].a)
self.assertEqual("2", row.d["key"].d)
- l = srdd.map(lambda x: x.l).first()
+ l = df.map(lambda x: x.l).first()
self.assertEqual(1, len(l))
self.assertEqual('s', l[0].b)
- d = srdd.map(lambda x: x.d).first()
+ d = df.map(lambda x: x.d).first()
self.assertEqual(1, len(d))
self.assertEqual(1.0, d["key"].c)
- row = srdd.map(lambda x: x.d["key"]).first()
+ row = df.map(lambda x: x.d["key"]).first()
self.assertEqual(1.0, row.c)
self.assertEqual("2", row.d)
@@ -908,26 +894,26 @@ class SQLTests(ReusedPySparkTestCase):
d = [Row(l=[], d={}),
Row(l=[Row(a=1, b='s')], d={"key": Row(c=1.0, d="2")}, s="")]
rdd = self.sc.parallelize(d)
- srdd = self.sqlCtx.inferSchema(rdd)
- self.assertEqual([], srdd.map(lambda r: r.l).first())
- self.assertEqual([None, ""], srdd.map(lambda r: r.s).collect())
- srdd.registerTempTable("test")
+ df = self.sqlCtx.inferSchema(rdd)
+ self.assertEqual([], df.map(lambda r: r.l).first())
+ self.assertEqual([None, ""], df.map(lambda r: r.s).collect())
+ df.registerTempTable("test")
result = self.sqlCtx.sql("SELECT l[0].a from test where d['key'].d = '2'")
- self.assertEqual(1, result.first()[0])
+ self.assertEqual(1, result.head()[0])
- srdd2 = self.sqlCtx.inferSchema(rdd, 1.0)
- self.assertEqual(srdd.schema(), srdd2.schema())
- self.assertEqual({}, srdd2.map(lambda r: r.d).first())
- self.assertEqual([None, ""], srdd2.map(lambda r: r.s).collect())
- srdd2.registerTempTable("test2")
+ df2 = self.sqlCtx.inferSchema(rdd, 1.0)
+ self.assertEqual(df.schema(), df2.schema())
+ self.assertEqual({}, df2.map(lambda r: r.d).first())
+ self.assertEqual([None, ""], df2.map(lambda r: r.s).collect())
+ df2.registerTempTable("test2")
result = self.sqlCtx.sql("SELECT l[0].a from test2 where d['key'].d = '2'")
- self.assertEqual(1, result.first()[0])
+ self.assertEqual(1, result.head()[0])
def test_struct_in_map(self):
d = [Row(m={Row(i=1): Row(s="")})]
rdd = self.sc.parallelize(d)
- srdd = self.sqlCtx.inferSchema(rdd)
- k, v = srdd.first().m.items()[0]
+ df = self.sqlCtx.inferSchema(rdd)
+ k, v = df.head().m.items()[0]
self.assertEqual(1, k.i)
self.assertEqual("", v.s)
@@ -935,9 +921,9 @@ class SQLTests(ReusedPySparkTestCase):
row = Row(l=[Row(a=1, b='s')], d={"key": Row(c=1.0, d="2")})
self.assertEqual(1, row.asDict()['l'][0].a)
rdd = self.sc.parallelize([row])
- srdd = self.sqlCtx.inferSchema(rdd)
- srdd.registerTempTable("test")
- row = self.sqlCtx.sql("select l, d from test").first()
+ df = self.sqlCtx.inferSchema(rdd)
+ df.registerTempTable("test")
+ row = self.sqlCtx.sql("select l, d from test").head()
self.assertEqual(1, row.asDict()["l"][0].a)
self.assertEqual(1.0, row.asDict()['d']['key'].c)
@@ -945,12 +931,12 @@ class SQLTests(ReusedPySparkTestCase):
from pyspark.tests import ExamplePoint, ExamplePointUDT
row = Row(label=1.0, point=ExamplePoint(1.0, 2.0))
rdd = self.sc.parallelize([row])
- srdd = self.sqlCtx.inferSchema(rdd)
- schema = srdd.schema()
+ df = self.sqlCtx.inferSchema(rdd)
+ schema = df.schema()
field = [f for f in schema.fields if f.name == "point"][0]
self.assertEqual(type(field.dataType), ExamplePointUDT)
- srdd.registerTempTable("labeled_point")
- point = self.sqlCtx.sql("SELECT point FROM labeled_point").first().point
+ df.registerTempTable("labeled_point")
+ point = self.sqlCtx.sql("SELECT point FROM labeled_point").head().point
self.assertEqual(point, ExamplePoint(1.0, 2.0))
def test_apply_schema_with_udt(self):
@@ -959,21 +945,52 @@ class SQLTests(ReusedPySparkTestCase):
rdd = self.sc.parallelize([row])
schema = StructType([StructField("label", DoubleType(), False),
StructField("point", ExamplePointUDT(), False)])
- srdd = self.sqlCtx.applySchema(rdd, schema)
- point = srdd.first().point
+ df = self.sqlCtx.applySchema(rdd, schema)
+ point = df.head().point
self.assertEquals(point, ExamplePoint(1.0, 2.0))
def test_parquet_with_udt(self):
from pyspark.tests import ExamplePoint
row = Row(label=1.0, point=ExamplePoint(1.0, 2.0))
rdd = self.sc.parallelize([row])
- srdd0 = self.sqlCtx.inferSchema(rdd)
+ df0 = self.sqlCtx.inferSchema(rdd)
output_dir = os.path.join(self.tempdir.name, "labeled_point")
- srdd0.saveAsParquetFile(output_dir)
- srdd1 = self.sqlCtx.parquetFile(output_dir)
- point = srdd1.first().point
+ df0.saveAsParquetFile(output_dir)
+ df1 = self.sqlCtx.parquetFile(output_dir)
+ point = df1.head().point
self.assertEquals(point, ExamplePoint(1.0, 2.0))
+ def test_column_operators(self):
+ from pyspark.sql import Column, LongType
+ ci = self.df.key
+ cs = self.df.value
+ c = ci == cs
+ self.assertTrue(isinstance((- ci - 1 - 2) % 3 * 2.5 / 3.5, Column))
+ rcc = (1 + ci), (1 - ci), (1 * ci), (1 / ci), (1 % ci)
+ self.assertTrue(all(isinstance(c, Column) for c in rcc))
+ cb = [ci == 5, ci != 0, ci > 3, ci < 4, ci >= 0, ci <= 7, ci and cs, ci or cs]
+ self.assertTrue(all(isinstance(c, Column) for c in cb))
+ cbit = (ci & ci), (ci | ci), (ci ^ ci), (~ci)
+ self.assertTrue(all(isinstance(c, Column) for c in cbit))
+ css = cs.like('a'), cs.rlike('a'), cs.asc(), cs.desc(), cs.startswith('a'), cs.endswith('a')
+ self.assertTrue(all(isinstance(c, Column) for c in css))
+ self.assertTrue(isinstance(ci.cast(LongType()), Column))
+
+ def test_column_select(self):
+ df = self.df
+ self.assertEqual(self.testData, df.select("*").collect())
+ self.assertEqual(self.testData, df.select(df.key, df.value).collect())
+ self.assertEqual([Row(value='1')], df.where(df.key == 1).select(df.value).collect())
+
+ def test_aggregator(self):
+ df = self.df
+ g = df.groupBy()
+ self.assertEqual([99, 100], sorted(g.agg({'key': 'max', 'value': 'count'}).collect()[0]))
+ self.assertEqual([Row(**{"AVG(key#0)": 49.5})], g.mean().collect())
+ # TODO(davies): fix aggregators
+ from pyspark.sql import Aggregator as Agg
+ # self.assertEqual((0, '100'), tuple(g.agg(Agg.first(df.key), Agg.last(df.value)).first()))
+
class InputFormatTests(ReusedPySparkTestCase):