aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark/sql.py
diff options
context:
space:
mode:
Diffstat (limited to 'python/pyspark/sql.py')
-rw-r--r--python/pyspark/sql.py38
1 files changed, 22 insertions, 16 deletions
diff --git a/python/pyspark/sql.py b/python/pyspark/sql.py
index c3a6938f56..fdd8034c98 100644
--- a/python/pyspark/sql.py
+++ b/python/pyspark/sql.py
@@ -931,7 +931,7 @@ def _parse_schema_abstract(s):
def _infer_schema_type(obj, dataType):
"""
- Fill the dataType with types infered from obj
+ Fill the dataType with types inferred from obj
>>> schema = _parse_schema_abstract("a b c d")
>>> row = (1, 1.0, "str", datetime.date(2014, 10, 10))
@@ -2140,7 +2140,7 @@ class DataFrame(object):
return Column(self._jdf.apply(name))
raise AttributeError
- def As(self, name):
+ def alias(self, name):
""" Alias the current DataFrame """
return DataFrame(getattr(self._jdf, "as")(name), self.sql_ctx)
@@ -2216,7 +2216,7 @@ class DataFrame(object):
"""
return DataFrame(self._jdf.intersect(other._jdf), self.sql_ctx)
- def Except(self, other):
+ def subtract(self, other):
""" Return a new [[DataFrame]] containing rows in this frame
but not in another frame.
@@ -2234,7 +2234,7 @@ class DataFrame(object):
def addColumn(self, colName, col):
""" Return a new [[DataFrame]] by adding a column. """
- return self.select('*', col.As(colName))
+ return self.select('*', col.alias(colName))
def removeColumn(self, colName):
raise NotImplemented
@@ -2342,7 +2342,7 @@ SCALA_METHOD_MAPPINGS = {
def _create_column_from_literal(literal):
sc = SparkContext._active_spark_context
- return sc._jvm.Literal.apply(literal)
+ return sc._jvm.org.apache.spark.sql.api.java.dsl.lit(literal)
def _create_column_from_name(name):
@@ -2371,13 +2371,20 @@ def _unary_op(name):
return _
-def _bin_op(name):
- """ Create a method for given binary operator """
+def _bin_op(name, pass_literal_through=False):
+ """ Create a method for given binary operator
+
+ Keyword arguments:
+ pass_literal_through -- whether to pass literal value directly through to the JVM.
+ """
def _(self, other):
if isinstance(other, Column):
jc = other._jc
else:
- jc = _create_column_from_literal(other)
+ if pass_literal_through:
+ jc = other
+ else:
+ jc = _create_column_from_literal(other)
return Column(getattr(self._jc, _scalaMethod(name))(jc), self._jdf, self.sql_ctx)
return _
@@ -2458,10 +2465,10 @@ class Column(DataFrame):
# __getattr__ = _bin_op("getField")
# string methods
- rlike = _bin_op("rlike")
- like = _bin_op("like")
- startswith = _bin_op("startsWith")
- endswith = _bin_op("endsWith")
+ rlike = _bin_op("rlike", pass_literal_through=True)
+ like = _bin_op("like", pass_literal_through=True)
+ startswith = _bin_op("startsWith", pass_literal_through=True)
+ endswith = _bin_op("endsWith", pass_literal_through=True)
upper = _unary_op("upper")
lower = _unary_op("lower")
@@ -2487,7 +2494,7 @@ class Column(DataFrame):
isNotNull = _unary_op("isNotNull")
# `as` is keyword
- def As(self, alias):
+ def alias(self, alias):
return Column(getattr(self._jsc, "as")(alias), self._jdf, self.sql_ctx)
def cast(self, dataType):
@@ -2501,15 +2508,14 @@ class Column(DataFrame):
def _aggregate_func(name):
- """ Creat a function for aggregator by name"""
+ """ Create a function for aggregator by name"""
def _(col):
sc = SparkContext._active_spark_context
if isinstance(col, Column):
jcol = col._jc
else:
jcol = _create_column_from_name(col)
- # FIXME: can not access dsl.min/max ...
- jc = getattr(sc._jvm.org.apache.spark.sql.dsl(), name)(jcol)
+ jc = getattr(sc._jvm.org.apache.spark.sql.api.java.dsl, name)(jcol)
return Column(jc)
return staticmethod(_)