aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark
diff options
context:
space:
mode:
authorDavies Liu <davies@databricks.com>2015-02-11 12:13:16 -0800
committerReynold Xin <rxin@databricks.com>2015-02-11 12:13:16 -0800
commitb694eb9c2fefeaa33891d3e61f9bea369bc09984 (patch)
tree0618924c6564e41ab27676415e79467216d4832f /python/pyspark
parent1ac099e3e00ddb01af8e6e3a84c70f8363f04b5c (diff)
downloadspark-b694eb9c2fefeaa33891d3e61f9bea369bc09984.tar.gz
spark-b694eb9c2fefeaa33891d3e61f9bea369bc09984.tar.bz2
spark-b694eb9c2fefeaa33891d3e61f9bea369bc09984.zip
[SPARK-5677] [SPARK-5734] [SQL] [PySpark] Python DataFrame API remaining tasks
1. DataFrame.renameColumn 2. DataFrame.show() and _repr_ 3. Use simpleString() rather than jsonValue in DataFrame.dtypes 4. createDataFrame from local Python data, including pandas.DataFrame Author: Davies Liu <davies@databricks.com> Closes #4528 from davies/df3 and squashes the following commits: 014acea [Davies Liu] fix typo 6ba526e [Davies Liu] fix tests 46f5f95 [Davies Liu] address comments 6cbc154 [Davies Liu] dataframe.show() and improve dtypes 6f94f25 [Davies Liu] create DataFrame from local Python data
Diffstat (limited to 'python/pyspark')
-rw-r--r--python/pyspark/sql/context.py114
-rw-r--r--python/pyspark/sql/dataframe.py42
-rw-r--r--python/pyspark/sql/tests.py2
-rw-r--r--python/pyspark/sql/types.py32
4 files changed, 144 insertions, 46 deletions
diff --git a/python/pyspark/sql/context.py b/python/pyspark/sql/context.py
index 9d29ef4839..db4bcbece2 100644
--- a/python/pyspark/sql/context.py
+++ b/python/pyspark/sql/context.py
@@ -23,12 +23,18 @@ from itertools import imap
from py4j.protocol import Py4JError
from py4j.java_collections import MapConverter
-from pyspark.rdd import _prepare_for_python_RDD
+from pyspark.rdd import RDD, _prepare_for_python_RDD
from pyspark.serializers import AutoBatchedSerializer, PickleSerializer
-from pyspark.sql.types import StringType, StructType, _infer_type, _verify_type, \
+from pyspark.sql.types import StringType, StructType, _verify_type, \
_infer_schema, _has_nulltype, _merge_type, _create_converter, _python_to_sql_converter
from pyspark.sql.dataframe import DataFrame
+try:
+ import pandas
+ has_pandas = True
+except ImportError:
+ has_pandas = False
+
__all__ = ["SQLContext", "HiveContext"]
@@ -116,6 +122,31 @@ class SQLContext(object):
self._sc._javaAccumulator,
returnType.json())
+ def _inferSchema(self, rdd, samplingRatio=None):
+ first = rdd.first()
+ if not first:
+ raise ValueError("The first row in RDD is empty, "
+ "can not infer schema")
+ if type(first) is dict:
+ warnings.warn("Using RDD of dict to inferSchema is deprecated,"
+ "please use pyspark.sql.Row instead")
+
+ if samplingRatio is None:
+ schema = _infer_schema(first)
+ if _has_nulltype(schema):
+ for row in rdd.take(100)[1:]:
+ schema = _merge_type(schema, _infer_schema(row))
+ if not _has_nulltype(schema):
+ break
+ else:
+ raise ValueError("Some of types cannot be determined by the "
+ "first 100 rows, please try again with sampling")
+ else:
+ if samplingRatio < 0.99:
+ rdd = rdd.sample(False, float(samplingRatio))
+ schema = rdd.map(_infer_schema).reduce(_merge_type)
+ return schema
+
def inferSchema(self, rdd, samplingRatio=None):
"""Infer and apply a schema to an RDD of L{Row}.
@@ -171,29 +202,7 @@ class SQLContext(object):
if isinstance(rdd, DataFrame):
raise TypeError("Cannot apply schema to DataFrame")
- first = rdd.first()
- if not first:
- raise ValueError("The first row in RDD is empty, "
- "can not infer schema")
- if type(first) is dict:
- warnings.warn("Using RDD of dict to inferSchema is deprecated,"
- "please use pyspark.sql.Row instead")
-
- if samplingRatio is None:
- schema = _infer_schema(first)
- if _has_nulltype(schema):
- for row in rdd.take(100)[1:]:
- schema = _merge_type(schema, _infer_schema(row))
- if not _has_nulltype(schema):
- break
- else:
- warnings.warn("Some of types cannot be determined by the "
- "first 100 rows, please try again with sampling")
- else:
- if samplingRatio < 0.99:
- rdd = rdd.sample(False, float(samplingRatio))
- schema = rdd.map(_infer_schema).reduce(_merge_type)
-
+ schema = self._inferSchema(rdd, samplingRatio)
converter = _create_converter(schema)
rdd = rdd.map(converter)
return self.applySchema(rdd, schema)
@@ -274,7 +283,7 @@ class SQLContext(object):
raise TypeError("Cannot apply schema to DataFrame")
if not isinstance(schema, StructType):
- raise TypeError("schema should be StructType")
+ raise TypeError("schema should be StructType, but got %s" % schema)
# take the first few rows to verify schema
rows = rdd.take(10)
@@ -294,9 +303,9 @@ class SQLContext(object):
df = self._ssql_ctx.applySchemaToPythonRDD(jrdd.rdd(), schema.json())
return DataFrame(df, self)
- def createDataFrame(self, rdd, schema=None, samplingRatio=None):
+ def createDataFrame(self, data, schema=None, samplingRatio=None):
"""
- Create a DataFrame from an RDD of tuple/list and an optional `schema`.
+ Create a DataFrame from an RDD of tuple/list, list or pandas.DataFrame.
`schema` could be :class:`StructType` or a list of column names.
@@ -311,12 +320,20 @@ class SQLContext(object):
rows will be used to do referring. The first row will be used if
`samplingRatio` is None.
- :param rdd: an RDD of Row or tuple or list or dict
+ :param data: an RDD of Row/tuple/list/dict, list, or pandas.DataFrame
:param schema: a StructType or list of names of columns
:param samplingRatio: the sample ratio of rows used for inferring
:return: a DataFrame
- >>> rdd = sc.parallelize([('Alice', 1)])
+ >>> l = [('Alice', 1)]
+ >>> sqlCtx.createDataFrame(l, ['name', 'age']).collect()
+ [Row(name=u'Alice', age=1)]
+
+ >>> d = [{'name': 'Alice', 'age': 1}]
+ >>> sqlCtx.createDataFrame(d).collect()
+ [Row(age=1, name=u'Alice')]
+
+ >>> rdd = sc.parallelize(l)
>>> df = sqlCtx.createDataFrame(rdd, ['name', 'age'])
>>> df.collect()
[Row(name=u'Alice', age=1)]
@@ -336,19 +353,32 @@ class SQLContext(object):
>>> df3.collect()
[Row(name=u'Alice', age=1)]
"""
- if isinstance(rdd, DataFrame):
- raise TypeError("rdd is already a DataFrame")
+ if isinstance(data, DataFrame):
+ raise TypeError("data is already a DataFrame")
- if isinstance(schema, StructType):
- return self.applySchema(rdd, schema)
- else:
- if isinstance(schema, (list, tuple)):
- first = rdd.first()
- if not isinstance(first, (list, tuple)):
- raise ValueError("each row in `rdd` should be list or tuple")
- row_cls = Row(*schema)
- rdd = rdd.map(lambda r: row_cls(*r))
- return self.inferSchema(rdd, samplingRatio)
+ if has_pandas and isinstance(data, pandas.DataFrame):
+ data = self._sc.parallelize(data.to_records(index=False))
+ if schema is None:
+ schema = list(data.columns)
+
+ if not isinstance(data, RDD):
+ try:
+ # data could be list, tuple, generator ...
+ data = self._sc.parallelize(data)
+ except Exception:
+ raise ValueError("cannot create an RDD from type: %s" % type(data))
+
+ if schema is None:
+ return self.inferSchema(data, samplingRatio)
+
+ if isinstance(schema, (list, tuple)):
+ first = data.first()
+ if not isinstance(first, (list, tuple)):
+ raise ValueError("each row in `rdd` should be list or tuple")
+ row_cls = Row(*schema)
+ schema = self._inferSchema(data.map(lambda r: row_cls(*r)), samplingRatio)
+
+ return self.applySchema(data, schema)
def registerRDDAsTable(self, rdd, tableName):
"""Registers the given RDD as a temporary table in the catalog.
diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py
index 3eef0cc376..3eb56ed74c 100644
--- a/python/pyspark/sql/dataframe.py
+++ b/python/pyspark/sql/dataframe.py
@@ -236,6 +236,24 @@ class DataFrame(object):
"""
print (self._jdf.schema().treeString())
+ def show(self):
+ """
+ Print the first 20 rows.
+
+ >>> df.show()
+ age name
+ 2 Alice
+ 5 Bob
+ >>> df
+ age name
+ 2 Alice
+ 5 Bob
+ """
+ print (self)
+
+ def __repr__(self):
+ return self._jdf.showString()
+
def count(self):
"""Return the number of elements in this RDD.
@@ -380,9 +398,9 @@ class DataFrame(object):
"""Return all column names and their data types as a list.
>>> df.dtypes
- [('age', 'integer'), ('name', 'string')]
+ [('age', 'int'), ('name', 'string')]
"""
- return [(str(f.name), f.dataType.jsonValue()) for f in self.schema().fields]
+ return [(str(f.name), f.dataType.simpleString()) for f in self.schema().fields]
@property
def columns(self):
@@ -606,6 +624,17 @@ class DataFrame(object):
"""
return self.select('*', col.alias(colName))
+ def renameColumn(self, existing, new):
+ """ Rename an existing column to a new name
+
+ >>> df.renameColumn('age', 'age2').collect()
+ [Row(age2=2, name=u'Alice'), Row(age2=5, name=u'Bob')]
+ """
+ cols = [Column(_to_java_column(c), self.sql_ctx).alias(new)
+ if c == existing else c
+ for c in self.columns]
+ return self.select(*cols)
+
def to_pandas(self):
"""
Collect all the rows and return a `pandas.DataFrame`.
@@ -885,6 +914,12 @@ class Column(DataFrame):
jc = self._jc.cast(jdt)
return Column(jc, self.sql_ctx)
+ def __repr__(self):
+ if self._jdf.isComputable():
+ return self._jdf.samples()
+ else:
+ return 'Column<%s>' % self._jdf.toString()
+
def to_pandas(self):
"""
Return a pandas.Series from the column
@@ -1030,7 +1065,8 @@ def _test():
globs['df'] = sqlCtx.inferSchema(rdd2)
globs['df2'] = sqlCtx.inferSchema(rdd3)
(failure_count, test_count) = doctest.testmod(
- pyspark.sql.dataframe, globs=globs, optionflags=doctest.ELLIPSIS)
+ pyspark.sql.dataframe, globs=globs,
+ optionflags=doctest.ELLIPSIS | doctest.NORMALIZE_WHITESPACE)
globs['sc'].stop()
if failure_count:
exit(-1)
diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py
index 5e41e36897..43e5c3a1b0 100644
--- a/python/pyspark/sql/tests.py
+++ b/python/pyspark/sql/tests.py
@@ -194,7 +194,7 @@ class SQLTests(ReusedPySparkTestCase):
result = self.sqlCtx.sql("SELECT l[0].a from test where d['key'].d = '2'")
self.assertEqual(1, result.head()[0])
- df2 = self.sqlCtx.createDataFrame(rdd, 1.0)
+ df2 = self.sqlCtx.createDataFrame(rdd, samplingRatio=1.0)
self.assertEqual(df.schema(), df2.schema())
self.assertEqual({}, df2.map(lambda r: r.d).first())
self.assertEqual([None, ""], df2.map(lambda r: r.s).collect())
diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py
index 41afefe48e..40bd7e54a9 100644
--- a/python/pyspark/sql/types.py
+++ b/python/pyspark/sql/types.py
@@ -52,6 +52,9 @@ class DataType(object):
def typeName(cls):
return cls.__name__[:-4].lower()
+ def simpleString(self):
+ return self.typeName()
+
def jsonValue(self):
return self.typeName()
@@ -145,6 +148,12 @@ class DecimalType(DataType):
self.scale = scale
self.hasPrecisionInfo = precision is not None
+ def simpleString(self):
+ if self.hasPrecisionInfo:
+ return "decimal(%d,%d)" % (self.precision, self.scale)
+ else:
+ return "decimal(10,0)"
+
def jsonValue(self):
if self.hasPrecisionInfo:
return "decimal(%d,%d)" % (self.precision, self.scale)
@@ -180,6 +189,8 @@ class ByteType(PrimitiveType):
The data type representing int values with 1 singed byte.
"""
+ def simpleString(self):
+ return 'tinyint'
class IntegerType(PrimitiveType):
@@ -188,6 +199,8 @@ class IntegerType(PrimitiveType):
The data type representing int values.
"""
+ def simpleString(self):
+ return 'int'
class LongType(PrimitiveType):
@@ -198,6 +211,8 @@ class LongType(PrimitiveType):
beyond the range of [-9223372036854775808, 9223372036854775807],
please use DecimalType.
"""
+ def simpleString(self):
+ return 'bigint'
class ShortType(PrimitiveType):
@@ -206,6 +221,8 @@ class ShortType(PrimitiveType):
The data type representing int values with 2 signed bytes.
"""
+ def simpleString(self):
+ return 'smallint'
class ArrayType(DataType):
@@ -233,6 +250,9 @@ class ArrayType(DataType):
self.elementType = elementType
self.containsNull = containsNull
+ def simpleString(self):
+ return 'array<%s>' % self.elementType.simpleString()
+
def __repr__(self):
return "ArrayType(%s,%s)" % (self.elementType,
str(self.containsNull).lower())
@@ -283,6 +303,9 @@ class MapType(DataType):
self.valueType = valueType
self.valueContainsNull = valueContainsNull
+ def simpleString(self):
+ return 'map<%s,%s>' % (self.keyType.simpleString(), self.valueType.simpleString())
+
def __repr__(self):
return "MapType(%s,%s,%s)" % (self.keyType, self.valueType,
str(self.valueContainsNull).lower())
@@ -337,6 +360,9 @@ class StructField(DataType):
self.nullable = nullable
self.metadata = metadata or {}
+ def simpleString(self):
+ return '%s:%s' % (self.name, self.dataType.simpleString())
+
def __repr__(self):
return "StructField(%s,%s,%s)" % (self.name, self.dataType,
str(self.nullable).lower())
@@ -379,6 +405,9 @@ class StructType(DataType):
"""
self.fields = fields
+ def simpleString(self):
+ return 'struct<%s>' % (','.join(f.simpleString() for f in self.fields))
+
def __repr__(self):
return ("StructType(List(%s))" %
",".join(str(field) for field in self.fields))
@@ -435,6 +464,9 @@ class UserDefinedType(DataType):
"""
raise NotImplementedError("UDT must implement deserialize().")
+ def simpleString(self):
+ return 'null'
+
def json(self):
return json.dumps(self.jsonValue(), separators=(',', ':'), sort_keys=True)