diff options
Diffstat (limited to 'python/pyspark')
-rw-r--r-- | python/pyspark/sql.py | 38 |
1 files changed, 22 insertions, 16 deletions
diff --git a/python/pyspark/sql.py b/python/pyspark/sql.py index c3a6938f56..fdd8034c98 100644 --- a/python/pyspark/sql.py +++ b/python/pyspark/sql.py @@ -931,7 +931,7 @@ def _parse_schema_abstract(s): def _infer_schema_type(obj, dataType): """ - Fill the dataType with types infered from obj + Fill the dataType with types inferred from obj >>> schema = _parse_schema_abstract("a b c d") >>> row = (1, 1.0, "str", datetime.date(2014, 10, 10)) @@ -2140,7 +2140,7 @@ class DataFrame(object): return Column(self._jdf.apply(name)) raise AttributeError - def As(self, name): + def alias(self, name): """ Alias the current DataFrame """ return DataFrame(getattr(self._jdf, "as")(name), self.sql_ctx) @@ -2216,7 +2216,7 @@ class DataFrame(object): """ return DataFrame(self._jdf.intersect(other._jdf), self.sql_ctx) - def Except(self, other): + def subtract(self, other): """ Return a new [[DataFrame]] containing rows in this frame but not in another frame. @@ -2234,7 +2234,7 @@ class DataFrame(object): def addColumn(self, colName, col): """ Return a new [[DataFrame]] by adding a column. """ - return self.select('*', col.As(colName)) + return self.select('*', col.alias(colName)) def removeColumn(self, colName): raise NotImplemented @@ -2342,7 +2342,7 @@ SCALA_METHOD_MAPPINGS = { def _create_column_from_literal(literal): sc = SparkContext._active_spark_context - return sc._jvm.Literal.apply(literal) + return sc._jvm.org.apache.spark.sql.api.java.dsl.lit(literal) def _create_column_from_name(name): @@ -2371,13 +2371,20 @@ def _unary_op(name): return _ -def _bin_op(name): - """ Create a method for given binary operator """ +def _bin_op(name, pass_literal_through=False): + """ Create a method for given binary operator + + Keyword arguments: + pass_literal_through -- whether to pass literal value directly through to the JVM. + """ def _(self, other): if isinstance(other, Column): jc = other._jc else: - jc = _create_column_from_literal(other) + if pass_literal_through: + jc = other + else: + jc = _create_column_from_literal(other) return Column(getattr(self._jc, _scalaMethod(name))(jc), self._jdf, self.sql_ctx) return _ @@ -2458,10 +2465,10 @@ class Column(DataFrame): # __getattr__ = _bin_op("getField") # string methods - rlike = _bin_op("rlike") - like = _bin_op("like") - startswith = _bin_op("startsWith") - endswith = _bin_op("endsWith") + rlike = _bin_op("rlike", pass_literal_through=True) + like = _bin_op("like", pass_literal_through=True) + startswith = _bin_op("startsWith", pass_literal_through=True) + endswith = _bin_op("endsWith", pass_literal_through=True) upper = _unary_op("upper") lower = _unary_op("lower") @@ -2487,7 +2494,7 @@ class Column(DataFrame): isNotNull = _unary_op("isNotNull") # `as` is keyword - def As(self, alias): + def alias(self, alias): return Column(getattr(self._jsc, "as")(alias), self._jdf, self.sql_ctx) def cast(self, dataType): @@ -2501,15 +2508,14 @@ class Column(DataFrame): def _aggregate_func(name): - """ Creat a function for aggregator by name""" + """ Create a function for aggregator by name""" def _(col): sc = SparkContext._active_spark_context if isinstance(col, Column): jcol = col._jc else: jcol = _create_column_from_name(col) - # FIXME: can not access dsl.min/max ... - jc = getattr(sc._jvm.org.apache.spark.sql.dsl(), name)(jcol) + jc = getattr(sc._jvm.org.apache.spark.sql.api.java.dsl, name)(jcol) return Column(jc) return staticmethod(_) |