aboutsummaryrefslogtreecommitdiff
path: root/python
diff options
context:
space:
mode:
authorDavies Liu <davies@databricks.com>2015-03-30 15:47:00 -0700
committerReynold Xin <rxin@databricks.com>2015-03-30 15:47:00 -0700
commitf76d2e55b1a67bf5576e1aa001a0b872b9b3895a (patch)
treee644495fdeff827221574dc23594db48a246e83c /python
parentdf3550084c9975f999ed370dd9f7c495181a68ba (diff)
downloadspark-f76d2e55b1a67bf5576e1aa001a0b872b9b3895a.tar.gz
spark-f76d2e55b1a67bf5576e1aa001a0b872b9b3895a.tar.bz2
spark-f76d2e55b1a67bf5576e1aa001a0b872b9b3895a.zip
[SPARK-6603] [PySpark] [SQL] add SQLContext.udf and deprecate inferSchema() and applySchema
This PR create an alias for `registerFunction` as `udf.register`, to be consistent with Scala API. It also deprecated inferSchema() and applySchema(), show an warning for them. cc rxin Author: Davies Liu <davies@databricks.com> Closes #5273 from davies/udf and squashes the following commits: 476e947 [Davies Liu] address comments c096fdb [Davies Liu] add SQLContext.udf and deprecate inferSchema() and applySchema
Diffstat (limited to 'python')
-rw-r--r--python/pyspark/sql/context.py87
1 files changed, 60 insertions, 27 deletions
diff --git a/python/pyspark/sql/context.py b/python/pyspark/sql/context.py
index 795ef0dbc4..80939a1f8a 100644
--- a/python/pyspark/sql/context.py
+++ b/python/pyspark/sql/context.py
@@ -34,7 +34,7 @@ try:
except ImportError:
has_pandas = False
-__all__ = ["SQLContext", "HiveContext"]
+__all__ = ["SQLContext", "HiveContext", "UDFRegistration"]
def _monkey_patch_RDD(sqlCtx):
@@ -56,6 +56,31 @@ def _monkey_patch_RDD(sqlCtx):
RDD.toDF = toDF
+class UDFRegistration(object):
+ """Wrapper for register UDF"""
+
+ def __init__(self, sqlCtx):
+ self.sqlCtx = sqlCtx
+
+ def register(self, name, f, returnType=StringType()):
+ """Registers a lambda function as a UDF so it can be used in SQL statements.
+
+ In addition to a name and the function itself, the return type can be optionally specified.
+ When the return type is not given it default to a string and conversion will automatically
+ be done. For any other return type, the produced object must match the specified type.
+
+ >>> sqlCtx.udf.register("stringLengthString", lambda x: len(x))
+ >>> sqlCtx.sql("SELECT stringLengthString('test')").collect()
+ [Row(c0=u'4')]
+
+ >>> from pyspark.sql.types import IntegerType
+ >>> sqlCtx.udf.register("stringLengthInt", lambda x: len(x), IntegerType())
+ >>> sqlCtx.sql("SELECT stringLengthInt('test')").collect()
+ [Row(c0=4)]
+ """
+ return self.sqlCtx.registerFunction(name, f, returnType)
+
+
class SQLContext(object):
"""Main entry point for Spark SQL functionality.
@@ -118,6 +143,11 @@ class SQLContext(object):
"""
return self._ssql_ctx.getConf(key, defaultValue)
+ @property
+ def udf(self):
+ """Wrapper for register Python function as UDF """
+ return UDFRegistration(self)
+
def registerFunction(self, name, f, returnType=StringType()):
"""Registers a lambda function as a UDF so it can be used in SQL statements.
@@ -198,14 +228,12 @@ class SQLContext(object):
>>> df.collect()[0]
Row(field1=1, field2=u'row1')
"""
+ warnings.warn("inferSchema is deprecated, please use createDataFrame instead")
if isinstance(rdd, DataFrame):
raise TypeError("Cannot apply schema to DataFrame")
- schema = self._inferSchema(rdd, samplingRatio)
- converter = _create_converter(schema)
- rdd = rdd.map(converter)
- return self.applySchema(rdd, schema)
+ return self.createDataFrame(rdd, None, samplingRatio)
def applySchema(self, rdd, schema):
"""
@@ -230,6 +258,7 @@ class SQLContext(object):
>>> df.collect()
[Row(field1=1, field2=u'row1'),..., Row(field1=3, field2=u'row3')]
"""
+ warnings.warn("applySchema is deprecated, please use createDataFrame instead")
if isinstance(rdd, DataFrame):
raise TypeError("Cannot apply schema to DataFrame")
@@ -237,23 +266,7 @@ class SQLContext(object):
if not isinstance(schema, StructType):
raise TypeError("schema should be StructType, but got %s" % schema)
- # take the first few rows to verify schema
- rows = rdd.take(10)
- # Row() cannot been deserialized by Pyrolite
- if rows and isinstance(rows[0], tuple) and rows[0].__class__.__name__ == 'Row':
- rdd = rdd.map(tuple)
- rows = rdd.take(10)
-
- for row in rows:
- _verify_type(row, schema)
-
- # convert python objects to sql data
- converter = _python_to_sql_converter(schema)
- rdd = rdd.map(converter)
-
- jrdd = self._jvm.SerDeUtil.toJavaArray(rdd._to_java_object_rdd())
- df = self._ssql_ctx.applySchemaToPythonRDD(jrdd.rdd(), schema.json())
- return DataFrame(df, self)
+ return self.createDataFrame(rdd, schema)
def createDataFrame(self, data, schema=None, samplingRatio=None):
"""
@@ -323,22 +336,42 @@ class SQLContext(object):
if not isinstance(data, RDD):
try:
# data could be list, tuple, generator ...
- data = self._sc.parallelize(data)
+ rdd = self._sc.parallelize(data)
except Exception:
raise ValueError("cannot create an RDD from type: %s" % type(data))
+ else:
+ rdd = data
if schema is None:
- return self.inferSchema(data, samplingRatio)
+ schema = self._inferSchema(rdd, samplingRatio)
+ converter = _create_converter(schema)
+ rdd = rdd.map(converter)
if isinstance(schema, (list, tuple)):
- first = data.first()
+ first = rdd.first()
if not isinstance(first, (list, tuple)):
raise ValueError("each row in `rdd` should be list or tuple, "
"but got %r" % type(first))
row_cls = Row(*schema)
- schema = self._inferSchema(data.map(lambda r: row_cls(*r)), samplingRatio)
+ schema = self._inferSchema(rdd.map(lambda r: row_cls(*r)), samplingRatio)
- return self.applySchema(data, schema)
+ # take the first few rows to verify schema
+ rows = rdd.take(10)
+ # Row() cannot been deserialized by Pyrolite
+ if rows and isinstance(rows[0], tuple) and rows[0].__class__.__name__ == 'Row':
+ rdd = rdd.map(tuple)
+ rows = rdd.take(10)
+
+ for row in rows:
+ _verify_type(row, schema)
+
+ # convert python objects to sql data
+ converter = _python_to_sql_converter(schema)
+ rdd = rdd.map(converter)
+
+ jrdd = self._jvm.SerDeUtil.toJavaArray(rdd._to_java_object_rdd())
+ df = self._ssql_ctx.applySchemaToPythonRDD(jrdd.rdd(), schema.json())
+ return DataFrame(df, self)
def registerDataFrameAsTable(self, rdd, tableName):
"""Registers the given RDD as a temporary table in the catalog.