From 08488c175f2e8532cb6aab84da2abd9ad57179cc Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Mon, 9 Feb 2015 20:49:22 -0800 Subject: [SPARK-5469] restructure pyspark.sql into multiple files All the DataTypes moved into pyspark.sql.types The changes can be tracked by `--find-copies-harder -M25` ``` davieslocalhost:~/work/spark/python$ git diff --find-copies-harder -M25 --numstat master.. 2 5 python/docs/pyspark.ml.rst 0 3 python/docs/pyspark.mllib.rst 10 2 python/docs/pyspark.sql.rst 1 1 python/pyspark/mllib/linalg.py 21 14 python/pyspark/{mllib => sql}/__init__.py 14 2108 python/pyspark/{sql.py => sql/context.py} 10 1772 python/pyspark/{sql.py => sql/dataframe.py} 7 6 python/pyspark/{sql_tests.py => sql/tests.py} 8 1465 python/pyspark/{sql.py => sql/types.py} 4 2 python/run-tests 1 1 sql/core/src/main/scala/org/apache/spark/sql/test/ExamplePointUDT.scala ``` Also `git blame -C -C python/pyspark/sql/context.py` to track the history. Author: Davies Liu Closes #4479 from davies/sql and squashes the following commits: 1b5f0a5 [Davies Liu] Merge branch 'master' of github.com:apache/spark into sql 2b2b983 [Davies Liu] restructure pyspark.sql --- python/pyspark/mllib/linalg.py | 2 +- python/pyspark/sql.py | 2736 --------------------------------------- python/pyspark/sql/__init__.py | 42 + python/pyspark/sql/context.py | 642 +++++++++ python/pyspark/sql/dataframe.py | 974 ++++++++++++++ python/pyspark/sql/tests.py | 300 +++++ python/pyspark/sql/types.py | 1279 ++++++++++++++++++ python/pyspark/sql_tests.py | 299 ----- 8 files changed, 3238 insertions(+), 3036 deletions(-) delete mode 100644 python/pyspark/sql.py create mode 100644 python/pyspark/sql/__init__.py create mode 100644 python/pyspark/sql/context.py create mode 100644 python/pyspark/sql/dataframe.py create mode 100644 python/pyspark/sql/tests.py create mode 100644 python/pyspark/sql/types.py delete mode 100644 python/pyspark/sql_tests.py (limited to 'python/pyspark') diff --git a/python/pyspark/mllib/linalg.py b/python/pyspark/mllib/linalg.py index 7f21190ed8..597012b1c9 100644 --- a/python/pyspark/mllib/linalg.py +++ b/python/pyspark/mllib/linalg.py @@ -29,7 +29,7 @@ import copy_reg import numpy as np -from pyspark.sql import UserDefinedType, StructField, StructType, ArrayType, DoubleType, \ +from pyspark.sql.types import UserDefinedType, StructField, StructType, ArrayType, DoubleType, \ IntegerType, ByteType diff --git a/python/pyspark/sql.py b/python/pyspark/sql.py deleted file mode 100644 index 6a6dfbc585..0000000000 --- a/python/pyspark/sql.py +++ /dev/null @@ -1,2736 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one or more -# contributor license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright ownership. -# The ASF licenses this file to You under the Apache License, Version 2.0 -# (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -""" -public classes of Spark SQL: - - - L{SQLContext} - Main entry point for SQL functionality. - - L{DataFrame} - A Resilient Distributed Dataset (RDD) with Schema information for the data contained. In - addition to normal RDD operations, DataFrames also support SQL. - - L{GroupedData} - - L{Column} - Column is a DataFrame with a single column. - - L{Row} - A Row of data returned by a Spark SQL query. - - L{HiveContext} - Main entry point for accessing data stored in Apache Hive.. -""" - -import sys -import itertools -import decimal -import datetime -import keyword -import warnings -import json -import re -import random -import os -from tempfile import NamedTemporaryFile -from array import array -from operator import itemgetter -from itertools import imap - -from py4j.protocol import Py4JError -from py4j.java_collections import ListConverter, MapConverter - -from pyspark.context import SparkContext -from pyspark.rdd import RDD, _prepare_for_python_RDD -from pyspark.serializers import BatchedSerializer, AutoBatchedSerializer, PickleSerializer, \ - CloudPickleSerializer, UTF8Deserializer -from pyspark.storagelevel import StorageLevel -from pyspark.traceback_utils import SCCallSiteSync - - -__all__ = [ - "StringType", "BinaryType", "BooleanType", "DateType", "TimestampType", "DecimalType", - "DoubleType", "FloatType", "ByteType", "IntegerType", "LongType", - "ShortType", "ArrayType", "MapType", "StructField", "StructType", - "SQLContext", "HiveContext", "DataFrame", "GroupedData", "Column", "Row", "Dsl", - "SchemaRDD"] - - -class DataType(object): - - """Spark SQL DataType""" - - def __repr__(self): - return self.__class__.__name__ - - def __hash__(self): - return hash(str(self)) - - def __eq__(self, other): - return (isinstance(other, self.__class__) and - self.__dict__ == other.__dict__) - - def __ne__(self, other): - return not self.__eq__(other) - - @classmethod - def typeName(cls): - return cls.__name__[:-4].lower() - - def jsonValue(self): - return self.typeName() - - def json(self): - return json.dumps(self.jsonValue(), - separators=(',', ':'), - sort_keys=True) - - -class PrimitiveTypeSingleton(type): - - """Metaclass for PrimitiveType""" - - _instances = {} - - def __call__(cls): - if cls not in cls._instances: - cls._instances[cls] = super(PrimitiveTypeSingleton, cls).__call__() - return cls._instances[cls] - - -class PrimitiveType(DataType): - - """Spark SQL PrimitiveType""" - - __metaclass__ = PrimitiveTypeSingleton - - def __eq__(self, other): - # because they should be the same object - return self is other - - -class NullType(PrimitiveType): - - """Spark SQL NullType - - The data type representing None, used for the types which has not - been inferred. - """ - - -class StringType(PrimitiveType): - - """Spark SQL StringType - - The data type representing string values. - """ - - -class BinaryType(PrimitiveType): - - """Spark SQL BinaryType - - The data type representing bytearray values. - """ - - -class BooleanType(PrimitiveType): - - """Spark SQL BooleanType - - The data type representing bool values. - """ - - -class DateType(PrimitiveType): - - """Spark SQL DateType - - The data type representing datetime.date values. - """ - - -class TimestampType(PrimitiveType): - - """Spark SQL TimestampType - - The data type representing datetime.datetime values. - """ - - -class DecimalType(DataType): - - """Spark SQL DecimalType - - The data type representing decimal.Decimal values. - """ - - def __init__(self, precision=None, scale=None): - self.precision = precision - self.scale = scale - self.hasPrecisionInfo = precision is not None - - def jsonValue(self): - if self.hasPrecisionInfo: - return "decimal(%d,%d)" % (self.precision, self.scale) - else: - return "decimal" - - def __repr__(self): - if self.hasPrecisionInfo: - return "DecimalType(%d,%d)" % (self.precision, self.scale) - else: - return "DecimalType()" - - -class DoubleType(PrimitiveType): - - """Spark SQL DoubleType - - The data type representing float values. - """ - - -class FloatType(PrimitiveType): - - """Spark SQL FloatType - - The data type representing single precision floating-point values. - """ - - -class ByteType(PrimitiveType): - - """Spark SQL ByteType - - The data type representing int values with 1 singed byte. - """ - - -class IntegerType(PrimitiveType): - - """Spark SQL IntegerType - - The data type representing int values. - """ - - -class LongType(PrimitiveType): - - """Spark SQL LongType - - The data type representing long values. If the any value is - beyond the range of [-9223372036854775808, 9223372036854775807], - please use DecimalType. - """ - - -class ShortType(PrimitiveType): - - """Spark SQL ShortType - - The data type representing int values with 2 signed bytes. - """ - - -class ArrayType(DataType): - - """Spark SQL ArrayType - - The data type representing list values. An ArrayType object - comprises two fields, elementType (a DataType) and containsNull (a bool). - The field of elementType is used to specify the type of array elements. - The field of containsNull is used to specify if the array has None values. - - """ - - def __init__(self, elementType, containsNull=True): - """Creates an ArrayType - - :param elementType: the data type of elements. - :param containsNull: indicates whether the list contains None values. - - >>> ArrayType(StringType) == ArrayType(StringType, True) - True - >>> ArrayType(StringType, False) == ArrayType(StringType) - False - """ - self.elementType = elementType - self.containsNull = containsNull - - def __repr__(self): - return "ArrayType(%s,%s)" % (self.elementType, - str(self.containsNull).lower()) - - def jsonValue(self): - return {"type": self.typeName(), - "elementType": self.elementType.jsonValue(), - "containsNull": self.containsNull} - - @classmethod - def fromJson(cls, json): - return ArrayType(_parse_datatype_json_value(json["elementType"]), - json["containsNull"]) - - -class MapType(DataType): - - """Spark SQL MapType - - The data type representing dict values. A MapType object comprises - three fields, keyType (a DataType), valueType (a DataType) and - valueContainsNull (a bool). - - The field of keyType is used to specify the type of keys in the map. - The field of valueType is used to specify the type of values in the map. - The field of valueContainsNull is used to specify if values of this - map has None values. - - For values of a MapType column, keys are not allowed to have None values. - - """ - - def __init__(self, keyType, valueType, valueContainsNull=True): - """Creates a MapType - :param keyType: the data type of keys. - :param valueType: the data type of values. - :param valueContainsNull: indicates whether values contains - null values. - - >>> (MapType(StringType, IntegerType) - ... == MapType(StringType, IntegerType, True)) - True - >>> (MapType(StringType, IntegerType, False) - ... == MapType(StringType, FloatType)) - False - """ - self.keyType = keyType - self.valueType = valueType - self.valueContainsNull = valueContainsNull - - def __repr__(self): - return "MapType(%s,%s,%s)" % (self.keyType, self.valueType, - str(self.valueContainsNull).lower()) - - def jsonValue(self): - return {"type": self.typeName(), - "keyType": self.keyType.jsonValue(), - "valueType": self.valueType.jsonValue(), - "valueContainsNull": self.valueContainsNull} - - @classmethod - def fromJson(cls, json): - return MapType(_parse_datatype_json_value(json["keyType"]), - _parse_datatype_json_value(json["valueType"]), - json["valueContainsNull"]) - - -class StructField(DataType): - - """Spark SQL StructField - - Represents a field in a StructType. - A StructField object comprises three fields, name (a string), - dataType (a DataType) and nullable (a bool). The field of name - is the name of a StructField. The field of dataType specifies - the data type of a StructField. - - The field of nullable specifies if values of a StructField can - contain None values. - - """ - - def __init__(self, name, dataType, nullable=True, metadata=None): - """Creates a StructField - :param name: the name of this field. - :param dataType: the data type of this field. - :param nullable: indicates whether values of this field - can be null. - :param metadata: metadata of this field, which is a map from string - to simple type that can be serialized to JSON - automatically - - >>> (StructField("f1", StringType, True) - ... == StructField("f1", StringType, True)) - True - >>> (StructField("f1", StringType, True) - ... == StructField("f2", StringType, True)) - False - """ - self.name = name - self.dataType = dataType - self.nullable = nullable - self.metadata = metadata or {} - - def __repr__(self): - return "StructField(%s,%s,%s)" % (self.name, self.dataType, - str(self.nullable).lower()) - - def jsonValue(self): - return {"name": self.name, - "type": self.dataType.jsonValue(), - "nullable": self.nullable, - "metadata": self.metadata} - - @classmethod - def fromJson(cls, json): - return StructField(json["name"], - _parse_datatype_json_value(json["type"]), - json["nullable"], - json["metadata"]) - - -class StructType(DataType): - - """Spark SQL StructType - - The data type representing rows. - A StructType object comprises a list of L{StructField}. - - """ - - def __init__(self, fields): - """Creates a StructType - - >>> struct1 = StructType([StructField("f1", StringType, True)]) - >>> struct2 = StructType([StructField("f1", StringType, True)]) - >>> struct1 == struct2 - True - >>> struct1 = StructType([StructField("f1", StringType, True)]) - >>> struct2 = StructType([StructField("f1", StringType, True), - ... [StructField("f2", IntegerType, False)]]) - >>> struct1 == struct2 - False - """ - self.fields = fields - - def __repr__(self): - return ("StructType(List(%s))" % - ",".join(str(field) for field in self.fields)) - - def jsonValue(self): - return {"type": self.typeName(), - "fields": [f.jsonValue() for f in self.fields]} - - @classmethod - def fromJson(cls, json): - return StructType([StructField.fromJson(f) for f in json["fields"]]) - - -class UserDefinedType(DataType): - """ - .. note:: WARN: Spark Internal Use Only - SQL User-Defined Type (UDT). - """ - - @classmethod - def typeName(cls): - return cls.__name__.lower() - - @classmethod - def sqlType(cls): - """ - Underlying SQL storage type for this UDT. - """ - raise NotImplementedError("UDT must implement sqlType().") - - @classmethod - def module(cls): - """ - The Python module of the UDT. - """ - raise NotImplementedError("UDT must implement module().") - - @classmethod - def scalaUDT(cls): - """ - The class name of the paired Scala UDT. - """ - raise NotImplementedError("UDT must have a paired Scala UDT.") - - def serialize(self, obj): - """ - Converts the a user-type object into a SQL datum. - """ - raise NotImplementedError("UDT must implement serialize().") - - def deserialize(self, datum): - """ - Converts a SQL datum into a user-type object. - """ - raise NotImplementedError("UDT must implement deserialize().") - - def json(self): - return json.dumps(self.jsonValue(), separators=(',', ':'), sort_keys=True) - - def jsonValue(self): - schema = { - "type": "udt", - "class": self.scalaUDT(), - "pyClass": "%s.%s" % (self.module(), type(self).__name__), - "sqlType": self.sqlType().jsonValue() - } - return schema - - @classmethod - def fromJson(cls, json): - pyUDT = json["pyClass"] - split = pyUDT.rfind(".") - pyModule = pyUDT[:split] - pyClass = pyUDT[split+1:] - m = __import__(pyModule, globals(), locals(), [pyClass], -1) - UDT = getattr(m, pyClass) - return UDT() - - def __eq__(self, other): - return type(self) == type(other) - - -_all_primitive_types = dict((v.typeName(), v) - for v in globals().itervalues() - if type(v) is PrimitiveTypeSingleton and - v.__base__ == PrimitiveType) - - -_all_complex_types = dict((v.typeName(), v) - for v in [ArrayType, MapType, StructType]) - - -def _parse_datatype_json_string(json_string): - """Parses the given data type JSON string. - >>> def check_datatype(datatype): - ... scala_datatype = sqlCtx._ssql_ctx.parseDataType(datatype.json()) - ... python_datatype = _parse_datatype_json_string(scala_datatype.json()) - ... return datatype == python_datatype - >>> all(check_datatype(cls()) for cls in _all_primitive_types.values()) - True - >>> # Simple ArrayType. - >>> simple_arraytype = ArrayType(StringType(), True) - >>> check_datatype(simple_arraytype) - True - >>> # Simple MapType. - >>> simple_maptype = MapType(StringType(), LongType()) - >>> check_datatype(simple_maptype) - True - >>> # Simple StructType. - >>> simple_structtype = StructType([ - ... StructField("a", DecimalType(), False), - ... StructField("b", BooleanType(), True), - ... StructField("c", LongType(), True), - ... StructField("d", BinaryType(), False)]) - >>> check_datatype(simple_structtype) - True - >>> # Complex StructType. - >>> complex_structtype = StructType([ - ... StructField("simpleArray", simple_arraytype, True), - ... StructField("simpleMap", simple_maptype, True), - ... StructField("simpleStruct", simple_structtype, True), - ... StructField("boolean", BooleanType(), False), - ... StructField("withMeta", DoubleType(), False, {"name": "age"})]) - >>> check_datatype(complex_structtype) - True - >>> # Complex ArrayType. - >>> complex_arraytype = ArrayType(complex_structtype, True) - >>> check_datatype(complex_arraytype) - True - >>> # Complex MapType. - >>> complex_maptype = MapType(complex_structtype, - ... complex_arraytype, False) - >>> check_datatype(complex_maptype) - True - >>> check_datatype(ExamplePointUDT()) - True - >>> structtype_with_udt = StructType([StructField("label", DoubleType(), False), - ... StructField("point", ExamplePointUDT(), False)]) - >>> check_datatype(structtype_with_udt) - True - """ - return _parse_datatype_json_value(json.loads(json_string)) - - -_FIXED_DECIMAL = re.compile("decimal\\((\\d+),(\\d+)\\)") - - -def _parse_datatype_json_value(json_value): - if type(json_value) is unicode: - if json_value in _all_primitive_types.keys(): - return _all_primitive_types[json_value]() - elif json_value == u'decimal': - return DecimalType() - elif _FIXED_DECIMAL.match(json_value): - m = _FIXED_DECIMAL.match(json_value) - return DecimalType(int(m.group(1)), int(m.group(2))) - else: - raise ValueError("Could not parse datatype: %s" % json_value) - else: - tpe = json_value["type"] - if tpe in _all_complex_types: - return _all_complex_types[tpe].fromJson(json_value) - elif tpe == 'udt': - return UserDefinedType.fromJson(json_value) - else: - raise ValueError("not supported type: %s" % tpe) - - -# Mapping Python types to Spark SQL DataType -_type_mappings = { - type(None): NullType, - bool: BooleanType, - int: IntegerType, - long: LongType, - float: DoubleType, - str: StringType, - unicode: StringType, - bytearray: BinaryType, - decimal.Decimal: DecimalType, - datetime.date: DateType, - datetime.datetime: TimestampType, - datetime.time: TimestampType, -} - - -def _infer_type(obj): - """Infer the DataType from obj - - >>> p = ExamplePoint(1.0, 2.0) - >>> _infer_type(p) - ExamplePointUDT - """ - if obj is None: - raise ValueError("Can not infer type for None") - - if hasattr(obj, '__UDT__'): - return obj.__UDT__ - - dataType = _type_mappings.get(type(obj)) - if dataType is not None: - return dataType() - - if isinstance(obj, dict): - for key, value in obj.iteritems(): - if key is not None and value is not None: - return MapType(_infer_type(key), _infer_type(value), True) - else: - return MapType(NullType(), NullType(), True) - elif isinstance(obj, (list, array)): - for v in obj: - if v is not None: - return ArrayType(_infer_type(obj[0]), True) - else: - return ArrayType(NullType(), True) - else: - try: - return _infer_schema(obj) - except ValueError: - raise ValueError("not supported type: %s" % type(obj)) - - -def _infer_schema(row): - """Infer the schema from dict/namedtuple/object""" - if isinstance(row, dict): - items = sorted(row.items()) - - elif isinstance(row, tuple): - if hasattr(row, "_fields"): # namedtuple - items = zip(row._fields, tuple(row)) - elif hasattr(row, "__FIELDS__"): # Row - items = zip(row.__FIELDS__, tuple(row)) - elif all(isinstance(x, tuple) and len(x) == 2 for x in row): - items = row - else: - raise ValueError("Can't infer schema from tuple") - - elif hasattr(row, "__dict__"): # object - items = sorted(row.__dict__.items()) - - else: - raise ValueError("Can not infer schema for type: %s" % type(row)) - - fields = [StructField(k, _infer_type(v), True) for k, v in items] - return StructType(fields) - - -def _need_python_to_sql_conversion(dataType): - """ - Checks whether we need python to sql conversion for the given type. - For now, only UDTs need this conversion. - - >>> _need_python_to_sql_conversion(DoubleType()) - False - >>> schema0 = StructType([StructField("indices", ArrayType(IntegerType(), False), False), - ... StructField("values", ArrayType(DoubleType(), False), False)]) - >>> _need_python_to_sql_conversion(schema0) - False - >>> _need_python_to_sql_conversion(ExamplePointUDT()) - True - >>> schema1 = ArrayType(ExamplePointUDT(), False) - >>> _need_python_to_sql_conversion(schema1) - True - >>> schema2 = StructType([StructField("label", DoubleType(), False), - ... StructField("point", ExamplePointUDT(), False)]) - >>> _need_python_to_sql_conversion(schema2) - True - """ - if isinstance(dataType, StructType): - return any([_need_python_to_sql_conversion(f.dataType) for f in dataType.fields]) - elif isinstance(dataType, ArrayType): - return _need_python_to_sql_conversion(dataType.elementType) - elif isinstance(dataType, MapType): - return _need_python_to_sql_conversion(dataType.keyType) or \ - _need_python_to_sql_conversion(dataType.valueType) - elif isinstance(dataType, UserDefinedType): - return True - else: - return False - - -def _python_to_sql_converter(dataType): - """ - Returns a converter that converts a Python object into a SQL datum for the given type. - - >>> conv = _python_to_sql_converter(DoubleType()) - >>> conv(1.0) - 1.0 - >>> conv = _python_to_sql_converter(ArrayType(DoubleType(), False)) - >>> conv([1.0, 2.0]) - [1.0, 2.0] - >>> conv = _python_to_sql_converter(ExamplePointUDT()) - >>> conv(ExamplePoint(1.0, 2.0)) - [1.0, 2.0] - >>> schema = StructType([StructField("label", DoubleType(), False), - ... StructField("point", ExamplePointUDT(), False)]) - >>> conv = _python_to_sql_converter(schema) - >>> conv((1.0, ExamplePoint(1.0, 2.0))) - (1.0, [1.0, 2.0]) - """ - if not _need_python_to_sql_conversion(dataType): - return lambda x: x - - if isinstance(dataType, StructType): - names, types = zip(*[(f.name, f.dataType) for f in dataType.fields]) - converters = map(_python_to_sql_converter, types) - - def converter(obj): - if isinstance(obj, dict): - return tuple(c(obj.get(n)) for n, c in zip(names, converters)) - elif isinstance(obj, tuple): - if hasattr(obj, "_fields") or hasattr(obj, "__FIELDS__"): - return tuple(c(v) for c, v in zip(converters, obj)) - elif all(isinstance(x, tuple) and len(x) == 2 for x in obj): # k-v pairs - d = dict(obj) - return tuple(c(d.get(n)) for n, c in zip(names, converters)) - else: - return tuple(c(v) for c, v in zip(converters, obj)) - else: - raise ValueError("Unexpected tuple %r with type %r" % (obj, dataType)) - return converter - elif isinstance(dataType, ArrayType): - element_converter = _python_to_sql_converter(dataType.elementType) - return lambda a: [element_converter(v) for v in a] - elif isinstance(dataType, MapType): - key_converter = _python_to_sql_converter(dataType.keyType) - value_converter = _python_to_sql_converter(dataType.valueType) - return lambda m: dict([(key_converter(k), value_converter(v)) for k, v in m.items()]) - elif isinstance(dataType, UserDefinedType): - return lambda obj: dataType.serialize(obj) - else: - raise ValueError("Unexpected type %r" % dataType) - - -def _has_nulltype(dt): - """ Return whether there is NullType in `dt` or not """ - if isinstance(dt, StructType): - return any(_has_nulltype(f.dataType) for f in dt.fields) - elif isinstance(dt, ArrayType): - return _has_nulltype((dt.elementType)) - elif isinstance(dt, MapType): - return _has_nulltype(dt.keyType) or _has_nulltype(dt.valueType) - else: - return isinstance(dt, NullType) - - -def _merge_type(a, b): - if isinstance(a, NullType): - return b - elif isinstance(b, NullType): - return a - elif type(a) is not type(b): - # TODO: type cast (such as int -> long) - raise TypeError("Can not merge type %s and %s" % (a, b)) - - # same type - if isinstance(a, StructType): - nfs = dict((f.name, f.dataType) for f in b.fields) - fields = [StructField(f.name, _merge_type(f.dataType, nfs.get(f.name, NullType()))) - for f in a.fields] - names = set([f.name for f in fields]) - for n in nfs: - if n not in names: - fields.append(StructField(n, nfs[n])) - return StructType(fields) - - elif isinstance(a, ArrayType): - return ArrayType(_merge_type(a.elementType, b.elementType), True) - - elif isinstance(a, MapType): - return MapType(_merge_type(a.keyType, b.keyType), - _merge_type(a.valueType, b.valueType), - True) - else: - return a - - -def _create_converter(dataType): - """Create an converter to drop the names of fields in obj """ - if isinstance(dataType, ArrayType): - conv = _create_converter(dataType.elementType) - return lambda row: map(conv, row) - - elif isinstance(dataType, MapType): - kconv = _create_converter(dataType.keyType) - vconv = _create_converter(dataType.valueType) - return lambda row: dict((kconv(k), vconv(v)) for k, v in row.iteritems()) - - elif isinstance(dataType, NullType): - return lambda x: None - - elif not isinstance(dataType, StructType): - return lambda x: x - - # dataType must be StructType - names = [f.name for f in dataType.fields] - converters = [_create_converter(f.dataType) for f in dataType.fields] - - def convert_struct(obj): - if obj is None: - return - - if isinstance(obj, tuple): - if hasattr(obj, "_fields"): - d = dict(zip(obj._fields, obj)) - elif hasattr(obj, "__FIELDS__"): - d = dict(zip(obj.__FIELDS__, obj)) - elif all(isinstance(x, tuple) and len(x) == 2 for x in obj): - d = dict(obj) - else: - raise ValueError("unexpected tuple: %s" % str(obj)) - - elif isinstance(obj, dict): - d = obj - elif hasattr(obj, "__dict__"): # object - d = obj.__dict__ - else: - raise ValueError("Unexpected obj: %s" % obj) - - return tuple([conv(d.get(name)) for name, conv in zip(names, converters)]) - - return convert_struct - - -_BRACKETS = {'(': ')', '[': ']', '{': '}'} - - -def _split_schema_abstract(s): - """ - split the schema abstract into fields - - >>> _split_schema_abstract("a b c") - ['a', 'b', 'c'] - >>> _split_schema_abstract("a(a b)") - ['a(a b)'] - >>> _split_schema_abstract("a b[] c{a b}") - ['a', 'b[]', 'c{a b}'] - >>> _split_schema_abstract(" ") - [] - """ - - r = [] - w = '' - brackets = [] - for c in s: - if c == ' ' and not brackets: - if w: - r.append(w) - w = '' - else: - w += c - if c in _BRACKETS: - brackets.append(c) - elif c in _BRACKETS.values(): - if not brackets or c != _BRACKETS[brackets.pop()]: - raise ValueError("unexpected " + c) - - if brackets: - raise ValueError("brackets not closed: %s" % brackets) - if w: - r.append(w) - return r - - -def _parse_field_abstract(s): - """ - Parse a field in schema abstract - - >>> _parse_field_abstract("a") - StructField(a,None,true) - >>> _parse_field_abstract("b(c d)") - StructField(b,StructType(...c,None,true),StructField(d... - >>> _parse_field_abstract("a[]") - StructField(a,ArrayType(None,true),true) - >>> _parse_field_abstract("a{[]}") - StructField(a,MapType(None,ArrayType(None,true),true),true) - """ - if set(_BRACKETS.keys()) & set(s): - idx = min((s.index(c) for c in _BRACKETS if c in s)) - name = s[:idx] - return StructField(name, _parse_schema_abstract(s[idx:]), True) - else: - return StructField(s, None, True) - - -def _parse_schema_abstract(s): - """ - parse abstract into schema - - >>> _parse_schema_abstract("a b c") - StructType...a...b...c... - >>> _parse_schema_abstract("a[b c] b{}") - StructType...a,ArrayType...b...c...b,MapType... - >>> _parse_schema_abstract("c{} d{a b}") - StructType...c,MapType...d,MapType...a...b... - >>> _parse_schema_abstract("a b(t)").fields[1] - StructField(b,StructType(List(StructField(t,None,true))),true) - """ - s = s.strip() - if not s: - return - - elif s.startswith('('): - return _parse_schema_abstract(s[1:-1]) - - elif s.startswith('['): - return ArrayType(_parse_schema_abstract(s[1:-1]), True) - - elif s.startswith('{'): - return MapType(None, _parse_schema_abstract(s[1:-1])) - - parts = _split_schema_abstract(s) - fields = [_parse_field_abstract(p) for p in parts] - return StructType(fields) - - -def _infer_schema_type(obj, dataType): - """ - Fill the dataType with types inferred from obj - - >>> schema = _parse_schema_abstract("a b c d") - >>> row = (1, 1.0, "str", datetime.date(2014, 10, 10)) - >>> _infer_schema_type(row, schema) - StructType...IntegerType...DoubleType...StringType...DateType... - >>> row = [[1], {"key": (1, 2.0)}] - >>> schema = _parse_schema_abstract("a[] b{c d}") - >>> _infer_schema_type(row, schema) - StructType...a,ArrayType...b,MapType(StringType,...c,IntegerType... - """ - if dataType is None: - return _infer_type(obj) - - if not obj: - return NullType() - - if isinstance(dataType, ArrayType): - eType = _infer_schema_type(obj[0], dataType.elementType) - return ArrayType(eType, True) - - elif isinstance(dataType, MapType): - k, v = obj.iteritems().next() - return MapType(_infer_schema_type(k, dataType.keyType), - _infer_schema_type(v, dataType.valueType)) - - elif isinstance(dataType, StructType): - fs = dataType.fields - assert len(fs) == len(obj), \ - "Obj(%s) have different length with fields(%s)" % (obj, fs) - fields = [StructField(f.name, _infer_schema_type(o, f.dataType), True) - for o, f in zip(obj, fs)] - return StructType(fields) - - else: - raise ValueError("Unexpected dataType: %s" % dataType) - - -_acceptable_types = { - BooleanType: (bool,), - ByteType: (int, long), - ShortType: (int, long), - IntegerType: (int, long), - LongType: (int, long), - FloatType: (float,), - DoubleType: (float,), - DecimalType: (decimal.Decimal,), - StringType: (str, unicode), - BinaryType: (bytearray,), - DateType: (datetime.date,), - TimestampType: (datetime.datetime,), - ArrayType: (list, tuple, array), - MapType: (dict,), - StructType: (tuple, list), -} - - -def _verify_type(obj, dataType): - """ - Verify the type of obj against dataType, raise an exception if - they do not match. - - >>> _verify_type(None, StructType([])) - >>> _verify_type("", StringType()) - >>> _verify_type(0, IntegerType()) - >>> _verify_type(range(3), ArrayType(ShortType())) - >>> _verify_type(set(), ArrayType(StringType())) # doctest: +IGNORE_EXCEPTION_DETAIL - Traceback (most recent call last): - ... - TypeError:... - >>> _verify_type({}, MapType(StringType(), IntegerType())) - >>> _verify_type((), StructType([])) - >>> _verify_type([], StructType([])) - >>> _verify_type([1], StructType([])) # doctest: +IGNORE_EXCEPTION_DETAIL - Traceback (most recent call last): - ... - ValueError:... - >>> _verify_type(ExamplePoint(1.0, 2.0), ExamplePointUDT()) - >>> _verify_type([1.0, 2.0], ExamplePointUDT()) # doctest: +IGNORE_EXCEPTION_DETAIL - Traceback (most recent call last): - ... - ValueError:... - """ - # all objects are nullable - if obj is None: - return - - if isinstance(dataType, UserDefinedType): - if not (hasattr(obj, '__UDT__') and obj.__UDT__ == dataType): - raise ValueError("%r is not an instance of type %r" % (obj, dataType)) - _verify_type(dataType.serialize(obj), dataType.sqlType()) - return - - _type = type(dataType) - assert _type in _acceptable_types, "unkown datatype: %s" % dataType - - # subclass of them can not be deserialized in JVM - if type(obj) not in _acceptable_types[_type]: - raise TypeError("%s can not accept object in type %s" - % (dataType, type(obj))) - - if isinstance(dataType, ArrayType): - for i in obj: - _verify_type(i, dataType.elementType) - - elif isinstance(dataType, MapType): - for k, v in obj.iteritems(): - _verify_type(k, dataType.keyType) - _verify_type(v, dataType.valueType) - - elif isinstance(dataType, StructType): - if len(obj) != len(dataType.fields): - raise ValueError("Length of object (%d) does not match with" - "length of fields (%d)" % (len(obj), len(dataType.fields))) - for v, f in zip(obj, dataType.fields): - _verify_type(v, f.dataType) - - -_cached_cls = {} - - -def _restore_object(dataType, obj): - """ Restore object during unpickling. """ - # use id(dataType) as key to speed up lookup in dict - # Because of batched pickling, dataType will be the - # same object in most cases. - k = id(dataType) - cls = _cached_cls.get(k) - if cls is None: - # use dataType as key to avoid create multiple class - cls = _cached_cls.get(dataType) - if cls is None: - cls = _create_cls(dataType) - _cached_cls[dataType] = cls - _cached_cls[k] = cls - return cls(obj) - - -def _create_object(cls, v): - """ Create an customized object with class `cls`. """ - # datetime.date would be deserialized as datetime.datetime - # from java type, so we need to set it back. - if cls is datetime.date and isinstance(v, datetime.datetime): - return v.date() - return cls(v) if v is not None else v - - -def _create_getter(dt, i): - """ Create a getter for item `i` with schema """ - cls = _create_cls(dt) - - def getter(self): - return _create_object(cls, self[i]) - - return getter - - -def _has_struct_or_date(dt): - """Return whether `dt` is or has StructType/DateType in it""" - if isinstance(dt, StructType): - return True - elif isinstance(dt, ArrayType): - return _has_struct_or_date(dt.elementType) - elif isinstance(dt, MapType): - return _has_struct_or_date(dt.keyType) or _has_struct_or_date(dt.valueType) - elif isinstance(dt, DateType): - return True - elif isinstance(dt, UserDefinedType): - return True - return False - - -def _create_properties(fields): - """Create properties according to fields""" - ps = {} - for i, f in enumerate(fields): - name = f.name - if (name.startswith("__") and name.endswith("__") - or keyword.iskeyword(name)): - warnings.warn("field name %s can not be accessed in Python," - "use position to access it instead" % name) - if _has_struct_or_date(f.dataType): - # delay creating object until accessing it - getter = _create_getter(f.dataType, i) - else: - getter = itemgetter(i) - ps[name] = property(getter) - return ps - - -def _create_cls(dataType): - """ - Create an class by dataType - - The created class is similar to namedtuple, but can have nested schema. - - >>> schema = _parse_schema_abstract("a b c") - >>> row = (1, 1.0, "str") - >>> schema = _infer_schema_type(row, schema) - >>> obj = _create_cls(schema)(row) - >>> import pickle - >>> pickle.loads(pickle.dumps(obj)) - Row(a=1, b=1.0, c='str') - - >>> row = [[1], {"key": (1, 2.0)}] - >>> schema = _parse_schema_abstract("a[] b{c d}") - >>> schema = _infer_schema_type(row, schema) - >>> obj = _create_cls(schema)(row) - >>> pickle.loads(pickle.dumps(obj)) - Row(a=[1], b={'key': Row(c=1, d=2.0)}) - >>> pickle.loads(pickle.dumps(obj.a)) - [1] - >>> pickle.loads(pickle.dumps(obj.b)) - {'key': Row(c=1, d=2.0)} - """ - - if isinstance(dataType, ArrayType): - cls = _create_cls(dataType.elementType) - - def List(l): - if l is None: - return - return [_create_object(cls, v) for v in l] - - return List - - elif isinstance(dataType, MapType): - kcls = _create_cls(dataType.keyType) - vcls = _create_cls(dataType.valueType) - - def Dict(d): - if d is None: - return - return dict((_create_object(kcls, k), _create_object(vcls, v)) for k, v in d.items()) - - return Dict - - elif isinstance(dataType, DateType): - return datetime.date - - elif isinstance(dataType, UserDefinedType): - return lambda datum: dataType.deserialize(datum) - - elif not isinstance(dataType, StructType): - # no wrapper for primitive types - return lambda x: x - - class Row(tuple): - - """ Row in DataFrame """ - __DATATYPE__ = dataType - __FIELDS__ = tuple(f.name for f in dataType.fields) - __slots__ = () - - # create property for fast access - locals().update(_create_properties(dataType.fields)) - - def asDict(self): - """ Return as a dict """ - return dict((n, getattr(self, n)) for n in self.__FIELDS__) - - def __repr__(self): - # call collect __repr__ for nested objects - return ("Row(%s)" % ", ".join("%s=%r" % (n, getattr(self, n)) - for n in self.__FIELDS__)) - - def __reduce__(self): - return (_restore_object, (self.__DATATYPE__, tuple(self))) - - return Row - - -class SQLContext(object): - - """Main entry point for Spark SQL functionality. - - A SQLContext can be used create L{DataFrame}, register L{DataFrame} as - tables, execute SQL over tables, cache tables, and read parquet files. - """ - - def __init__(self, sparkContext, sqlContext=None): - """Create a new SQLContext. - - :param sparkContext: The SparkContext to wrap. - :param sqlContext: An optional JVM Scala SQLContext. If set, we do not instatiate a new - SQLContext in the JVM, instead we make all calls to this object. - - >>> df = sqlCtx.inferSchema(rdd) - >>> sqlCtx.inferSchema(df) # doctest: +IGNORE_EXCEPTION_DETAIL - Traceback (most recent call last): - ... - TypeError:... - - >>> bad_rdd = sc.parallelize([1,2,3]) - >>> sqlCtx.inferSchema(bad_rdd) # doctest: +IGNORE_EXCEPTION_DETAIL - Traceback (most recent call last): - ... - ValueError:... - - >>> from datetime import datetime - >>> allTypes = sc.parallelize([Row(i=1, s="string", d=1.0, l=1L, - ... b=True, list=[1, 2, 3], dict={"s": 0}, row=Row(a=1), - ... time=datetime(2014, 8, 1, 14, 1, 5))]) - >>> df = sqlCtx.inferSchema(allTypes) - >>> df.registerTempTable("allTypes") - >>> sqlCtx.sql('select i+1, d+1, not b, list[1], dict["s"], time, row.a ' - ... 'from allTypes where b and i > 0').collect() - [Row(c0=2, c1=2.0, c2=False, c3=2, c4=0...8, 1, 14, 1, 5), a=1)] - >>> df.map(lambda x: (x.i, x.s, x.d, x.l, x.b, x.time, - ... x.row.a, x.list)).collect() - [(1, u'string', 1.0, 1, True, ...(2014, 8, 1, 14, 1, 5), 1, [1, 2, 3])] - """ - self._sc = sparkContext - self._jsc = self._sc._jsc - self._jvm = self._sc._jvm - self._scala_SQLContext = sqlContext - - @property - def _ssql_ctx(self): - """Accessor for the JVM Spark SQL context. - - Subclasses can override this property to provide their own - JVM Contexts. - """ - if self._scala_SQLContext is None: - self._scala_SQLContext = self._jvm.SQLContext(self._jsc.sc()) - return self._scala_SQLContext - - def registerFunction(self, name, f, returnType=StringType()): - """Registers a lambda function as a UDF so it can be used in SQL statements. - - In addition to a name and the function itself, the return type can be optionally specified. - When the return type is not given it default to a string and conversion will automatically - be done. For any other return type, the produced object must match the specified type. - - >>> sqlCtx.registerFunction("stringLengthString", lambda x: len(x)) - >>> sqlCtx.sql("SELECT stringLengthString('test')").collect() - [Row(c0=u'4')] - >>> sqlCtx.registerFunction("stringLengthInt", lambda x: len(x), IntegerType()) - >>> sqlCtx.sql("SELECT stringLengthInt('test')").collect() - [Row(c0=4)] - """ - func = lambda _, it: imap(lambda x: f(*x), it) - ser = AutoBatchedSerializer(PickleSerializer()) - command = (func, None, ser, ser) - pickled_cmd, bvars, env, includes = _prepare_for_python_RDD(self._sc, command, self) - self._ssql_ctx.udf().registerPython(name, - bytearray(pickled_cmd), - env, - includes, - self._sc.pythonExec, - bvars, - self._sc._javaAccumulator, - returnType.json()) - - def inferSchema(self, rdd, samplingRatio=None): - """Infer and apply a schema to an RDD of L{Row}. - - When samplingRatio is specified, the schema is inferred by looking - at the types of each row in the sampled dataset. Otherwise, the - first 100 rows of the RDD are inspected. Nested collections are - supported, which can include array, dict, list, Row, tuple, - namedtuple, or object. - - Each row could be L{pyspark.sql.Row} object or namedtuple or objects. - Using top level dicts is deprecated, as dict is used to represent Maps. - - If a single column has multiple distinct inferred types, it may cause - runtime exceptions. - - >>> rdd = sc.parallelize( - ... [Row(field1=1, field2="row1"), - ... Row(field1=2, field2="row2"), - ... Row(field1=3, field2="row3")]) - >>> df = sqlCtx.inferSchema(rdd) - >>> df.collect()[0] - Row(field1=1, field2=u'row1') - - >>> NestedRow = Row("f1", "f2") - >>> nestedRdd1 = sc.parallelize([ - ... NestedRow(array('i', [1, 2]), {"row1": 1.0}), - ... NestedRow(array('i', [2, 3]), {"row2": 2.0})]) - >>> df = sqlCtx.inferSchema(nestedRdd1) - >>> df.collect() - [Row(f1=[1, 2], f2={u'row1': 1.0}), ..., f2={u'row2': 2.0})] - - >>> nestedRdd2 = sc.parallelize([ - ... NestedRow([[1, 2], [2, 3]], [1, 2]), - ... NestedRow([[2, 3], [3, 4]], [2, 3])]) - >>> df = sqlCtx.inferSchema(nestedRdd2) - >>> df.collect() - [Row(f1=[[1, 2], [2, 3]], f2=[1, 2]), ..., f2=[2, 3])] - - >>> from collections import namedtuple - >>> CustomRow = namedtuple('CustomRow', 'field1 field2') - >>> rdd = sc.parallelize( - ... [CustomRow(field1=1, field2="row1"), - ... CustomRow(field1=2, field2="row2"), - ... CustomRow(field1=3, field2="row3")]) - >>> df = sqlCtx.inferSchema(rdd) - >>> df.collect()[0] - Row(field1=1, field2=u'row1') - """ - - if isinstance(rdd, DataFrame): - raise TypeError("Cannot apply schema to DataFrame") - - first = rdd.first() - if not first: - raise ValueError("The first row in RDD is empty, " - "can not infer schema") - if type(first) is dict: - warnings.warn("Using RDD of dict to inferSchema is deprecated," - "please use pyspark.sql.Row instead") - - if samplingRatio is None: - schema = _infer_schema(first) - if _has_nulltype(schema): - for row in rdd.take(100)[1:]: - schema = _merge_type(schema, _infer_schema(row)) - if not _has_nulltype(schema): - break - else: - warnings.warn("Some of types cannot be determined by the " - "first 100 rows, please try again with sampling") - else: - if samplingRatio > 0.99: - rdd = rdd.sample(False, float(samplingRatio)) - schema = rdd.map(_infer_schema).reduce(_merge_type) - - converter = _create_converter(schema) - rdd = rdd.map(converter) - return self.applySchema(rdd, schema) - - def applySchema(self, rdd, schema): - """ - Applies the given schema to the given RDD of L{tuple} or L{list}. - - These tuples or lists can contain complex nested structures like - lists, maps or nested rows. - - The schema should be a StructType. - - It is important that the schema matches the types of the objects - in each row or exceptions could be thrown at runtime. - - >>> rdd2 = sc.parallelize([(1, "row1"), (2, "row2"), (3, "row3")]) - >>> schema = StructType([StructField("field1", IntegerType(), False), - ... StructField("field2", StringType(), False)]) - >>> df = sqlCtx.applySchema(rdd2, schema) - >>> sqlCtx.registerRDDAsTable(df, "table1") - >>> df2 = sqlCtx.sql("SELECT * from table1") - >>> df2.collect() - [Row(field1=1, field2=u'row1'),..., Row(field1=3, field2=u'row3')] - - >>> from datetime import date, datetime - >>> rdd = sc.parallelize([(127, -128L, -32768, 32767, 2147483647L, 1.0, - ... date(2010, 1, 1), - ... datetime(2010, 1, 1, 1, 1, 1), - ... {"a": 1}, (2,), [1, 2, 3], None)]) - >>> schema = StructType([ - ... StructField("byte1", ByteType(), False), - ... StructField("byte2", ByteType(), False), - ... StructField("short1", ShortType(), False), - ... StructField("short2", ShortType(), False), - ... StructField("int", IntegerType(), False), - ... StructField("float", FloatType(), False), - ... StructField("date", DateType(), False), - ... StructField("time", TimestampType(), False), - ... StructField("map", - ... MapType(StringType(), IntegerType(), False), False), - ... StructField("struct", - ... StructType([StructField("b", ShortType(), False)]), False), - ... StructField("list", ArrayType(ByteType(), False), False), - ... StructField("null", DoubleType(), True)]) - >>> df = sqlCtx.applySchema(rdd, schema) - >>> results = df.map( - ... lambda x: (x.byte1, x.byte2, x.short1, x.short2, x.int, x.float, x.date, - ... x.time, x.map["a"], x.struct.b, x.list, x.null)) - >>> results.collect()[0] # doctest: +NORMALIZE_WHITESPACE - (127, -128, -32768, 32767, 2147483647, 1.0, datetime.date(2010, 1, 1), - datetime.datetime(2010, 1, 1, 1, 1, 1), 1, 2, [1, 2, 3], None) - - >>> df.registerTempTable("table2") - >>> sqlCtx.sql( - ... "SELECT byte1 - 1 AS byte1, byte2 + 1 AS byte2, " + - ... "short1 + 1 AS short1, short2 - 1 AS short2, int - 1 AS int, " + - ... "float + 1.5 as float FROM table2").collect() - [Row(byte1=126, byte2=-127, short1=-32767, short2=32766, int=2147483646, float=2.5)] - - >>> rdd = sc.parallelize([(127, -32768, 1.0, - ... datetime(2010, 1, 1, 1, 1, 1), - ... {"a": 1}, (2,), [1, 2, 3])]) - >>> abstract = "byte short float time map{} struct(b) list[]" - >>> schema = _parse_schema_abstract(abstract) - >>> typedSchema = _infer_schema_type(rdd.first(), schema) - >>> df = sqlCtx.applySchema(rdd, typedSchema) - >>> df.collect() - [Row(byte=127, short=-32768, float=1.0, time=..., list=[1, 2, 3])] - """ - - if isinstance(rdd, DataFrame): - raise TypeError("Cannot apply schema to DataFrame") - - if not isinstance(schema, StructType): - raise TypeError("schema should be StructType") - - # take the first few rows to verify schema - rows = rdd.take(10) - # Row() cannot been deserialized by Pyrolite - if rows and isinstance(rows[0], tuple) and rows[0].__class__.__name__ == 'Row': - rdd = rdd.map(tuple) - rows = rdd.take(10) - - for row in rows: - _verify_type(row, schema) - - # convert python objects to sql data - converter = _python_to_sql_converter(schema) - rdd = rdd.map(converter) - - jrdd = self._jvm.SerDeUtil.toJavaArray(rdd._to_java_object_rdd()) - df = self._ssql_ctx.applySchemaToPythonRDD(jrdd.rdd(), schema.json()) - return DataFrame(df, self) - - def registerRDDAsTable(self, rdd, tableName): - """Registers the given RDD as a temporary table in the catalog. - - Temporary tables exist only during the lifetime of this instance of - SQLContext. - - >>> df = sqlCtx.inferSchema(rdd) - >>> sqlCtx.registerRDDAsTable(df, "table1") - """ - if (rdd.__class__ is DataFrame): - df = rdd._jdf - self._ssql_ctx.registerRDDAsTable(df, tableName) - else: - raise ValueError("Can only register DataFrame as table") - - def parquetFile(self, *paths): - """Loads a Parquet file, returning the result as a L{DataFrame}. - - >>> import tempfile, shutil - >>> parquetFile = tempfile.mkdtemp() - >>> shutil.rmtree(parquetFile) - >>> df = sqlCtx.inferSchema(rdd) - >>> df.saveAsParquetFile(parquetFile) - >>> df2 = sqlCtx.parquetFile(parquetFile) - >>> sorted(df.collect()) == sorted(df2.collect()) - True - """ - gateway = self._sc._gateway - jpath = paths[0] - jpaths = gateway.new_array(gateway.jvm.java.lang.String, len(paths) - 1) - for i in range(1, len(paths)): - jpaths[i] = paths[i] - jdf = self._ssql_ctx.parquetFile(jpath, jpaths) - return DataFrame(jdf, self) - - def jsonFile(self, path, schema=None, samplingRatio=1.0): - """ - Loads a text file storing one JSON object per line as a - L{DataFrame}. - - If the schema is provided, applies the given schema to this - JSON dataset. - - Otherwise, it samples the dataset with ratio `samplingRatio` to - determine the schema. - - >>> import tempfile, shutil - >>> jsonFile = tempfile.mkdtemp() - >>> shutil.rmtree(jsonFile) - >>> ofn = open(jsonFile, 'w') - >>> for json in jsonStrings: - ... print>>ofn, json - >>> ofn.close() - >>> df1 = sqlCtx.jsonFile(jsonFile) - >>> sqlCtx.registerRDDAsTable(df1, "table1") - >>> df2 = sqlCtx.sql( - ... "SELECT field1 AS f1, field2 as f2, field3 as f3, " - ... "field6 as f4 from table1") - >>> for r in df2.collect(): - ... print r - Row(f1=1, f2=u'row1', f3=Row(field4=11, field5=None), f4=None) - Row(f1=2, f2=None, f3=Row(field4=22,..., f4=[Row(field7=u'row2')]) - Row(f1=None, f2=u'row3', f3=Row(field4=33, field5=[]), f4=None) - - >>> df3 = sqlCtx.jsonFile(jsonFile, df1.schema()) - >>> sqlCtx.registerRDDAsTable(df3, "table2") - >>> df4 = sqlCtx.sql( - ... "SELECT field1 AS f1, field2 as f2, field3 as f3, " - ... "field6 as f4 from table2") - >>> for r in df4.collect(): - ... print r - Row(f1=1, f2=u'row1', f3=Row(field4=11, field5=None), f4=None) - Row(f1=2, f2=None, f3=Row(field4=22,..., f4=[Row(field7=u'row2')]) - Row(f1=None, f2=u'row3', f3=Row(field4=33, field5=[]), f4=None) - - >>> schema = StructType([ - ... StructField("field2", StringType(), True), - ... StructField("field3", - ... StructType([ - ... StructField("field5", - ... ArrayType(IntegerType(), False), True)]), False)]) - >>> df5 = sqlCtx.jsonFile(jsonFile, schema) - >>> sqlCtx.registerRDDAsTable(df5, "table3") - >>> df6 = sqlCtx.sql( - ... "SELECT field2 AS f1, field3.field5 as f2, " - ... "field3.field5[0] as f3 from table3") - >>> df6.collect() - [Row(f1=u'row1', f2=None, f3=None)...Row(f1=u'row3', f2=[], f3=None)] - """ - if schema is None: - df = self._ssql_ctx.jsonFile(path, samplingRatio) - else: - scala_datatype = self._ssql_ctx.parseDataType(schema.json()) - df = self._ssql_ctx.jsonFile(path, scala_datatype) - return DataFrame(df, self) - - def jsonRDD(self, rdd, schema=None, samplingRatio=1.0): - """Loads an RDD storing one JSON object per string as a L{DataFrame}. - - If the schema is provided, applies the given schema to this - JSON dataset. - - Otherwise, it samples the dataset with ratio `samplingRatio` to - determine the schema. - - >>> df1 = sqlCtx.jsonRDD(json) - >>> sqlCtx.registerRDDAsTable(df1, "table1") - >>> df2 = sqlCtx.sql( - ... "SELECT field1 AS f1, field2 as f2, field3 as f3, " - ... "field6 as f4 from table1") - >>> for r in df2.collect(): - ... print r - Row(f1=1, f2=u'row1', f3=Row(field4=11, field5=None), f4=None) - Row(f1=2, f2=None, f3=Row(field4=22..., f4=[Row(field7=u'row2')]) - Row(f1=None, f2=u'row3', f3=Row(field4=33, field5=[]), f4=None) - - >>> df3 = sqlCtx.jsonRDD(json, df1.schema()) - >>> sqlCtx.registerRDDAsTable(df3, "table2") - >>> df4 = sqlCtx.sql( - ... "SELECT field1 AS f1, field2 as f2, field3 as f3, " - ... "field6 as f4 from table2") - >>> for r in df4.collect(): - ... print r - Row(f1=1, f2=u'row1', f3=Row(field4=11, field5=None), f4=None) - Row(f1=2, f2=None, f3=Row(field4=22..., f4=[Row(field7=u'row2')]) - Row(f1=None, f2=u'row3', f3=Row(field4=33, field5=[]), f4=None) - - >>> schema = StructType([ - ... StructField("field2", StringType(), True), - ... StructField("field3", - ... StructType([ - ... StructField("field5", - ... ArrayType(IntegerType(), False), True)]), False)]) - >>> df5 = sqlCtx.jsonRDD(json, schema) - >>> sqlCtx.registerRDDAsTable(df5, "table3") - >>> df6 = sqlCtx.sql( - ... "SELECT field2 AS f1, field3.field5 as f2, " - ... "field3.field5[0] as f3 from table3") - >>> df6.collect() - [Row(f1=u'row1', f2=None,...Row(f1=u'row3', f2=[], f3=None)] - - >>> sqlCtx.jsonRDD(sc.parallelize(['{}', - ... '{"key0": {"key1": "value1"}}'])).collect() - [Row(key0=None), Row(key0=Row(key1=u'value1'))] - >>> sqlCtx.jsonRDD(sc.parallelize(['{"key0": null}', - ... '{"key0": {"key1": "value1"}}'])).collect() - [Row(key0=None), Row(key0=Row(key1=u'value1'))] - """ - - def func(iterator): - for x in iterator: - if not isinstance(x, basestring): - x = unicode(x) - if isinstance(x, unicode): - x = x.encode("utf-8") - yield x - keyed = rdd.mapPartitions(func) - keyed._bypass_serializer = True - jrdd = keyed._jrdd.map(self._jvm.BytesToString()) - if schema is None: - df = self._ssql_ctx.jsonRDD(jrdd.rdd(), samplingRatio) - else: - scala_datatype = self._ssql_ctx.parseDataType(schema.json()) - df = self._ssql_ctx.jsonRDD(jrdd.rdd(), scala_datatype) - return DataFrame(df, self) - - def sql(self, sqlQuery): - """Return a L{DataFrame} representing the result of the given query. - - >>> df = sqlCtx.inferSchema(rdd) - >>> sqlCtx.registerRDDAsTable(df, "table1") - >>> df2 = sqlCtx.sql("SELECT field1 AS f1, field2 as f2 from table1") - >>> df2.collect() - [Row(f1=1, f2=u'row1'), Row(f1=2, f2=u'row2'), Row(f1=3, f2=u'row3')] - """ - return DataFrame(self._ssql_ctx.sql(sqlQuery), self) - - def table(self, tableName): - """Returns the specified table as a L{DataFrame}. - - >>> df = sqlCtx.inferSchema(rdd) - >>> sqlCtx.registerRDDAsTable(df, "table1") - >>> df2 = sqlCtx.table("table1") - >>> sorted(df.collect()) == sorted(df2.collect()) - True - """ - return DataFrame(self._ssql_ctx.table(tableName), self) - - def cacheTable(self, tableName): - """Caches the specified table in-memory.""" - self._ssql_ctx.cacheTable(tableName) - - def uncacheTable(self, tableName): - """Removes the specified table from the in-memory cache.""" - self._ssql_ctx.uncacheTable(tableName) - - -class HiveContext(SQLContext): - - """A variant of Spark SQL that integrates with data stored in Hive. - - Configuration for Hive is read from hive-site.xml on the classpath. - It supports running both SQL and HiveQL commands. - """ - - def __init__(self, sparkContext, hiveContext=None): - """Create a new HiveContext. - - :param sparkContext: The SparkContext to wrap. - :param hiveContext: An optional JVM Scala HiveContext. If set, we do not instatiate a new - HiveContext in the JVM, instead we make all calls to this object. - """ - SQLContext.__init__(self, sparkContext) - - if hiveContext: - self._scala_HiveContext = hiveContext - - @property - def _ssql_ctx(self): - try: - if not hasattr(self, '_scala_HiveContext'): - self._scala_HiveContext = self._get_hive_ctx() - return self._scala_HiveContext - except Py4JError as e: - raise Exception("You must build Spark with Hive. " - "Export 'SPARK_HIVE=true' and run " - "build/sbt assembly", e) - - def _get_hive_ctx(self): - return self._jvm.HiveContext(self._jsc.sc()) - - -def _create_row(fields, values): - row = Row(*values) - row.__FIELDS__ = fields - return row - - -class Row(tuple): - - """ - A row in L{DataFrame}. The fields in it can be accessed like attributes. - - Row can be used to create a row object by using named arguments, - the fields will be sorted by names. - - >>> row = Row(name="Alice", age=11) - >>> row - Row(age=11, name='Alice') - >>> row.name, row.age - ('Alice', 11) - - Row also can be used to create another Row like class, then it - could be used to create Row objects, such as - - >>> Person = Row("name", "age") - >>> Person - - >>> Person("Alice", 11) - Row(name='Alice', age=11) - """ - - def __new__(self, *args, **kwargs): - if args and kwargs: - raise ValueError("Can not use both args " - "and kwargs to create Row") - if args: - # create row class or objects - return tuple.__new__(self, args) - - elif kwargs: - # create row objects - names = sorted(kwargs.keys()) - values = tuple(kwargs[n] for n in names) - row = tuple.__new__(self, values) - row.__FIELDS__ = names - return row - - else: - raise ValueError("No args or kwargs") - - def asDict(self): - """ - Return as an dict - """ - if not hasattr(self, "__FIELDS__"): - raise TypeError("Cannot convert a Row class into dict") - return dict(zip(self.__FIELDS__, self)) - - # let obect acs like class - def __call__(self, *args): - """create new Row object""" - return _create_row(self, args) - - def __getattr__(self, item): - if item.startswith("__"): - raise AttributeError(item) - try: - # it will be slow when it has many fields, - # but this will not be used in normal cases - idx = self.__FIELDS__.index(item) - return self[idx] - except IndexError: - raise AttributeError(item) - - def __reduce__(self): - if hasattr(self, "__FIELDS__"): - return (_create_row, (self.__FIELDS__, tuple(self))) - else: - return tuple.__reduce__(self) - - def __repr__(self): - if hasattr(self, "__FIELDS__"): - return "Row(%s)" % ", ".join("%s=%r" % (k, v) - for k, v in zip(self.__FIELDS__, self)) - else: - return "" % ", ".join(self) - - -class DataFrame(object): - - """A collection of rows that have the same columns. - - A :class:`DataFrame` is equivalent to a relational table in Spark SQL, - and can be created using various functions in :class:`SQLContext`:: - - people = sqlContext.parquetFile("...") - - Once created, it can be manipulated using the various domain-specific-language - (DSL) functions defined in: :class:`DataFrame`, :class:`Column`. - - To select a column from the data frame, use the apply method:: - - ageCol = people.age - - Note that the :class:`Column` type can also be manipulated - through its various functions:: - - # The following creates a new column that increases everybody's age by 10. - people.age + 10 - - - A more concrete example:: - - # To create DataFrame using SQLContext - people = sqlContext.parquetFile("...") - department = sqlContext.parquetFile("...") - - people.filter(people.age > 30).join(department, people.deptId == department.id)) \ - .groupBy(department.name, "gender").agg({"salary": "avg", "age": "max"}) - """ - - def __init__(self, jdf, sql_ctx): - self._jdf = jdf - self.sql_ctx = sql_ctx - self._sc = sql_ctx and sql_ctx._sc - self.is_cached = False - - @property - def rdd(self): - """ - Return the content of the :class:`DataFrame` as an :class:`RDD` - of :class:`Row` s. - """ - if not hasattr(self, '_lazy_rdd'): - jrdd = self._jdf.javaToPython() - rdd = RDD(jrdd, self.sql_ctx._sc, BatchedSerializer(PickleSerializer())) - schema = self.schema() - - def applySchema(it): - cls = _create_cls(schema) - return itertools.imap(cls, it) - - self._lazy_rdd = rdd.mapPartitions(applySchema) - - return self._lazy_rdd - - def toJSON(self, use_unicode=False): - """Convert a DataFrame into a MappedRDD of JSON documents; one document per row. - - >>> df1 = sqlCtx.jsonRDD(json) - >>> sqlCtx.registerRDDAsTable(df1, "table1") - >>> df2 = sqlCtx.sql( "SELECT * from table1") - >>> df2.toJSON().take(1)[0] == '{"field1":1,"field2":"row1","field3":{"field4":11}}' - True - >>> df3 = sqlCtx.sql( "SELECT field3.field4 from table1") - >>> df3.toJSON().collect() == ['{"field4":11}', '{"field4":22}', '{"field4":33}'] - True - """ - rdd = self._jdf.toJSON() - return RDD(rdd.toJavaRDD(), self._sc, UTF8Deserializer(use_unicode)) - - def saveAsParquetFile(self, path): - """Save the contents as a Parquet file, preserving the schema. - - Files that are written out using this method can be read back in as - a DataFrame using the L{SQLContext.parquetFile} method. - - >>> import tempfile, shutil - >>> parquetFile = tempfile.mkdtemp() - >>> shutil.rmtree(parquetFile) - >>> df.saveAsParquetFile(parquetFile) - >>> df2 = sqlCtx.parquetFile(parquetFile) - >>> sorted(df2.collect()) == sorted(df.collect()) - True - """ - self._jdf.saveAsParquetFile(path) - - def registerTempTable(self, name): - """Registers this RDD as a temporary table using the given name. - - The lifetime of this temporary table is tied to the L{SQLContext} - that was used to create this DataFrame. - - >>> df.registerTempTable("people") - >>> df2 = sqlCtx.sql("select * from people") - >>> sorted(df.collect()) == sorted(df2.collect()) - True - """ - self._jdf.registerTempTable(name) - - def registerAsTable(self, name): - """DEPRECATED: use registerTempTable() instead""" - warnings.warn("Use registerTempTable instead of registerAsTable.", DeprecationWarning) - self.registerTempTable(name) - - def insertInto(self, tableName, overwrite=False): - """Inserts the contents of this DataFrame into the specified table. - - Optionally overwriting any existing data. - """ - self._jdf.insertInto(tableName, overwrite) - - def saveAsTable(self, tableName): - """Creates a new table with the contents of this DataFrame.""" - self._jdf.saveAsTable(tableName) - - def schema(self): - """Returns the schema of this DataFrame (represented by - a L{StructType}). - - >>> df.schema() - StructType(List(StructField(age,IntegerType,true),StructField(name,StringType,true))) - """ - return _parse_datatype_json_string(self._jdf.schema().json()) - - def printSchema(self): - """Prints out the schema in the tree format. - - >>> df.printSchema() - root - |-- age: integer (nullable = true) - |-- name: string (nullable = true) - - """ - print (self._jdf.schema().treeString()) - - def count(self): - """Return the number of elements in this RDD. - - Unlike the base RDD implementation of count, this implementation - leverages the query optimizer to compute the count on the DataFrame, - which supports features such as filter pushdown. - - >>> df.count() - 2L - """ - return self._jdf.count() - - def collect(self): - """Return a list that contains all of the rows. - - Each object in the list is a Row, the fields can be accessed as - attributes. - - >>> df.collect() - [Row(age=2, name=u'Alice'), Row(age=5, name=u'Bob')] - """ - with SCCallSiteSync(self._sc) as css: - bytesInJava = self._jdf.javaToPython().collect().iterator() - tempFile = NamedTemporaryFile(delete=False, dir=self._sc._temp_dir) - tempFile.close() - self._sc._writeToFile(bytesInJava, tempFile.name) - # Read the data into Python and deserialize it: - with open(tempFile.name, 'rb') as tempFile: - rs = list(BatchedSerializer(PickleSerializer()).load_stream(tempFile)) - os.unlink(tempFile.name) - cls = _create_cls(self.schema()) - return [cls(r) for r in rs] - - def limit(self, num): - """Limit the result count to the number specified. - - >>> df.limit(1).collect() - [Row(age=2, name=u'Alice')] - >>> df.limit(0).collect() - [] - """ - jdf = self._jdf.limit(num) - return DataFrame(jdf, self.sql_ctx) - - def take(self, num): - """Take the first num rows of the RDD. - - Each object in the list is a Row, the fields can be accessed as - attributes. - - >>> df.take(2) - [Row(age=2, name=u'Alice'), Row(age=5, name=u'Bob')] - """ - return self.limit(num).collect() - - def map(self, f): - """ Return a new RDD by applying a function to each Row, it's a - shorthand for df.rdd.map() - - >>> df.map(lambda p: p.name).collect() - [u'Alice', u'Bob'] - """ - return self.rdd.map(f) - - def mapPartitions(self, f, preservesPartitioning=False): - """ - Return a new RDD by applying a function to each partition. - - >>> rdd = sc.parallelize([1, 2, 3, 4], 4) - >>> def f(iterator): yield 1 - >>> rdd.mapPartitions(f).sum() - 4 - """ - return self.rdd.mapPartitions(f, preservesPartitioning) - - def cache(self): - """ Persist with the default storage level (C{MEMORY_ONLY_SER}). - """ - self.is_cached = True - self._jdf.cache() - return self - - def persist(self, storageLevel=StorageLevel.MEMORY_ONLY_SER): - """ Set the storage level to persist its values across operations - after the first time it is computed. This can only be used to assign - a new storage level if the RDD does not have a storage level set yet. - If no storage level is specified defaults to (C{MEMORY_ONLY_SER}). - """ - self.is_cached = True - javaStorageLevel = self._sc._getJavaStorageLevel(storageLevel) - self._jdf.persist(javaStorageLevel) - return self - - def unpersist(self, blocking=True): - """ Mark it as non-persistent, and remove all blocks for it from - memory and disk. - """ - self.is_cached = False - self._jdf.unpersist(blocking) - return self - - # def coalesce(self, numPartitions, shuffle=False): - # rdd = self._jdf.coalesce(numPartitions, shuffle, None) - # return DataFrame(rdd, self.sql_ctx) - - def repartition(self, numPartitions): - """ Return a new :class:`DataFrame` that has exactly `numPartitions` - partitions. - """ - rdd = self._jdf.repartition(numPartitions, None) - return DataFrame(rdd, self.sql_ctx) - - def sample(self, withReplacement, fraction, seed=None): - """ - Return a sampled subset of this DataFrame. - - >>> df = sqlCtx.inferSchema(rdd) - >>> df.sample(False, 0.5, 97).count() - 2L - """ - assert fraction >= 0.0, "Negative fraction value: %s" % fraction - seed = seed if seed is not None else random.randint(0, sys.maxint) - rdd = self._jdf.sample(withReplacement, fraction, long(seed)) - return DataFrame(rdd, self.sql_ctx) - - # def takeSample(self, withReplacement, num, seed=None): - # """Return a fixed-size sampled subset of this DataFrame. - # - # >>> df = sqlCtx.inferSchema(rdd) - # >>> df.takeSample(False, 2, 97) - # [Row(field1=3, field2=u'row3'), Row(field1=1, field2=u'row1')] - # """ - # seed = seed if seed is not None else random.randint(0, sys.maxint) - # with SCCallSiteSync(self.context) as css: - # bytesInJava = self._jdf \ - # .takeSampleToPython(withReplacement, num, long(seed)) \ - # .iterator() - # cls = _create_cls(self.schema()) - # return map(cls, self._collect_iterator_through_file(bytesInJava)) - - @property - def dtypes(self): - """Return all column names and their data types as a list. - - >>> df.dtypes - [('age', 'integer'), ('name', 'string')] - """ - return [(str(f.name), f.dataType.jsonValue()) for f in self.schema().fields] - - @property - def columns(self): - """ Return all column names as a list. - - >>> df.columns - [u'age', u'name'] - """ - return [f.name for f in self.schema().fields] - - def join(self, other, joinExprs=None, joinType=None): - """ - Join with another DataFrame, using the given join expression. - The following performs a full outer join between `df1` and `df2`:: - - :param other: Right side of the join - :param joinExprs: Join expression - :param joinType: One of `inner`, `outer`, `left_outer`, `right_outer`, `semijoin`. - - >>> df.join(df2, df.name == df2.name, 'outer').select(df.name, df2.height).collect() - [Row(name=None, height=80), Row(name=u'Bob', height=85), Row(name=u'Alice', height=None)] - """ - - if joinExprs is None: - jdf = self._jdf.join(other._jdf) - else: - assert isinstance(joinExprs, Column), "joinExprs should be Column" - if joinType is None: - jdf = self._jdf.join(other._jdf, joinExprs._jc) - else: - assert isinstance(joinType, basestring), "joinType should be basestring" - jdf = self._jdf.join(other._jdf, joinExprs._jc, joinType) - return DataFrame(jdf, self.sql_ctx) - - def sort(self, *cols): - """ Return a new :class:`DataFrame` sorted by the specified column. - - :param cols: The columns or expressions used for sorting - - >>> df.sort(df.age.desc()).collect() - [Row(age=5, name=u'Bob'), Row(age=2, name=u'Alice')] - >>> df.sortBy(df.age.desc()).collect() - [Row(age=5, name=u'Bob'), Row(age=2, name=u'Alice')] - """ - if not cols: - raise ValueError("should sort by at least one column") - jcols = ListConverter().convert([_to_java_column(c) for c in cols], - self._sc._gateway._gateway_client) - jdf = self._jdf.sort(self._sc._jvm.PythonUtils.toSeq(jcols)) - return DataFrame(jdf, self.sql_ctx) - - sortBy = sort - - def head(self, n=None): - """ Return the first `n` rows or the first row if n is None. - - >>> df.head() - Row(age=2, name=u'Alice') - >>> df.head(1) - [Row(age=2, name=u'Alice')] - """ - if n is None: - rs = self.head(1) - return rs[0] if rs else None - return self.take(n) - - def first(self): - """ Return the first row. - - >>> df.first() - Row(age=2, name=u'Alice') - """ - return self.head() - - def __getitem__(self, item): - """ Return the column by given name - - >>> df['age'].collect() - [Row(age=2), Row(age=5)] - >>> df[ ["name", "age"]].collect() - [Row(name=u'Alice', age=2), Row(name=u'Bob', age=5)] - >>> df[ df.age > 3 ].collect() - [Row(age=5, name=u'Bob')] - """ - if isinstance(item, basestring): - jc = self._jdf.apply(item) - return Column(jc, self.sql_ctx) - elif isinstance(item, Column): - return self.filter(item) - elif isinstance(item, list): - return self.select(*item) - else: - raise IndexError("unexpected index: %s" % item) - - def __getattr__(self, name): - """ Return the column by given name - - >>> df.age.collect() - [Row(age=2), Row(age=5)] - """ - if name.startswith("__"): - raise AttributeError(name) - jc = self._jdf.apply(name) - return Column(jc, self.sql_ctx) - - def select(self, *cols): - """ Selecting a set of expressions. - - >>> df.select().collect() - [Row(age=2, name=u'Alice'), Row(age=5, name=u'Bob')] - >>> df.select('*').collect() - [Row(age=2, name=u'Alice'), Row(age=5, name=u'Bob')] - >>> df.select('name', 'age').collect() - [Row(name=u'Alice', age=2), Row(name=u'Bob', age=5)] - >>> df.select(df.name, (df.age + 10).alias('age')).collect() - [Row(name=u'Alice', age=12), Row(name=u'Bob', age=15)] - """ - if not cols: - cols = ["*"] - jcols = ListConverter().convert([_to_java_column(c) for c in cols], - self._sc._gateway._gateway_client) - jdf = self._jdf.select(self.sql_ctx._sc._jvm.PythonUtils.toSeq(jcols)) - return DataFrame(jdf, self.sql_ctx) - - def selectExpr(self, *expr): - """ - Selects a set of SQL expressions. This is a variant of - `select` that accepts SQL expressions. - - >>> df.selectExpr("age * 2", "abs(age)").collect() - [Row(('age * 2)=4, Abs('age)=2), Row(('age * 2)=10, Abs('age)=5)] - """ - jexpr = ListConverter().convert(expr, self._sc._gateway._gateway_client) - jdf = self._jdf.selectExpr(self._sc._jvm.PythonUtils.toSeq(jexpr)) - return DataFrame(jdf, self.sql_ctx) - - def filter(self, condition): - """ Filtering rows using the given condition, which could be - Column expression or string of SQL expression. - - where() is an alias for filter(). - - >>> df.filter(df.age > 3).collect() - [Row(age=5, name=u'Bob')] - >>> df.where(df.age == 2).collect() - [Row(age=2, name=u'Alice')] - - >>> df.filter("age > 3").collect() - [Row(age=5, name=u'Bob')] - >>> df.where("age = 2").collect() - [Row(age=2, name=u'Alice')] - """ - if isinstance(condition, basestring): - jdf = self._jdf.filter(condition) - elif isinstance(condition, Column): - jdf = self._jdf.filter(condition._jc) - else: - raise TypeError("condition should be string or Column") - return DataFrame(jdf, self.sql_ctx) - - where = filter - - def groupBy(self, *cols): - """ Group the :class:`DataFrame` using the specified columns, - so we can run aggregation on them. See :class:`GroupedData` - for all the available aggregate functions. - - >>> df.groupBy().avg().collect() - [Row(AVG(age#0)=3.5)] - >>> df.groupBy('name').agg({'age': 'mean'}).collect() - [Row(name=u'Bob', AVG(age#0)=5.0), Row(name=u'Alice', AVG(age#0)=2.0)] - >>> df.groupBy(df.name).avg().collect() - [Row(name=u'Bob', AVG(age#0)=5.0), Row(name=u'Alice', AVG(age#0)=2.0)] - """ - jcols = ListConverter().convert([_to_java_column(c) for c in cols], - self._sc._gateway._gateway_client) - jdf = self._jdf.groupBy(self.sql_ctx._sc._jvm.PythonUtils.toSeq(jcols)) - return GroupedData(jdf, self.sql_ctx) - - def agg(self, *exprs): - """ Aggregate on the entire :class:`DataFrame` without groups - (shorthand for df.groupBy.agg()). - - >>> df.agg({"age": "max"}).collect() - [Row(MAX(age#0)=5)] - >>> from pyspark.sql import Dsl - >>> df.agg(Dsl.min(df.age)).collect() - [Row(MIN(age#0)=2)] - """ - return self.groupBy().agg(*exprs) - - def unionAll(self, other): - """ Return a new DataFrame containing union of rows in this - frame and another frame. - - This is equivalent to `UNION ALL` in SQL. - """ - return DataFrame(self._jdf.unionAll(other._jdf), self.sql_ctx) - - def intersect(self, other): - """ Return a new :class:`DataFrame` containing rows only in - both this frame and another frame. - - This is equivalent to `INTERSECT` in SQL. - """ - return DataFrame(self._jdf.intersect(other._jdf), self.sql_ctx) - - def subtract(self, other): - """ Return a new :class:`DataFrame` containing rows in this frame - but not in another frame. - - This is equivalent to `EXCEPT` in SQL. - """ - return DataFrame(getattr(self._jdf, "except")(other._jdf), self.sql_ctx) - - def addColumn(self, colName, col): - """ Return a new :class:`DataFrame` by adding a column. - - >>> df.addColumn('age2', df.age + 2).collect() - [Row(age=2, name=u'Alice', age2=4), Row(age=5, name=u'Bob', age2=7)] - """ - return self.select('*', col.alias(colName)) - - def to_pandas(self): - """ - Collect all the rows and return a `pandas.DataFrame`. - - >>> df.to_pandas() # doctest: +SKIP - age name - 0 2 Alice - 1 5 Bob - """ - import pandas as pd - return pd.DataFrame.from_records(self.collect(), columns=self.columns) - - -# Having SchemaRDD for backward compatibility (for docs) -class SchemaRDD(DataFrame): - """ - SchemaRDD is deprecated, please use DataFrame - """ - - -def dfapi(f): - def _api(self): - name = f.__name__ - jdf = getattr(self._jdf, name)() - return DataFrame(jdf, self.sql_ctx) - _api.__name__ = f.__name__ - _api.__doc__ = f.__doc__ - return _api - - -class GroupedData(object): - - """ - A set of methods for aggregations on a :class:`DataFrame`, - created by DataFrame.groupBy(). - """ - - def __init__(self, jdf, sql_ctx): - self._jdf = jdf - self.sql_ctx = sql_ctx - - def agg(self, *exprs): - """ Compute aggregates by specifying a map from column name - to aggregate methods. - - The available aggregate methods are `avg`, `max`, `min`, - `sum`, `count`. - - :param exprs: list or aggregate columns or a map from column - name to aggregate methods. - - >>> gdf = df.groupBy(df.name) - >>> gdf.agg({"age": "max"}).collect() - [Row(name=u'Bob', MAX(age#0)=5), Row(name=u'Alice', MAX(age#0)=2)] - >>> from pyspark.sql import Dsl - >>> gdf.agg(Dsl.min(df.age)).collect() - [Row(MIN(age#0)=5), Row(MIN(age#0)=2)] - """ - assert exprs, "exprs should not be empty" - if len(exprs) == 1 and isinstance(exprs[0], dict): - jmap = MapConverter().convert(exprs[0], - self.sql_ctx._sc._gateway._gateway_client) - jdf = self._jdf.agg(jmap) - else: - # Columns - assert all(isinstance(c, Column) for c in exprs), "all exprs should be Column" - jcols = ListConverter().convert([c._jc for c in exprs[1:]], - self.sql_ctx._sc._gateway._gateway_client) - jdf = self._jdf.agg(exprs[0]._jc, self.sql_ctx._sc._jvm.PythonUtils.toSeq(jcols)) - return DataFrame(jdf, self.sql_ctx) - - @dfapi - def count(self): - """ Count the number of rows for each group. - - >>> df.groupBy(df.age).count().collect() - [Row(age=2, count=1), Row(age=5, count=1)] - """ - - @dfapi - def mean(self): - """Compute the average value for each numeric columns - for each group. This is an alias for `avg`.""" - - @dfapi - def avg(self): - """Compute the average value for each numeric columns - for each group.""" - - @dfapi - def max(self): - """Compute the max value for each numeric columns for - each group. """ - - @dfapi - def min(self): - """Compute the min value for each numeric column for - each group.""" - - @dfapi - def sum(self): - """Compute the sum for each numeric columns for each - group.""" - - -def _create_column_from_literal(literal): - sc = SparkContext._active_spark_context - return sc._jvm.Dsl.lit(literal) - - -def _create_column_from_name(name): - sc = SparkContext._active_spark_context - return sc._jvm.Dsl.col(name) - - -def _to_java_column(col): - if isinstance(col, Column): - jcol = col._jc - else: - jcol = _create_column_from_name(col) - return jcol - - -def _unary_op(name, doc="unary operator"): - """ Create a method for given unary operator """ - def _(self): - jc = getattr(self._jc, name)() - return Column(jc, self.sql_ctx) - _.__doc__ = doc - return _ - - -def _dsl_op(name, doc=''): - def _(self): - jc = getattr(self._sc._jvm.Dsl, name)(self._jc) - return Column(jc, self.sql_ctx) - _.__doc__ = doc - return _ - - -def _bin_op(name, doc="binary operator"): - """ Create a method for given binary operator - """ - def _(self, other): - jc = other._jc if isinstance(other, Column) else other - njc = getattr(self._jc, name)(jc) - return Column(njc, self.sql_ctx) - _.__doc__ = doc - return _ - - -def _reverse_op(name, doc="binary operator"): - """ Create a method for binary operator (this object is on right side) - """ - def _(self, other): - jother = _create_column_from_literal(other) - jc = getattr(jother, name)(self._jc) - return Column(jc, self.sql_ctx) - _.__doc__ = doc - return _ - - -class Column(DataFrame): - - """ - A column in a DataFrame. - - `Column` instances can be created by:: - - # 1. Select a column out of a DataFrame - df.colName - df["colName"] - - # 2. Create from an expression - df.colName + 1 - 1 / df.colName - """ - - def __init__(self, jc, sql_ctx=None): - self._jc = jc - super(Column, self).__init__(jc, sql_ctx) - - # arithmetic operators - __neg__ = _dsl_op("negate") - __add__ = _bin_op("plus") - __sub__ = _bin_op("minus") - __mul__ = _bin_op("multiply") - __div__ = _bin_op("divide") - __mod__ = _bin_op("mod") - __radd__ = _bin_op("plus") - __rsub__ = _reverse_op("minus") - __rmul__ = _bin_op("multiply") - __rdiv__ = _reverse_op("divide") - __rmod__ = _reverse_op("mod") - - # logistic operators - __eq__ = _bin_op("equalTo") - __ne__ = _bin_op("notEqual") - __lt__ = _bin_op("lt") - __le__ = _bin_op("leq") - __ge__ = _bin_op("geq") - __gt__ = _bin_op("gt") - - # `and`, `or`, `not` cannot be overloaded in Python, - # so use bitwise operators as boolean operators - __and__ = _bin_op('and') - __or__ = _bin_op('or') - __invert__ = _dsl_op('not') - __rand__ = _bin_op("and") - __ror__ = _bin_op("or") - - # container operators - __contains__ = _bin_op("contains") - __getitem__ = _bin_op("getItem") - getField = _bin_op("getField", "An expression that gets a field by name in a StructField.") - - # string methods - rlike = _bin_op("rlike") - like = _bin_op("like") - startswith = _bin_op("startsWith") - endswith = _bin_op("endsWith") - - def substr(self, startPos, length): - """ - Return a Column which is a substring of the column - - :param startPos: start position (int or Column) - :param length: length of the substring (int or Column) - - >>> df.name.substr(1, 3).collect() - [Row(col=u'Ali'), Row(col=u'Bob')] - """ - if type(startPos) != type(length): - raise TypeError("Can not mix the type") - if isinstance(startPos, (int, long)): - jc = self._jc.substr(startPos, length) - elif isinstance(startPos, Column): - jc = self._jc.substr(startPos._jc, length._jc) - else: - raise TypeError("Unexpected type: %s" % type(startPos)) - return Column(jc, self.sql_ctx) - - __getslice__ = substr - - # order - asc = _unary_op("asc") - desc = _unary_op("desc") - - isNull = _unary_op("isNull", "True if the current expression is null.") - isNotNull = _unary_op("isNotNull", "True if the current expression is not null.") - - def alias(self, alias): - """Return a alias for this column - - >>> df.age.alias("age2").collect() - [Row(age2=2), Row(age2=5)] - """ - return Column(getattr(self._jc, "as")(alias), self.sql_ctx) - - def cast(self, dataType): - """ Convert the column into type `dataType` - - >>> df.select(df.age.cast("string").alias('ages')).collect() - [Row(ages=u'2'), Row(ages=u'5')] - >>> df.select(df.age.cast(StringType()).alias('ages')).collect() - [Row(ages=u'2'), Row(ages=u'5')] - """ - if self.sql_ctx is None: - sc = SparkContext._active_spark_context - ssql_ctx = sc._jvm.SQLContext(sc._jsc.sc()) - else: - ssql_ctx = self.sql_ctx._ssql_ctx - if isinstance(dataType, basestring): - jc = self._jc.cast(dataType) - elif isinstance(dataType, DataType): - jdt = ssql_ctx.parseDataType(dataType.json()) - jc = self._jc.cast(jdt) - return Column(jc, self.sql_ctx) - - def to_pandas(self): - """ - Return a pandas.Series from the column - - >>> df.age.to_pandas() # doctest: +SKIP - 0 2 - 1 5 - dtype: int64 - """ - import pandas as pd - data = [c for c, in self.collect()] - return pd.Series(data) - - -def _aggregate_func(name, doc=""): - """ Create a function for aggregator by name""" - def _(col): - sc = SparkContext._active_spark_context - jc = getattr(sc._jvm.Dsl, name)(_to_java_column(col)) - return Column(jc) - _.__name__ = name - _.__doc__ = doc - return staticmethod(_) - - -class UserDefinedFunction(object): - def __init__(self, func, returnType): - self.func = func - self.returnType = returnType - self._broadcast = None - self._judf = self._create_judf() - - def _create_judf(self): - f = self.func # put it in closure `func` - func = lambda _, it: imap(lambda x: f(*x), it) - ser = AutoBatchedSerializer(PickleSerializer()) - command = (func, None, ser, ser) - sc = SparkContext._active_spark_context - pickled_command, broadcast_vars, env, includes = _prepare_for_python_RDD(sc, command, self) - ssql_ctx = sc._jvm.SQLContext(sc._jsc.sc()) - jdt = ssql_ctx.parseDataType(self.returnType.json()) - judf = sc._jvm.UserDefinedPythonFunction(f.__name__, bytearray(pickled_command), env, - includes, sc.pythonExec, broadcast_vars, - sc._javaAccumulator, jdt) - return judf - - def __del__(self): - if self._broadcast is not None: - self._broadcast.unpersist() - self._broadcast = None - - def __call__(self, *cols): - sc = SparkContext._active_spark_context - jcols = ListConverter().convert([_to_java_column(c) for c in cols], - sc._gateway._gateway_client) - jc = self._judf.apply(sc._jvm.PythonUtils.toSeq(jcols)) - return Column(jc) - - -class Dsl(object): - """ - A collections of builtin aggregators - """ - DSLS = { - 'lit': 'Creates a :class:`Column` of literal value.', - 'col': 'Returns a :class:`Column` based on the given column name.', - 'column': 'Returns a :class:`Column` based on the given column name.', - 'upper': 'Converts a string expression to upper case.', - 'lower': 'Converts a string expression to upper case.', - 'sqrt': 'Computes the square root of the specified float value.', - 'abs': 'Computes the absolutle value.', - - 'max': 'Aggregate function: returns the maximum value of the expression in a group.', - 'min': 'Aggregate function: returns the minimum value of the expression in a group.', - 'first': 'Aggregate function: returns the first value in a group.', - 'last': 'Aggregate function: returns the last value in a group.', - 'count': 'Aggregate function: returns the number of items in a group.', - 'sum': 'Aggregate function: returns the sum of all values in the expression.', - 'avg': 'Aggregate function: returns the average of the values in a group.', - 'mean': 'Aggregate function: returns the average of the values in a group.', - 'sumDistinct': 'Aggregate function: returns the sum of distinct values in the expression.', - } - - for _name, _doc in DSLS.items(): - locals()[_name] = _aggregate_func(_name, _doc) - del _name, _doc - - @staticmethod - def countDistinct(col, *cols): - """ Return a new Column for distinct count of (col, *cols) - - >>> from pyspark.sql import Dsl - >>> df.agg(Dsl.countDistinct(df.age, df.name).alias('c')).collect() - [Row(c=2)] - - >>> df.agg(Dsl.countDistinct("age", "name").alias('c')).collect() - [Row(c=2)] - """ - sc = SparkContext._active_spark_context - jcols = ListConverter().convert([_to_java_column(c) for c in cols], - sc._gateway._gateway_client) - jc = sc._jvm.Dsl.countDistinct(_to_java_column(col), - sc._jvm.PythonUtils.toSeq(jcols)) - return Column(jc) - - @staticmethod - def approxCountDistinct(col, rsd=None): - """ Return a new Column for approxiate distinct count of (col, *cols) - - >>> from pyspark.sql import Dsl - >>> df.agg(Dsl.approxCountDistinct(df.age).alias('c')).collect() - [Row(c=2)] - """ - sc = SparkContext._active_spark_context - if rsd is None: - jc = sc._jvm.Dsl.approxCountDistinct(_to_java_column(col)) - else: - jc = sc._jvm.Dsl.approxCountDistinct(_to_java_column(col), rsd) - return Column(jc) - - @staticmethod - def udf(f, returnType=StringType()): - """Create a user defined function (UDF) - - >>> slen = Dsl.udf(lambda s: len(s), IntegerType()) - >>> df.select(slen(df.name).alias('slen')).collect() - [Row(slen=5), Row(slen=3)] - """ - return UserDefinedFunction(f, returnType) - - -def _test(): - import doctest - from pyspark.context import SparkContext - # let doctest run in pyspark.sql, so DataTypes can be picklable - import pyspark.sql - from pyspark.sql import Row, SQLContext - from pyspark.sql_tests import ExamplePoint, ExamplePointUDT - globs = pyspark.sql.__dict__.copy() - sc = SparkContext('local[4]', 'PythonTest') - globs['sc'] = sc - globs['sqlCtx'] = sqlCtx = SQLContext(sc) - globs['rdd'] = sc.parallelize( - [Row(field1=1, field2="row1"), - Row(field1=2, field2="row2"), - Row(field1=3, field2="row3")] - ) - rdd2 = sc.parallelize([Row(name='Alice', age=2), Row(name='Bob', age=5)]) - rdd3 = sc.parallelize([Row(name='Tom', height=80), Row(name='Bob', height=85)]) - globs['df'] = sqlCtx.inferSchema(rdd2) - globs['df2'] = sqlCtx.inferSchema(rdd3) - globs['ExamplePoint'] = ExamplePoint - globs['ExamplePointUDT'] = ExamplePointUDT - jsonStrings = [ - '{"field1": 1, "field2": "row1", "field3":{"field4":11}}', - '{"field1" : 2, "field3":{"field4":22, "field5": [10, 11]},' - '"field6":[{"field7": "row2"}]}', - '{"field1" : null, "field2": "row3", ' - '"field3":{"field4":33, "field5": []}}' - ] - globs['jsonStrings'] = jsonStrings - globs['json'] = sc.parallelize(jsonStrings) - (failure_count, test_count) = doctest.testmod( - pyspark.sql, globs=globs, optionflags=doctest.ELLIPSIS) - globs['sc'].stop() - if failure_count: - exit(-1) - - -if __name__ == "__main__": - _test() diff --git a/python/pyspark/sql/__init__.py b/python/pyspark/sql/__init__.py new file mode 100644 index 0000000000..0a5ba00393 --- /dev/null +++ b/python/pyspark/sql/__init__.py @@ -0,0 +1,42 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""" +public classes of Spark SQL: + + - L{SQLContext} + Main entry point for SQL functionality. + - L{DataFrame} + A Resilient Distributed Dataset (RDD) with Schema information for the data contained. In + addition to normal RDD operations, DataFrames also support SQL. + - L{GroupedData} + - L{Column} + Column is a DataFrame with a single column. + - L{Row} + A Row of data returned by a Spark SQL query. + - L{HiveContext} + Main entry point for accessing data stored in Apache Hive.. +""" + +from pyspark.sql.context import SQLContext, HiveContext +from pyspark.sql.types import Row +from pyspark.sql.dataframe import DataFrame, GroupedData, Column, Dsl, SchemaRDD + +__all__ = [ + 'SQLContext', 'HiveContext', 'DataFrame', 'GroupedData', 'Column', 'Row', + 'Dsl', +] diff --git a/python/pyspark/sql/context.py b/python/pyspark/sql/context.py new file mode 100644 index 0000000000..49f016a9cf --- /dev/null +++ b/python/pyspark/sql/context.py @@ -0,0 +1,642 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import warnings +import json +from array import array +from itertools import imap + +from py4j.protocol import Py4JError + +from pyspark.rdd import _prepare_for_python_RDD +from pyspark.serializers import AutoBatchedSerializer, PickleSerializer +from pyspark.sql.types import StringType, StructType, _verify_type, \ + _infer_schema, _has_nulltype, _merge_type, _create_converter, _python_to_sql_converter +from pyspark.sql.dataframe import DataFrame + +__all__ = ["SQLContext", "HiveContext"] + + +class SQLContext(object): + + """Main entry point for Spark SQL functionality. + + A SQLContext can be used create L{DataFrame}, register L{DataFrame} as + tables, execute SQL over tables, cache tables, and read parquet files. + """ + + def __init__(self, sparkContext, sqlContext=None): + """Create a new SQLContext. + + :param sparkContext: The SparkContext to wrap. + :param sqlContext: An optional JVM Scala SQLContext. If set, we do not instatiate a new + SQLContext in the JVM, instead we make all calls to this object. + + >>> df = sqlCtx.inferSchema(rdd) + >>> sqlCtx.inferSchema(df) # doctest: +IGNORE_EXCEPTION_DETAIL + Traceback (most recent call last): + ... + TypeError:... + + >>> bad_rdd = sc.parallelize([1,2,3]) + >>> sqlCtx.inferSchema(bad_rdd) # doctest: +IGNORE_EXCEPTION_DETAIL + Traceback (most recent call last): + ... + ValueError:... + + >>> from datetime import datetime + >>> allTypes = sc.parallelize([Row(i=1, s="string", d=1.0, l=1L, + ... b=True, list=[1, 2, 3], dict={"s": 0}, row=Row(a=1), + ... time=datetime(2014, 8, 1, 14, 1, 5))]) + >>> df = sqlCtx.inferSchema(allTypes) + >>> df.registerTempTable("allTypes") + >>> sqlCtx.sql('select i+1, d+1, not b, list[1], dict["s"], time, row.a ' + ... 'from allTypes where b and i > 0').collect() + [Row(c0=2, c1=2.0, c2=False, c3=2, c4=0...8, 1, 14, 1, 5), a=1)] + >>> df.map(lambda x: (x.i, x.s, x.d, x.l, x.b, x.time, + ... x.row.a, x.list)).collect() + [(1, u'string', 1.0, 1, True, ...(2014, 8, 1, 14, 1, 5), 1, [1, 2, 3])] + """ + self._sc = sparkContext + self._jsc = self._sc._jsc + self._jvm = self._sc._jvm + self._scala_SQLContext = sqlContext + + @property + def _ssql_ctx(self): + """Accessor for the JVM Spark SQL context. + + Subclasses can override this property to provide their own + JVM Contexts. + """ + if self._scala_SQLContext is None: + self._scala_SQLContext = self._jvm.SQLContext(self._jsc.sc()) + return self._scala_SQLContext + + def registerFunction(self, name, f, returnType=StringType()): + """Registers a lambda function as a UDF so it can be used in SQL statements. + + In addition to a name and the function itself, the return type can be optionally specified. + When the return type is not given it default to a string and conversion will automatically + be done. For any other return type, the produced object must match the specified type. + + >>> sqlCtx.registerFunction("stringLengthString", lambda x: len(x)) + >>> sqlCtx.sql("SELECT stringLengthString('test')").collect() + [Row(c0=u'4')] + >>> from pyspark.sql.types import IntegerType + >>> sqlCtx.registerFunction("stringLengthInt", lambda x: len(x), IntegerType()) + >>> sqlCtx.sql("SELECT stringLengthInt('test')").collect() + [Row(c0=4)] + """ + func = lambda _, it: imap(lambda x: f(*x), it) + ser = AutoBatchedSerializer(PickleSerializer()) + command = (func, None, ser, ser) + pickled_cmd, bvars, env, includes = _prepare_for_python_RDD(self._sc, command, self) + self._ssql_ctx.udf().registerPython(name, + bytearray(pickled_cmd), + env, + includes, + self._sc.pythonExec, + bvars, + self._sc._javaAccumulator, + returnType.json()) + + def inferSchema(self, rdd, samplingRatio=None): + """Infer and apply a schema to an RDD of L{Row}. + + When samplingRatio is specified, the schema is inferred by looking + at the types of each row in the sampled dataset. Otherwise, the + first 100 rows of the RDD are inspected. Nested collections are + supported, which can include array, dict, list, Row, tuple, + namedtuple, or object. + + Each row could be L{pyspark.sql.Row} object or namedtuple or objects. + Using top level dicts is deprecated, as dict is used to represent Maps. + + If a single column has multiple distinct inferred types, it may cause + runtime exceptions. + + >>> rdd = sc.parallelize( + ... [Row(field1=1, field2="row1"), + ... Row(field1=2, field2="row2"), + ... Row(field1=3, field2="row3")]) + >>> df = sqlCtx.inferSchema(rdd) + >>> df.collect()[0] + Row(field1=1, field2=u'row1') + + >>> NestedRow = Row("f1", "f2") + >>> nestedRdd1 = sc.parallelize([ + ... NestedRow(array('i', [1, 2]), {"row1": 1.0}), + ... NestedRow(array('i', [2, 3]), {"row2": 2.0})]) + >>> df = sqlCtx.inferSchema(nestedRdd1) + >>> df.collect() + [Row(f1=[1, 2], f2={u'row1': 1.0}), ..., f2={u'row2': 2.0})] + + >>> nestedRdd2 = sc.parallelize([ + ... NestedRow([[1, 2], [2, 3]], [1, 2]), + ... NestedRow([[2, 3], [3, 4]], [2, 3])]) + >>> df = sqlCtx.inferSchema(nestedRdd2) + >>> df.collect() + [Row(f1=[[1, 2], [2, 3]], f2=[1, 2]), ..., f2=[2, 3])] + + >>> from collections import namedtuple + >>> CustomRow = namedtuple('CustomRow', 'field1 field2') + >>> rdd = sc.parallelize( + ... [CustomRow(field1=1, field2="row1"), + ... CustomRow(field1=2, field2="row2"), + ... CustomRow(field1=3, field2="row3")]) + >>> df = sqlCtx.inferSchema(rdd) + >>> df.collect()[0] + Row(field1=1, field2=u'row1') + """ + + if isinstance(rdd, DataFrame): + raise TypeError("Cannot apply schema to DataFrame") + + first = rdd.first() + if not first: + raise ValueError("The first row in RDD is empty, " + "can not infer schema") + if type(first) is dict: + warnings.warn("Using RDD of dict to inferSchema is deprecated," + "please use pyspark.sql.Row instead") + + if samplingRatio is None: + schema = _infer_schema(first) + if _has_nulltype(schema): + for row in rdd.take(100)[1:]: + schema = _merge_type(schema, _infer_schema(row)) + if not _has_nulltype(schema): + break + else: + warnings.warn("Some of types cannot be determined by the " + "first 100 rows, please try again with sampling") + else: + if samplingRatio > 0.99: + rdd = rdd.sample(False, float(samplingRatio)) + schema = rdd.map(_infer_schema).reduce(_merge_type) + + converter = _create_converter(schema) + rdd = rdd.map(converter) + return self.applySchema(rdd, schema) + + def applySchema(self, rdd, schema): + """ + Applies the given schema to the given RDD of L{tuple} or L{list}. + + These tuples or lists can contain complex nested structures like + lists, maps or nested rows. + + The schema should be a StructType. + + It is important that the schema matches the types of the objects + in each row or exceptions could be thrown at runtime. + + >>> from pyspark.sql.types import * + >>> rdd2 = sc.parallelize([(1, "row1"), (2, "row2"), (3, "row3")]) + >>> schema = StructType([StructField("field1", IntegerType(), False), + ... StructField("field2", StringType(), False)]) + >>> df = sqlCtx.applySchema(rdd2, schema) + >>> sqlCtx.registerRDDAsTable(df, "table1") + >>> df2 = sqlCtx.sql("SELECT * from table1") + >>> df2.collect() + [Row(field1=1, field2=u'row1'),..., Row(field1=3, field2=u'row3')] + + >>> from datetime import date, datetime + >>> rdd = sc.parallelize([(127, -128L, -32768, 32767, 2147483647L, 1.0, + ... date(2010, 1, 1), + ... datetime(2010, 1, 1, 1, 1, 1), + ... {"a": 1}, (2,), [1, 2, 3], None)]) + >>> schema = StructType([ + ... StructField("byte1", ByteType(), False), + ... StructField("byte2", ByteType(), False), + ... StructField("short1", ShortType(), False), + ... StructField("short2", ShortType(), False), + ... StructField("int", IntegerType(), False), + ... StructField("float", FloatType(), False), + ... StructField("date", DateType(), False), + ... StructField("time", TimestampType(), False), + ... StructField("map", + ... MapType(StringType(), IntegerType(), False), False), + ... StructField("struct", + ... StructType([StructField("b", ShortType(), False)]), False), + ... StructField("list", ArrayType(ByteType(), False), False), + ... StructField("null", DoubleType(), True)]) + >>> df = sqlCtx.applySchema(rdd, schema) + >>> results = df.map( + ... lambda x: (x.byte1, x.byte2, x.short1, x.short2, x.int, x.float, x.date, + ... x.time, x.map["a"], x.struct.b, x.list, x.null)) + >>> results.collect()[0] # doctest: +NORMALIZE_WHITESPACE + (127, -128, -32768, 32767, 2147483647, 1.0, datetime.date(2010, 1, 1), + datetime.datetime(2010, 1, 1, 1, 1, 1), 1, 2, [1, 2, 3], None) + + >>> df.registerTempTable("table2") + >>> sqlCtx.sql( + ... "SELECT byte1 - 1 AS byte1, byte2 + 1 AS byte2, " + + ... "short1 + 1 AS short1, short2 - 1 AS short2, int - 1 AS int, " + + ... "float + 1.5 as float FROM table2").collect() + [Row(byte1=126, byte2=-127, short1=-32767, short2=32766, int=2147483646, float=2.5)] + + >>> from pyspark.sql.types import _parse_schema_abstract, _infer_schema_type + >>> rdd = sc.parallelize([(127, -32768, 1.0, + ... datetime(2010, 1, 1, 1, 1, 1), + ... {"a": 1}, (2,), [1, 2, 3])]) + >>> abstract = "byte short float time map{} struct(b) list[]" + >>> schema = _parse_schema_abstract(abstract) + >>> typedSchema = _infer_schema_type(rdd.first(), schema) + >>> df = sqlCtx.applySchema(rdd, typedSchema) + >>> df.collect() + [Row(byte=127, short=-32768, float=1.0, time=..., list=[1, 2, 3])] + """ + + if isinstance(rdd, DataFrame): + raise TypeError("Cannot apply schema to DataFrame") + + if not isinstance(schema, StructType): + raise TypeError("schema should be StructType") + + # take the first few rows to verify schema + rows = rdd.take(10) + # Row() cannot been deserialized by Pyrolite + if rows and isinstance(rows[0], tuple) and rows[0].__class__.__name__ == 'Row': + rdd = rdd.map(tuple) + rows = rdd.take(10) + + for row in rows: + _verify_type(row, schema) + + # convert python objects to sql data + converter = _python_to_sql_converter(schema) + rdd = rdd.map(converter) + + jrdd = self._jvm.SerDeUtil.toJavaArray(rdd._to_java_object_rdd()) + df = self._ssql_ctx.applySchemaToPythonRDD(jrdd.rdd(), schema.json()) + return DataFrame(df, self) + + def registerRDDAsTable(self, rdd, tableName): + """Registers the given RDD as a temporary table in the catalog. + + Temporary tables exist only during the lifetime of this instance of + SQLContext. + + >>> df = sqlCtx.inferSchema(rdd) + >>> sqlCtx.registerRDDAsTable(df, "table1") + """ + if (rdd.__class__ is DataFrame): + df = rdd._jdf + self._ssql_ctx.registerRDDAsTable(df, tableName) + else: + raise ValueError("Can only register DataFrame as table") + + def parquetFile(self, *paths): + """Loads a Parquet file, returning the result as a L{DataFrame}. + + >>> import tempfile, shutil + >>> parquetFile = tempfile.mkdtemp() + >>> shutil.rmtree(parquetFile) + >>> df = sqlCtx.inferSchema(rdd) + >>> df.saveAsParquetFile(parquetFile) + >>> df2 = sqlCtx.parquetFile(parquetFile) + >>> sorted(df.collect()) == sorted(df2.collect()) + True + """ + gateway = self._sc._gateway + jpath = paths[0] + jpaths = gateway.new_array(gateway.jvm.java.lang.String, len(paths) - 1) + for i in range(1, len(paths)): + jpaths[i] = paths[i] + jdf = self._ssql_ctx.parquetFile(jpath, jpaths) + return DataFrame(jdf, self) + + def jsonFile(self, path, schema=None, samplingRatio=1.0): + """ + Loads a text file storing one JSON object per line as a + L{DataFrame}. + + If the schema is provided, applies the given schema to this + JSON dataset. + + Otherwise, it samples the dataset with ratio `samplingRatio` to + determine the schema. + + >>> import tempfile, shutil + >>> jsonFile = tempfile.mkdtemp() + >>> shutil.rmtree(jsonFile) + >>> ofn = open(jsonFile, 'w') + >>> for json in jsonStrings: + ... print>>ofn, json + >>> ofn.close() + >>> df1 = sqlCtx.jsonFile(jsonFile) + >>> sqlCtx.registerRDDAsTable(df1, "table1") + >>> df2 = sqlCtx.sql( + ... "SELECT field1 AS f1, field2 as f2, field3 as f3, " + ... "field6 as f4 from table1") + >>> for r in df2.collect(): + ... print r + Row(f1=1, f2=u'row1', f3=Row(field4=11, field5=None), f4=None) + Row(f1=2, f2=None, f3=Row(field4=22,..., f4=[Row(field7=u'row2')]) + Row(f1=None, f2=u'row3', f3=Row(field4=33, field5=[]), f4=None) + + >>> df3 = sqlCtx.jsonFile(jsonFile, df1.schema()) + >>> sqlCtx.registerRDDAsTable(df3, "table2") + >>> df4 = sqlCtx.sql( + ... "SELECT field1 AS f1, field2 as f2, field3 as f3, " + ... "field6 as f4 from table2") + >>> for r in df4.collect(): + ... print r + Row(f1=1, f2=u'row1', f3=Row(field4=11, field5=None), f4=None) + Row(f1=2, f2=None, f3=Row(field4=22,..., f4=[Row(field7=u'row2')]) + Row(f1=None, f2=u'row3', f3=Row(field4=33, field5=[]), f4=None) + + >>> from pyspark.sql.types import * + >>> schema = StructType([ + ... StructField("field2", StringType(), True), + ... StructField("field3", + ... StructType([ + ... StructField("field5", + ... ArrayType(IntegerType(), False), True)]), False)]) + >>> df5 = sqlCtx.jsonFile(jsonFile, schema) + >>> sqlCtx.registerRDDAsTable(df5, "table3") + >>> df6 = sqlCtx.sql( + ... "SELECT field2 AS f1, field3.field5 as f2, " + ... "field3.field5[0] as f3 from table3") + >>> df6.collect() + [Row(f1=u'row1', f2=None, f3=None)...Row(f1=u'row3', f2=[], f3=None)] + """ + if schema is None: + df = self._ssql_ctx.jsonFile(path, samplingRatio) + else: + scala_datatype = self._ssql_ctx.parseDataType(schema.json()) + df = self._ssql_ctx.jsonFile(path, scala_datatype) + return DataFrame(df, self) + + def jsonRDD(self, rdd, schema=None, samplingRatio=1.0): + """Loads an RDD storing one JSON object per string as a L{DataFrame}. + + If the schema is provided, applies the given schema to this + JSON dataset. + + Otherwise, it samples the dataset with ratio `samplingRatio` to + determine the schema. + + >>> df1 = sqlCtx.jsonRDD(json) + >>> sqlCtx.registerRDDAsTable(df1, "table1") + >>> df2 = sqlCtx.sql( + ... "SELECT field1 AS f1, field2 as f2, field3 as f3, " + ... "field6 as f4 from table1") + >>> for r in df2.collect(): + ... print r + Row(f1=1, f2=u'row1', f3=Row(field4=11, field5=None), f4=None) + Row(f1=2, f2=None, f3=Row(field4=22..., f4=[Row(field7=u'row2')]) + Row(f1=None, f2=u'row3', f3=Row(field4=33, field5=[]), f4=None) + + >>> df3 = sqlCtx.jsonRDD(json, df1.schema()) + >>> sqlCtx.registerRDDAsTable(df3, "table2") + >>> df4 = sqlCtx.sql( + ... "SELECT field1 AS f1, field2 as f2, field3 as f3, " + ... "field6 as f4 from table2") + >>> for r in df4.collect(): + ... print r + Row(f1=1, f2=u'row1', f3=Row(field4=11, field5=None), f4=None) + Row(f1=2, f2=None, f3=Row(field4=22..., f4=[Row(field7=u'row2')]) + Row(f1=None, f2=u'row3', f3=Row(field4=33, field5=[]), f4=None) + + >>> from pyspark.sql.types import * + >>> schema = StructType([ + ... StructField("field2", StringType(), True), + ... StructField("field3", + ... StructType([ + ... StructField("field5", + ... ArrayType(IntegerType(), False), True)]), False)]) + >>> df5 = sqlCtx.jsonRDD(json, schema) + >>> sqlCtx.registerRDDAsTable(df5, "table3") + >>> df6 = sqlCtx.sql( + ... "SELECT field2 AS f1, field3.field5 as f2, " + ... "field3.field5[0] as f3 from table3") + >>> df6.collect() + [Row(f1=u'row1', f2=None,...Row(f1=u'row3', f2=[], f3=None)] + + >>> sqlCtx.jsonRDD(sc.parallelize(['{}', + ... '{"key0": {"key1": "value1"}}'])).collect() + [Row(key0=None), Row(key0=Row(key1=u'value1'))] + >>> sqlCtx.jsonRDD(sc.parallelize(['{"key0": null}', + ... '{"key0": {"key1": "value1"}}'])).collect() + [Row(key0=None), Row(key0=Row(key1=u'value1'))] + """ + + def func(iterator): + for x in iterator: + if not isinstance(x, basestring): + x = unicode(x) + if isinstance(x, unicode): + x = x.encode("utf-8") + yield x + keyed = rdd.mapPartitions(func) + keyed._bypass_serializer = True + jrdd = keyed._jrdd.map(self._jvm.BytesToString()) + if schema is None: + df = self._ssql_ctx.jsonRDD(jrdd.rdd(), samplingRatio) + else: + scala_datatype = self._ssql_ctx.parseDataType(schema.json()) + df = self._ssql_ctx.jsonRDD(jrdd.rdd(), scala_datatype) + return DataFrame(df, self) + + def sql(self, sqlQuery): + """Return a L{DataFrame} representing the result of the given query. + + >>> df = sqlCtx.inferSchema(rdd) + >>> sqlCtx.registerRDDAsTable(df, "table1") + >>> df2 = sqlCtx.sql("SELECT field1 AS f1, field2 as f2 from table1") + >>> df2.collect() + [Row(f1=1, f2=u'row1'), Row(f1=2, f2=u'row2'), Row(f1=3, f2=u'row3')] + """ + return DataFrame(self._ssql_ctx.sql(sqlQuery), self) + + def table(self, tableName): + """Returns the specified table as a L{DataFrame}. + + >>> df = sqlCtx.inferSchema(rdd) + >>> sqlCtx.registerRDDAsTable(df, "table1") + >>> df2 = sqlCtx.table("table1") + >>> sorted(df.collect()) == sorted(df2.collect()) + True + """ + return DataFrame(self._ssql_ctx.table(tableName), self) + + def cacheTable(self, tableName): + """Caches the specified table in-memory.""" + self._ssql_ctx.cacheTable(tableName) + + def uncacheTable(self, tableName): + """Removes the specified table from the in-memory cache.""" + self._ssql_ctx.uncacheTable(tableName) + + +class HiveContext(SQLContext): + + """A variant of Spark SQL that integrates with data stored in Hive. + + Configuration for Hive is read from hive-site.xml on the classpath. + It supports running both SQL and HiveQL commands. + """ + + def __init__(self, sparkContext, hiveContext=None): + """Create a new HiveContext. + + :param sparkContext: The SparkContext to wrap. + :param hiveContext: An optional JVM Scala HiveContext. If set, we do not instatiate a new + HiveContext in the JVM, instead we make all calls to this object. + """ + SQLContext.__init__(self, sparkContext) + + if hiveContext: + self._scala_HiveContext = hiveContext + + @property + def _ssql_ctx(self): + try: + if not hasattr(self, '_scala_HiveContext'): + self._scala_HiveContext = self._get_hive_ctx() + return self._scala_HiveContext + except Py4JError as e: + raise Exception("You must build Spark with Hive. " + "Export 'SPARK_HIVE=true' and run " + "build/sbt assembly", e) + + def _get_hive_ctx(self): + return self._jvm.HiveContext(self._jsc.sc()) + + +def _create_row(fields, values): + row = Row(*values) + row.__FIELDS__ = fields + return row + + +class Row(tuple): + + """ + A row in L{DataFrame}. The fields in it can be accessed like attributes. + + Row can be used to create a row object by using named arguments, + the fields will be sorted by names. + + >>> row = Row(name="Alice", age=11) + >>> row + Row(age=11, name='Alice') + >>> row.name, row.age + ('Alice', 11) + + Row also can be used to create another Row like class, then it + could be used to create Row objects, such as + + >>> Person = Row("name", "age") + >>> Person + + >>> Person("Alice", 11) + Row(name='Alice', age=11) + """ + + def __new__(self, *args, **kwargs): + if args and kwargs: + raise ValueError("Can not use both args " + "and kwargs to create Row") + if args: + # create row class or objects + return tuple.__new__(self, args) + + elif kwargs: + # create row objects + names = sorted(kwargs.keys()) + values = tuple(kwargs[n] for n in names) + row = tuple.__new__(self, values) + row.__FIELDS__ = names + return row + + else: + raise ValueError("No args or kwargs") + + def asDict(self): + """ + Return as an dict + """ + if not hasattr(self, "__FIELDS__"): + raise TypeError("Cannot convert a Row class into dict") + return dict(zip(self.__FIELDS__, self)) + + # let obect acs like class + def __call__(self, *args): + """create new Row object""" + return _create_row(self, args) + + def __getattr__(self, item): + if item.startswith("__"): + raise AttributeError(item) + try: + # it will be slow when it has many fields, + # but this will not be used in normal cases + idx = self.__FIELDS__.index(item) + return self[idx] + except IndexError: + raise AttributeError(item) + + def __reduce__(self): + if hasattr(self, "__FIELDS__"): + return (_create_row, (self.__FIELDS__, tuple(self))) + else: + return tuple.__reduce__(self) + + def __repr__(self): + if hasattr(self, "__FIELDS__"): + return "Row(%s)" % ", ".join("%s=%r" % (k, v) + for k, v in zip(self.__FIELDS__, self)) + else: + return "" % ", ".join(self) + + +def _test(): + import doctest + from pyspark.context import SparkContext + from pyspark.sql import Row, SQLContext + import pyspark.sql.context + globs = pyspark.sql.context.__dict__.copy() + sc = SparkContext('local[4]', 'PythonTest') + globs['sc'] = sc + globs['sqlCtx'] = sqlCtx = SQLContext(sc) + globs['rdd'] = sc.parallelize( + [Row(field1=1, field2="row1"), + Row(field1=2, field2="row2"), + Row(field1=3, field2="row3")] + ) + jsonStrings = [ + '{"field1": 1, "field2": "row1", "field3":{"field4":11}}', + '{"field1" : 2, "field3":{"field4":22, "field5": [10, 11]},' + '"field6":[{"field7": "row2"}]}', + '{"field1" : null, "field2": "row3", ' + '"field3":{"field4":33, "field5": []}}' + ] + globs['jsonStrings'] = jsonStrings + globs['json'] = sc.parallelize(jsonStrings) + (failure_count, test_count) = doctest.testmod( + pyspark.sql.context, globs=globs, optionflags=doctest.ELLIPSIS) + globs['sc'].stop() + if failure_count: + exit(-1) + + +if __name__ == "__main__": + _test() diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py new file mode 100644 index 0000000000..cda704eea7 --- /dev/null +++ b/python/pyspark/sql/dataframe.py @@ -0,0 +1,974 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import sys +import itertools +import warnings +import random +import os +from tempfile import NamedTemporaryFile +from itertools import imap + +from py4j.java_collections import ListConverter, MapConverter + +from pyspark.context import SparkContext +from pyspark.rdd import RDD, _prepare_for_python_RDD +from pyspark.serializers import BatchedSerializer, AutoBatchedSerializer, PickleSerializer, \ + UTF8Deserializer +from pyspark.storagelevel import StorageLevel +from pyspark.traceback_utils import SCCallSiteSync +from pyspark.sql.types import * +from pyspark.sql.types import _create_cls, _parse_datatype_json_string + + +__all__ = ["DataFrame", "GroupedData", "Column", "Dsl", "SchemaRDD"] + + +class DataFrame(object): + + """A collection of rows that have the same columns. + + A :class:`DataFrame` is equivalent to a relational table in Spark SQL, + and can be created using various functions in :class:`SQLContext`:: + + people = sqlContext.parquetFile("...") + + Once created, it can be manipulated using the various domain-specific-language + (DSL) functions defined in: :class:`DataFrame`, :class:`Column`. + + To select a column from the data frame, use the apply method:: + + ageCol = people.age + + Note that the :class:`Column` type can also be manipulated + through its various functions:: + + # The following creates a new column that increases everybody's age by 10. + people.age + 10 + + + A more concrete example:: + + # To create DataFrame using SQLContext + people = sqlContext.parquetFile("...") + department = sqlContext.parquetFile("...") + + people.filter(people.age > 30).join(department, people.deptId == department.id)) \ + .groupBy(department.name, "gender").agg({"salary": "avg", "age": "max"}) + """ + + def __init__(self, jdf, sql_ctx): + self._jdf = jdf + self.sql_ctx = sql_ctx + self._sc = sql_ctx and sql_ctx._sc + self.is_cached = False + + @property + def rdd(self): + """ + Return the content of the :class:`DataFrame` as an :class:`RDD` + of :class:`Row` s. + """ + if not hasattr(self, '_lazy_rdd'): + jrdd = self._jdf.javaToPython() + rdd = RDD(jrdd, self.sql_ctx._sc, BatchedSerializer(PickleSerializer())) + schema = self.schema() + + def applySchema(it): + cls = _create_cls(schema) + return itertools.imap(cls, it) + + self._lazy_rdd = rdd.mapPartitions(applySchema) + + return self._lazy_rdd + + def toJSON(self, use_unicode=False): + """Convert a DataFrame into a MappedRDD of JSON documents; one document per row. + + >>> df.toJSON().first() + '{"age":2,"name":"Alice"}' + """ + rdd = self._jdf.toJSON() + return RDD(rdd.toJavaRDD(), self._sc, UTF8Deserializer(use_unicode)) + + def saveAsParquetFile(self, path): + """Save the contents as a Parquet file, preserving the schema. + + Files that are written out using this method can be read back in as + a DataFrame using the L{SQLContext.parquetFile} method. + + >>> import tempfile, shutil + >>> parquetFile = tempfile.mkdtemp() + >>> shutil.rmtree(parquetFile) + >>> df.saveAsParquetFile(parquetFile) + >>> df2 = sqlCtx.parquetFile(parquetFile) + >>> sorted(df2.collect()) == sorted(df.collect()) + True + """ + self._jdf.saveAsParquetFile(path) + + def registerTempTable(self, name): + """Registers this RDD as a temporary table using the given name. + + The lifetime of this temporary table is tied to the L{SQLContext} + that was used to create this DataFrame. + + >>> df.registerTempTable("people") + >>> df2 = sqlCtx.sql("select * from people") + >>> sorted(df.collect()) == sorted(df2.collect()) + True + """ + self._jdf.registerTempTable(name) + + def registerAsTable(self, name): + """DEPRECATED: use registerTempTable() instead""" + warnings.warn("Use registerTempTable instead of registerAsTable.", DeprecationWarning) + self.registerTempTable(name) + + def insertInto(self, tableName, overwrite=False): + """Inserts the contents of this DataFrame into the specified table. + + Optionally overwriting any existing data. + """ + self._jdf.insertInto(tableName, overwrite) + + def saveAsTable(self, tableName): + """Creates a new table with the contents of this DataFrame.""" + self._jdf.saveAsTable(tableName) + + def schema(self): + """Returns the schema of this DataFrame (represented by + a L{StructType}). + + >>> df.schema() + StructType(List(StructField(age,IntegerType,true),StructField(name,StringType,true))) + """ + return _parse_datatype_json_string(self._jdf.schema().json()) + + def printSchema(self): + """Prints out the schema in the tree format. + + >>> df.printSchema() + root + |-- age: integer (nullable = true) + |-- name: string (nullable = true) + + """ + print (self._jdf.schema().treeString()) + + def count(self): + """Return the number of elements in this RDD. + + Unlike the base RDD implementation of count, this implementation + leverages the query optimizer to compute the count on the DataFrame, + which supports features such as filter pushdown. + + >>> df.count() + 2L + """ + return self._jdf.count() + + def collect(self): + """Return a list that contains all of the rows. + + Each object in the list is a Row, the fields can be accessed as + attributes. + + >>> df.collect() + [Row(age=2, name=u'Alice'), Row(age=5, name=u'Bob')] + """ + with SCCallSiteSync(self._sc) as css: + bytesInJava = self._jdf.javaToPython().collect().iterator() + tempFile = NamedTemporaryFile(delete=False, dir=self._sc._temp_dir) + tempFile.close() + self._sc._writeToFile(bytesInJava, tempFile.name) + # Read the data into Python and deserialize it: + with open(tempFile.name, 'rb') as tempFile: + rs = list(BatchedSerializer(PickleSerializer()).load_stream(tempFile)) + os.unlink(tempFile.name) + cls = _create_cls(self.schema()) + return [cls(r) for r in rs] + + def limit(self, num): + """Limit the result count to the number specified. + + >>> df.limit(1).collect() + [Row(age=2, name=u'Alice')] + >>> df.limit(0).collect() + [] + """ + jdf = self._jdf.limit(num) + return DataFrame(jdf, self.sql_ctx) + + def take(self, num): + """Take the first num rows of the RDD. + + Each object in the list is a Row, the fields can be accessed as + attributes. + + >>> df.take(2) + [Row(age=2, name=u'Alice'), Row(age=5, name=u'Bob')] + """ + return self.limit(num).collect() + + def map(self, f): + """ Return a new RDD by applying a function to each Row, it's a + shorthand for df.rdd.map() + + >>> df.map(lambda p: p.name).collect() + [u'Alice', u'Bob'] + """ + return self.rdd.map(f) + + def mapPartitions(self, f, preservesPartitioning=False): + """ + Return a new RDD by applying a function to each partition. + + >>> rdd = sc.parallelize([1, 2, 3, 4], 4) + >>> def f(iterator): yield 1 + >>> rdd.mapPartitions(f).sum() + 4 + """ + return self.rdd.mapPartitions(f, preservesPartitioning) + + def cache(self): + """ Persist with the default storage level (C{MEMORY_ONLY_SER}). + """ + self.is_cached = True + self._jdf.cache() + return self + + def persist(self, storageLevel=StorageLevel.MEMORY_ONLY_SER): + """ Set the storage level to persist its values across operations + after the first time it is computed. This can only be used to assign + a new storage level if the RDD does not have a storage level set yet. + If no storage level is specified defaults to (C{MEMORY_ONLY_SER}). + """ + self.is_cached = True + javaStorageLevel = self._sc._getJavaStorageLevel(storageLevel) + self._jdf.persist(javaStorageLevel) + return self + + def unpersist(self, blocking=True): + """ Mark it as non-persistent, and remove all blocks for it from + memory and disk. + """ + self.is_cached = False + self._jdf.unpersist(blocking) + return self + + # def coalesce(self, numPartitions, shuffle=False): + # rdd = self._jdf.coalesce(numPartitions, shuffle, None) + # return DataFrame(rdd, self.sql_ctx) + + def repartition(self, numPartitions): + """ Return a new :class:`DataFrame` that has exactly `numPartitions` + partitions. + """ + rdd = self._jdf.repartition(numPartitions, None) + return DataFrame(rdd, self.sql_ctx) + + def sample(self, withReplacement, fraction, seed=None): + """ + Return a sampled subset of this DataFrame. + + >>> df.sample(False, 0.5, 97).count() + 1L + """ + assert fraction >= 0.0, "Negative fraction value: %s" % fraction + seed = seed if seed is not None else random.randint(0, sys.maxint) + rdd = self._jdf.sample(withReplacement, fraction, long(seed)) + return DataFrame(rdd, self.sql_ctx) + + # def takeSample(self, withReplacement, num, seed=None): + # """Return a fixed-size sampled subset of this DataFrame. + # + # >>> df = sqlCtx.inferSchema(rdd) + # >>> df.takeSample(False, 2, 97) + # [Row(field1=3, field2=u'row3'), Row(field1=1, field2=u'row1')] + # """ + # seed = seed if seed is not None else random.randint(0, sys.maxint) + # with SCCallSiteSync(self.context) as css: + # bytesInJava = self._jdf \ + # .takeSampleToPython(withReplacement, num, long(seed)) \ + # .iterator() + # cls = _create_cls(self.schema()) + # return map(cls, self._collect_iterator_through_file(bytesInJava)) + + @property + def dtypes(self): + """Return all column names and their data types as a list. + + >>> df.dtypes + [('age', 'integer'), ('name', 'string')] + """ + return [(str(f.name), f.dataType.jsonValue()) for f in self.schema().fields] + + @property + def columns(self): + """ Return all column names as a list. + + >>> df.columns + [u'age', u'name'] + """ + return [f.name for f in self.schema().fields] + + def join(self, other, joinExprs=None, joinType=None): + """ + Join with another DataFrame, using the given join expression. + The following performs a full outer join between `df1` and `df2`:: + + :param other: Right side of the join + :param joinExprs: Join expression + :param joinType: One of `inner`, `outer`, `left_outer`, `right_outer`, `semijoin`. + + >>> df.join(df2, df.name == df2.name, 'outer').select(df.name, df2.height).collect() + [Row(name=None, height=80), Row(name=u'Bob', height=85), Row(name=u'Alice', height=None)] + """ + + if joinExprs is None: + jdf = self._jdf.join(other._jdf) + else: + assert isinstance(joinExprs, Column), "joinExprs should be Column" + if joinType is None: + jdf = self._jdf.join(other._jdf, joinExprs._jc) + else: + assert isinstance(joinType, basestring), "joinType should be basestring" + jdf = self._jdf.join(other._jdf, joinExprs._jc, joinType) + return DataFrame(jdf, self.sql_ctx) + + def sort(self, *cols): + """ Return a new :class:`DataFrame` sorted by the specified column. + + :param cols: The columns or expressions used for sorting + + >>> df.sort(df.age.desc()).collect() + [Row(age=5, name=u'Bob'), Row(age=2, name=u'Alice')] + >>> df.sortBy(df.age.desc()).collect() + [Row(age=5, name=u'Bob'), Row(age=2, name=u'Alice')] + """ + if not cols: + raise ValueError("should sort by at least one column") + jcols = ListConverter().convert([_to_java_column(c) for c in cols], + self._sc._gateway._gateway_client) + jdf = self._jdf.sort(self._sc._jvm.PythonUtils.toSeq(jcols)) + return DataFrame(jdf, self.sql_ctx) + + sortBy = sort + + def head(self, n=None): + """ Return the first `n` rows or the first row if n is None. + + >>> df.head() + Row(age=2, name=u'Alice') + >>> df.head(1) + [Row(age=2, name=u'Alice')] + """ + if n is None: + rs = self.head(1) + return rs[0] if rs else None + return self.take(n) + + def first(self): + """ Return the first row. + + >>> df.first() + Row(age=2, name=u'Alice') + """ + return self.head() + + def __getitem__(self, item): + """ Return the column by given name + + >>> df['age'].collect() + [Row(age=2), Row(age=5)] + >>> df[ ["name", "age"]].collect() + [Row(name=u'Alice', age=2), Row(name=u'Bob', age=5)] + >>> df[ df.age > 3 ].collect() + [Row(age=5, name=u'Bob')] + """ + if isinstance(item, basestring): + jc = self._jdf.apply(item) + return Column(jc, self.sql_ctx) + elif isinstance(item, Column): + return self.filter(item) + elif isinstance(item, list): + return self.select(*item) + else: + raise IndexError("unexpected index: %s" % item) + + def __getattr__(self, name): + """ Return the column by given name + + >>> df.age.collect() + [Row(age=2), Row(age=5)] + """ + if name.startswith("__"): + raise AttributeError(name) + jc = self._jdf.apply(name) + return Column(jc, self.sql_ctx) + + def select(self, *cols): + """ Selecting a set of expressions. + + >>> df.select().collect() + [Row(age=2, name=u'Alice'), Row(age=5, name=u'Bob')] + >>> df.select('*').collect() + [Row(age=2, name=u'Alice'), Row(age=5, name=u'Bob')] + >>> df.select('name', 'age').collect() + [Row(name=u'Alice', age=2), Row(name=u'Bob', age=5)] + >>> df.select(df.name, (df.age + 10).alias('age')).collect() + [Row(name=u'Alice', age=12), Row(name=u'Bob', age=15)] + """ + if not cols: + cols = ["*"] + jcols = ListConverter().convert([_to_java_column(c) for c in cols], + self._sc._gateway._gateway_client) + jdf = self._jdf.select(self.sql_ctx._sc._jvm.PythonUtils.toSeq(jcols)) + return DataFrame(jdf, self.sql_ctx) + + def selectExpr(self, *expr): + """ + Selects a set of SQL expressions. This is a variant of + `select` that accepts SQL expressions. + + >>> df.selectExpr("age * 2", "abs(age)").collect() + [Row(('age * 2)=4, Abs('age)=2), Row(('age * 2)=10, Abs('age)=5)] + """ + jexpr = ListConverter().convert(expr, self._sc._gateway._gateway_client) + jdf = self._jdf.selectExpr(self._sc._jvm.PythonUtils.toSeq(jexpr)) + return DataFrame(jdf, self.sql_ctx) + + def filter(self, condition): + """ Filtering rows using the given condition, which could be + Column expression or string of SQL expression. + + where() is an alias for filter(). + + >>> df.filter(df.age > 3).collect() + [Row(age=5, name=u'Bob')] + >>> df.where(df.age == 2).collect() + [Row(age=2, name=u'Alice')] + + >>> df.filter("age > 3").collect() + [Row(age=5, name=u'Bob')] + >>> df.where("age = 2").collect() + [Row(age=2, name=u'Alice')] + """ + if isinstance(condition, basestring): + jdf = self._jdf.filter(condition) + elif isinstance(condition, Column): + jdf = self._jdf.filter(condition._jc) + else: + raise TypeError("condition should be string or Column") + return DataFrame(jdf, self.sql_ctx) + + where = filter + + def groupBy(self, *cols): + """ Group the :class:`DataFrame` using the specified columns, + so we can run aggregation on them. See :class:`GroupedData` + for all the available aggregate functions. + + >>> df.groupBy().avg().collect() + [Row(AVG(age#0)=3.5)] + >>> df.groupBy('name').agg({'age': 'mean'}).collect() + [Row(name=u'Bob', AVG(age#0)=5.0), Row(name=u'Alice', AVG(age#0)=2.0)] + >>> df.groupBy(df.name).avg().collect() + [Row(name=u'Bob', AVG(age#0)=5.0), Row(name=u'Alice', AVG(age#0)=2.0)] + """ + jcols = ListConverter().convert([_to_java_column(c) for c in cols], + self._sc._gateway._gateway_client) + jdf = self._jdf.groupBy(self.sql_ctx._sc._jvm.PythonUtils.toSeq(jcols)) + return GroupedData(jdf, self.sql_ctx) + + def agg(self, *exprs): + """ Aggregate on the entire :class:`DataFrame` without groups + (shorthand for df.groupBy.agg()). + + >>> df.agg({"age": "max"}).collect() + [Row(MAX(age#0)=5)] + >>> from pyspark.sql import Dsl + >>> df.agg(Dsl.min(df.age)).collect() + [Row(MIN(age#0)=2)] + """ + return self.groupBy().agg(*exprs) + + def unionAll(self, other): + """ Return a new DataFrame containing union of rows in this + frame and another frame. + + This is equivalent to `UNION ALL` in SQL. + """ + return DataFrame(self._jdf.unionAll(other._jdf), self.sql_ctx) + + def intersect(self, other): + """ Return a new :class:`DataFrame` containing rows only in + both this frame and another frame. + + This is equivalent to `INTERSECT` in SQL. + """ + return DataFrame(self._jdf.intersect(other._jdf), self.sql_ctx) + + def subtract(self, other): + """ Return a new :class:`DataFrame` containing rows in this frame + but not in another frame. + + This is equivalent to `EXCEPT` in SQL. + """ + return DataFrame(getattr(self._jdf, "except")(other._jdf), self.sql_ctx) + + def addColumn(self, colName, col): + """ Return a new :class:`DataFrame` by adding a column. + + >>> df.addColumn('age2', df.age + 2).collect() + [Row(age=2, name=u'Alice', age2=4), Row(age=5, name=u'Bob', age2=7)] + """ + return self.select('*', col.alias(colName)) + + def to_pandas(self): + """ + Collect all the rows and return a `pandas.DataFrame`. + + >>> df.to_pandas() # doctest: +SKIP + age name + 0 2 Alice + 1 5 Bob + """ + import pandas as pd + return pd.DataFrame.from_records(self.collect(), columns=self.columns) + + +# Having SchemaRDD for backward compatibility (for docs) +class SchemaRDD(DataFrame): + """ + SchemaRDD is deprecated, please use DataFrame + """ + + +def dfapi(f): + def _api(self): + name = f.__name__ + jdf = getattr(self._jdf, name)() + return DataFrame(jdf, self.sql_ctx) + _api.__name__ = f.__name__ + _api.__doc__ = f.__doc__ + return _api + + +class GroupedData(object): + + """ + A set of methods for aggregations on a :class:`DataFrame`, + created by DataFrame.groupBy(). + """ + + def __init__(self, jdf, sql_ctx): + self._jdf = jdf + self.sql_ctx = sql_ctx + + def agg(self, *exprs): + """ Compute aggregates by specifying a map from column name + to aggregate methods. + + The available aggregate methods are `avg`, `max`, `min`, + `sum`, `count`. + + :param exprs: list or aggregate columns or a map from column + name to aggregate methods. + + >>> gdf = df.groupBy(df.name) + >>> gdf.agg({"age": "max"}).collect() + [Row(name=u'Bob', MAX(age#0)=5), Row(name=u'Alice', MAX(age#0)=2)] + >>> from pyspark.sql import Dsl + >>> gdf.agg(Dsl.min(df.age)).collect() + [Row(MIN(age#0)=5), Row(MIN(age#0)=2)] + """ + assert exprs, "exprs should not be empty" + if len(exprs) == 1 and isinstance(exprs[0], dict): + jmap = MapConverter().convert(exprs[0], + self.sql_ctx._sc._gateway._gateway_client) + jdf = self._jdf.agg(jmap) + else: + # Columns + assert all(isinstance(c, Column) for c in exprs), "all exprs should be Column" + jcols = ListConverter().convert([c._jc for c in exprs[1:]], + self.sql_ctx._sc._gateway._gateway_client) + jdf = self._jdf.agg(exprs[0]._jc, self.sql_ctx._sc._jvm.PythonUtils.toSeq(jcols)) + return DataFrame(jdf, self.sql_ctx) + + @dfapi + def count(self): + """ Count the number of rows for each group. + + >>> df.groupBy(df.age).count().collect() + [Row(age=2, count=1), Row(age=5, count=1)] + """ + + @dfapi + def mean(self): + """Compute the average value for each numeric columns + for each group. This is an alias for `avg`.""" + + @dfapi + def avg(self): + """Compute the average value for each numeric columns + for each group.""" + + @dfapi + def max(self): + """Compute the max value for each numeric columns for + each group. """ + + @dfapi + def min(self): + """Compute the min value for each numeric column for + each group.""" + + @dfapi + def sum(self): + """Compute the sum for each numeric columns for each + group.""" + + +def _create_column_from_literal(literal): + sc = SparkContext._active_spark_context + return sc._jvm.Dsl.lit(literal) + + +def _create_column_from_name(name): + sc = SparkContext._active_spark_context + return sc._jvm.Dsl.col(name) + + +def _to_java_column(col): + if isinstance(col, Column): + jcol = col._jc + else: + jcol = _create_column_from_name(col) + return jcol + + +def _unary_op(name, doc="unary operator"): + """ Create a method for given unary operator """ + def _(self): + jc = getattr(self._jc, name)() + return Column(jc, self.sql_ctx) + _.__doc__ = doc + return _ + + +def _dsl_op(name, doc=''): + def _(self): + jc = getattr(self._sc._jvm.Dsl, name)(self._jc) + return Column(jc, self.sql_ctx) + _.__doc__ = doc + return _ + + +def _bin_op(name, doc="binary operator"): + """ Create a method for given binary operator + """ + def _(self, other): + jc = other._jc if isinstance(other, Column) else other + njc = getattr(self._jc, name)(jc) + return Column(njc, self.sql_ctx) + _.__doc__ = doc + return _ + + +def _reverse_op(name, doc="binary operator"): + """ Create a method for binary operator (this object is on right side) + """ + def _(self, other): + jother = _create_column_from_literal(other) + jc = getattr(jother, name)(self._jc) + return Column(jc, self.sql_ctx) + _.__doc__ = doc + return _ + + +class Column(DataFrame): + + """ + A column in a DataFrame. + + `Column` instances can be created by:: + + # 1. Select a column out of a DataFrame + df.colName + df["colName"] + + # 2. Create from an expression + df.colName + 1 + 1 / df.colName + """ + + def __init__(self, jc, sql_ctx=None): + self._jc = jc + super(Column, self).__init__(jc, sql_ctx) + + # arithmetic operators + __neg__ = _dsl_op("negate") + __add__ = _bin_op("plus") + __sub__ = _bin_op("minus") + __mul__ = _bin_op("multiply") + __div__ = _bin_op("divide") + __mod__ = _bin_op("mod") + __radd__ = _bin_op("plus") + __rsub__ = _reverse_op("minus") + __rmul__ = _bin_op("multiply") + __rdiv__ = _reverse_op("divide") + __rmod__ = _reverse_op("mod") + + # logistic operators + __eq__ = _bin_op("equalTo") + __ne__ = _bin_op("notEqual") + __lt__ = _bin_op("lt") + __le__ = _bin_op("leq") + __ge__ = _bin_op("geq") + __gt__ = _bin_op("gt") + + # `and`, `or`, `not` cannot be overloaded in Python, + # so use bitwise operators as boolean operators + __and__ = _bin_op('and') + __or__ = _bin_op('or') + __invert__ = _dsl_op('not') + __rand__ = _bin_op("and") + __ror__ = _bin_op("or") + + # container operators + __contains__ = _bin_op("contains") + __getitem__ = _bin_op("getItem") + getField = _bin_op("getField", "An expression that gets a field by name in a StructField.") + + # string methods + rlike = _bin_op("rlike") + like = _bin_op("like") + startswith = _bin_op("startsWith") + endswith = _bin_op("endsWith") + + def substr(self, startPos, length): + """ + Return a Column which is a substring of the column + + :param startPos: start position (int or Column) + :param length: length of the substring (int or Column) + + >>> df.name.substr(1, 3).collect() + [Row(col=u'Ali'), Row(col=u'Bob')] + """ + if type(startPos) != type(length): + raise TypeError("Can not mix the type") + if isinstance(startPos, (int, long)): + jc = self._jc.substr(startPos, length) + elif isinstance(startPos, Column): + jc = self._jc.substr(startPos._jc, length._jc) + else: + raise TypeError("Unexpected type: %s" % type(startPos)) + return Column(jc, self.sql_ctx) + + __getslice__ = substr + + # order + asc = _unary_op("asc") + desc = _unary_op("desc") + + isNull = _unary_op("isNull", "True if the current expression is null.") + isNotNull = _unary_op("isNotNull", "True if the current expression is not null.") + + def alias(self, alias): + """Return a alias for this column + + >>> df.age.alias("age2").collect() + [Row(age2=2), Row(age2=5)] + """ + return Column(getattr(self._jc, "as")(alias), self.sql_ctx) + + def cast(self, dataType): + """ Convert the column into type `dataType` + + >>> df.select(df.age.cast("string").alias('ages')).collect() + [Row(ages=u'2'), Row(ages=u'5')] + >>> df.select(df.age.cast(StringType()).alias('ages')).collect() + [Row(ages=u'2'), Row(ages=u'5')] + """ + if self.sql_ctx is None: + sc = SparkContext._active_spark_context + ssql_ctx = sc._jvm.SQLContext(sc._jsc.sc()) + else: + ssql_ctx = self.sql_ctx._ssql_ctx + if isinstance(dataType, basestring): + jc = self._jc.cast(dataType) + elif isinstance(dataType, DataType): + jdt = ssql_ctx.parseDataType(dataType.json()) + jc = self._jc.cast(jdt) + return Column(jc, self.sql_ctx) + + def to_pandas(self): + """ + Return a pandas.Series from the column + + >>> df.age.to_pandas() # doctest: +SKIP + 0 2 + 1 5 + dtype: int64 + """ + import pandas as pd + data = [c for c, in self.collect()] + return pd.Series(data) + + +def _aggregate_func(name, doc=""): + """ Create a function for aggregator by name""" + def _(col): + sc = SparkContext._active_spark_context + jc = getattr(sc._jvm.Dsl, name)(_to_java_column(col)) + return Column(jc) + _.__name__ = name + _.__doc__ = doc + return staticmethod(_) + + +class UserDefinedFunction(object): + def __init__(self, func, returnType): + self.func = func + self.returnType = returnType + self._broadcast = None + self._judf = self._create_judf() + + def _create_judf(self): + f = self.func # put it in closure `func` + func = lambda _, it: imap(lambda x: f(*x), it) + ser = AutoBatchedSerializer(PickleSerializer()) + command = (func, None, ser, ser) + sc = SparkContext._active_spark_context + pickled_command, broadcast_vars, env, includes = _prepare_for_python_RDD(sc, command, self) + ssql_ctx = sc._jvm.SQLContext(sc._jsc.sc()) + jdt = ssql_ctx.parseDataType(self.returnType.json()) + judf = sc._jvm.UserDefinedPythonFunction(f.__name__, bytearray(pickled_command), env, + includes, sc.pythonExec, broadcast_vars, + sc._javaAccumulator, jdt) + return judf + + def __del__(self): + if self._broadcast is not None: + self._broadcast.unpersist() + self._broadcast = None + + def __call__(self, *cols): + sc = SparkContext._active_spark_context + jcols = ListConverter().convert([_to_java_column(c) for c in cols], + sc._gateway._gateway_client) + jc = self._judf.apply(sc._jvm.PythonUtils.toSeq(jcols)) + return Column(jc) + + +class Dsl(object): + """ + A collections of builtin aggregators + """ + DSLS = { + 'lit': 'Creates a :class:`Column` of literal value.', + 'col': 'Returns a :class:`Column` based on the given column name.', + 'column': 'Returns a :class:`Column` based on the given column name.', + 'upper': 'Converts a string expression to upper case.', + 'lower': 'Converts a string expression to upper case.', + 'sqrt': 'Computes the square root of the specified float value.', + 'abs': 'Computes the absolutle value.', + + 'max': 'Aggregate function: returns the maximum value of the expression in a group.', + 'min': 'Aggregate function: returns the minimum value of the expression in a group.', + 'first': 'Aggregate function: returns the first value in a group.', + 'last': 'Aggregate function: returns the last value in a group.', + 'count': 'Aggregate function: returns the number of items in a group.', + 'sum': 'Aggregate function: returns the sum of all values in the expression.', + 'avg': 'Aggregate function: returns the average of the values in a group.', + 'mean': 'Aggregate function: returns the average of the values in a group.', + 'sumDistinct': 'Aggregate function: returns the sum of distinct values in the expression.', + } + + for _name, _doc in DSLS.items(): + locals()[_name] = _aggregate_func(_name, _doc) + del _name, _doc + + @staticmethod + def countDistinct(col, *cols): + """ Return a new Column for distinct count of (col, *cols) + + >>> from pyspark.sql import Dsl + >>> df.agg(Dsl.countDistinct(df.age, df.name).alias('c')).collect() + [Row(c=2)] + + >>> df.agg(Dsl.countDistinct("age", "name").alias('c')).collect() + [Row(c=2)] + """ + sc = SparkContext._active_spark_context + jcols = ListConverter().convert([_to_java_column(c) for c in cols], + sc._gateway._gateway_client) + jc = sc._jvm.Dsl.countDistinct(_to_java_column(col), + sc._jvm.PythonUtils.toSeq(jcols)) + return Column(jc) + + @staticmethod + def approxCountDistinct(col, rsd=None): + """ Return a new Column for approxiate distinct count of (col, *cols) + + >>> from pyspark.sql import Dsl + >>> df.agg(Dsl.approxCountDistinct(df.age).alias('c')).collect() + [Row(c=2)] + """ + sc = SparkContext._active_spark_context + if rsd is None: + jc = sc._jvm.Dsl.approxCountDistinct(_to_java_column(col)) + else: + jc = sc._jvm.Dsl.approxCountDistinct(_to_java_column(col), rsd) + return Column(jc) + + @staticmethod + def udf(f, returnType=StringType()): + """Create a user defined function (UDF) + + >>> slen = Dsl.udf(lambda s: len(s), IntegerType()) + >>> df.select(slen(df.name).alias('slen')).collect() + [Row(slen=5), Row(slen=3)] + """ + return UserDefinedFunction(f, returnType) + + +def _test(): + import doctest + from pyspark.context import SparkContext + from pyspark.sql import Row, SQLContext + import pyspark.sql.dataframe + globs = pyspark.sql.dataframe.__dict__.copy() + sc = SparkContext('local[4]', 'PythonTest') + globs['sc'] = sc + globs['sqlCtx'] = sqlCtx = SQLContext(sc) + rdd2 = sc.parallelize([Row(name='Alice', age=2), Row(name='Bob', age=5)]) + rdd3 = sc.parallelize([Row(name='Tom', height=80), Row(name='Bob', height=85)]) + globs['df'] = sqlCtx.inferSchema(rdd2) + globs['df2'] = sqlCtx.inferSchema(rdd3) + (failure_count, test_count) = doctest.testmod( + pyspark.sql.dataframe, globs=globs, optionflags=doctest.ELLIPSIS) + globs['sc'].stop() + if failure_count: + exit(-1) + + +if __name__ == "__main__": + _test() diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py new file mode 100644 index 0000000000..d25c6365ed --- /dev/null +++ b/python/pyspark/sql/tests.py @@ -0,0 +1,300 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""" +Unit tests for pyspark.sql; additional tests are implemented as doctests in +individual modules. +""" +import os +import sys +import pydoc +import shutil +import tempfile + +if sys.version_info[:2] <= (2, 6): + try: + import unittest2 as unittest + except ImportError: + sys.stderr.write('Please install unittest2 to test with Python 2.6 or earlier') + sys.exit(1) +else: + import unittest + + +from pyspark.sql import SQLContext, Column +from pyspark.sql.types import IntegerType, Row, ArrayType, StructType, StructField, \ + UserDefinedType, DoubleType, LongType +from pyspark.tests import ReusedPySparkTestCase + + +class ExamplePointUDT(UserDefinedType): + """ + User-defined type (UDT) for ExamplePoint. + """ + + @classmethod + def sqlType(self): + return ArrayType(DoubleType(), False) + + @classmethod + def module(cls): + return 'pyspark.tests' + + @classmethod + def scalaUDT(cls): + return 'org.apache.spark.sql.test.ExamplePointUDT' + + def serialize(self, obj): + return [obj.x, obj.y] + + def deserialize(self, datum): + return ExamplePoint(datum[0], datum[1]) + + +class ExamplePoint: + """ + An example class to demonstrate UDT in Scala, Java, and Python. + """ + + __UDT__ = ExamplePointUDT() + + def __init__(self, x, y): + self.x = x + self.y = y + + def __repr__(self): + return "ExamplePoint(%s,%s)" % (self.x, self.y) + + def __str__(self): + return "(%s,%s)" % (self.x, self.y) + + def __eq__(self, other): + return isinstance(other, ExamplePoint) and \ + other.x == self.x and other.y == self.y + + +class SQLTests(ReusedPySparkTestCase): + + @classmethod + def setUpClass(cls): + ReusedPySparkTestCase.setUpClass() + cls.tempdir = tempfile.NamedTemporaryFile(delete=False) + os.unlink(cls.tempdir.name) + cls.sqlCtx = SQLContext(cls.sc) + cls.testData = [Row(key=i, value=str(i)) for i in range(100)] + rdd = cls.sc.parallelize(cls.testData) + cls.df = cls.sqlCtx.inferSchema(rdd) + + @classmethod + def tearDownClass(cls): + ReusedPySparkTestCase.tearDownClass() + shutil.rmtree(cls.tempdir.name, ignore_errors=True) + + def test_udf(self): + self.sqlCtx.registerFunction("twoArgs", lambda x, y: len(x) + y, IntegerType()) + [row] = self.sqlCtx.sql("SELECT twoArgs('test', 1)").collect() + self.assertEqual(row[0], 5) + + def test_udf2(self): + self.sqlCtx.registerFunction("strlen", lambda string: len(string), IntegerType()) + self.sqlCtx.inferSchema(self.sc.parallelize([Row(a="test")])).registerTempTable("test") + [res] = self.sqlCtx.sql("SELECT strlen(a) FROM test WHERE strlen(a) > 1").collect() + self.assertEqual(4, res[0]) + + def test_udf_with_array_type(self): + d = [Row(l=range(3), d={"key": range(5)})] + rdd = self.sc.parallelize(d) + self.sqlCtx.inferSchema(rdd).registerTempTable("test") + self.sqlCtx.registerFunction("copylist", lambda l: list(l), ArrayType(IntegerType())) + self.sqlCtx.registerFunction("maplen", lambda d: len(d), IntegerType()) + [(l1, l2)] = self.sqlCtx.sql("select copylist(l), maplen(d) from test").collect() + self.assertEqual(range(3), l1) + self.assertEqual(1, l2) + + def test_broadcast_in_udf(self): + bar = {"a": "aa", "b": "bb", "c": "abc"} + foo = self.sc.broadcast(bar) + self.sqlCtx.registerFunction("MYUDF", lambda x: foo.value[x] if x else '') + [res] = self.sqlCtx.sql("SELECT MYUDF('c')").collect() + self.assertEqual("abc", res[0]) + [res] = self.sqlCtx.sql("SELECT MYUDF('')").collect() + self.assertEqual("", res[0]) + + def test_basic_functions(self): + rdd = self.sc.parallelize(['{"foo":"bar"}', '{"foo":"baz"}']) + df = self.sqlCtx.jsonRDD(rdd) + df.count() + df.collect() + df.schema() + + # cache and checkpoint + self.assertFalse(df.is_cached) + df.persist() + df.unpersist() + df.cache() + self.assertTrue(df.is_cached) + self.assertEqual(2, df.count()) + + df.registerTempTable("temp") + df = self.sqlCtx.sql("select foo from temp") + df.count() + df.collect() + + def test_apply_schema_to_row(self): + df = self.sqlCtx.jsonRDD(self.sc.parallelize(["""{"a":2}"""])) + df2 = self.sqlCtx.applySchema(df.map(lambda x: x), df.schema()) + self.assertEqual(df.collect(), df2.collect()) + + rdd = self.sc.parallelize(range(10)).map(lambda x: Row(a=x)) + df3 = self.sqlCtx.applySchema(rdd, df.schema()) + self.assertEqual(10, df3.count()) + + def test_serialize_nested_array_and_map(self): + d = [Row(l=[Row(a=1, b='s')], d={"key": Row(c=1.0, d="2")})] + rdd = self.sc.parallelize(d) + df = self.sqlCtx.inferSchema(rdd) + row = df.head() + self.assertEqual(1, len(row.l)) + self.assertEqual(1, row.l[0].a) + self.assertEqual("2", row.d["key"].d) + + l = df.map(lambda x: x.l).first() + self.assertEqual(1, len(l)) + self.assertEqual('s', l[0].b) + + d = df.map(lambda x: x.d).first() + self.assertEqual(1, len(d)) + self.assertEqual(1.0, d["key"].c) + + row = df.map(lambda x: x.d["key"]).first() + self.assertEqual(1.0, row.c) + self.assertEqual("2", row.d) + + def test_infer_schema(self): + d = [Row(l=[], d={}), + Row(l=[Row(a=1, b='s')], d={"key": Row(c=1.0, d="2")}, s="")] + rdd = self.sc.parallelize(d) + df = self.sqlCtx.inferSchema(rdd) + self.assertEqual([], df.map(lambda r: r.l).first()) + self.assertEqual([None, ""], df.map(lambda r: r.s).collect()) + df.registerTempTable("test") + result = self.sqlCtx.sql("SELECT l[0].a from test where d['key'].d = '2'") + self.assertEqual(1, result.head()[0]) + + df2 = self.sqlCtx.inferSchema(rdd, 1.0) + self.assertEqual(df.schema(), df2.schema()) + self.assertEqual({}, df2.map(lambda r: r.d).first()) + self.assertEqual([None, ""], df2.map(lambda r: r.s).collect()) + df2.registerTempTable("test2") + result = self.sqlCtx.sql("SELECT l[0].a from test2 where d['key'].d = '2'") + self.assertEqual(1, result.head()[0]) + + def test_struct_in_map(self): + d = [Row(m={Row(i=1): Row(s="")})] + rdd = self.sc.parallelize(d) + df = self.sqlCtx.inferSchema(rdd) + k, v = df.head().m.items()[0] + self.assertEqual(1, k.i) + self.assertEqual("", v.s) + + def test_convert_row_to_dict(self): + row = Row(l=[Row(a=1, b='s')], d={"key": Row(c=1.0, d="2")}) + self.assertEqual(1, row.asDict()['l'][0].a) + rdd = self.sc.parallelize([row]) + df = self.sqlCtx.inferSchema(rdd) + df.registerTempTable("test") + row = self.sqlCtx.sql("select l, d from test").head() + self.assertEqual(1, row.asDict()["l"][0].a) + self.assertEqual(1.0, row.asDict()['d']['key'].c) + + def test_infer_schema_with_udt(self): + from pyspark.sql.tests import ExamplePoint, ExamplePointUDT + row = Row(label=1.0, point=ExamplePoint(1.0, 2.0)) + rdd = self.sc.parallelize([row]) + df = self.sqlCtx.inferSchema(rdd) + schema = df.schema() + field = [f for f in schema.fields if f.name == "point"][0] + self.assertEqual(type(field.dataType), ExamplePointUDT) + df.registerTempTable("labeled_point") + point = self.sqlCtx.sql("SELECT point FROM labeled_point").head().point + self.assertEqual(point, ExamplePoint(1.0, 2.0)) + + def test_apply_schema_with_udt(self): + from pyspark.sql.tests import ExamplePoint, ExamplePointUDT + row = (1.0, ExamplePoint(1.0, 2.0)) + rdd = self.sc.parallelize([row]) + schema = StructType([StructField("label", DoubleType(), False), + StructField("point", ExamplePointUDT(), False)]) + df = self.sqlCtx.applySchema(rdd, schema) + point = df.head().point + self.assertEquals(point, ExamplePoint(1.0, 2.0)) + + def test_parquet_with_udt(self): + from pyspark.sql.tests import ExamplePoint + row = Row(label=1.0, point=ExamplePoint(1.0, 2.0)) + rdd = self.sc.parallelize([row]) + df0 = self.sqlCtx.inferSchema(rdd) + output_dir = os.path.join(self.tempdir.name, "labeled_point") + df0.saveAsParquetFile(output_dir) + df1 = self.sqlCtx.parquetFile(output_dir) + point = df1.head().point + self.assertEquals(point, ExamplePoint(1.0, 2.0)) + + def test_column_operators(self): + ci = self.df.key + cs = self.df.value + c = ci == cs + self.assertTrue(isinstance((- ci - 1 - 2) % 3 * 2.5 / 3.5, Column)) + rcc = (1 + ci), (1 - ci), (1 * ci), (1 / ci), (1 % ci) + self.assertTrue(all(isinstance(c, Column) for c in rcc)) + cb = [ci == 5, ci != 0, ci > 3, ci < 4, ci >= 0, ci <= 7, ci and cs, ci or cs] + self.assertTrue(all(isinstance(c, Column) for c in cb)) + cbool = (ci & ci), (ci | ci), (~ci) + self.assertTrue(all(isinstance(c, Column) for c in cbool)) + css = cs.like('a'), cs.rlike('a'), cs.asc(), cs.desc(), cs.startswith('a'), cs.endswith('a') + self.assertTrue(all(isinstance(c, Column) for c in css)) + self.assertTrue(isinstance(ci.cast(LongType()), Column)) + + def test_column_select(self): + df = self.df + self.assertEqual(self.testData, df.select("*").collect()) + self.assertEqual(self.testData, df.select(df.key, df.value).collect()) + self.assertEqual([Row(value='1')], df.where(df.key == 1).select(df.value).collect()) + + def test_aggregator(self): + df = self.df + g = df.groupBy() + self.assertEqual([99, 100], sorted(g.agg({'key': 'max', 'value': 'count'}).collect()[0])) + self.assertEqual([Row(**{"AVG(key#0)": 49.5})], g.mean().collect()) + + from pyspark.sql import Dsl + self.assertEqual((0, u'99'), tuple(g.agg(Dsl.first(df.key), Dsl.last(df.value)).first())) + self.assertTrue(95 < g.agg(Dsl.approxCountDistinct(df.key)).first()[0]) + self.assertEqual(100, g.agg(Dsl.countDistinct(df.value)).first()[0]) + + def test_help_command(self): + # Regression test for SPARK-5464 + rdd = self.sc.parallelize(['{"foo":"bar"}', '{"foo":"baz"}']) + df = self.sqlCtx.jsonRDD(rdd) + # render_doc() reproduces the help() exception without printing output + pydoc.render_doc(df) + pydoc.render_doc(df.foo) + pydoc.render_doc(df.take(1)) + + +if __name__ == "__main__": + unittest.main() diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py new file mode 100644 index 0000000000..41afefe48e --- /dev/null +++ b/python/pyspark/sql/types.py @@ -0,0 +1,1279 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import decimal +import datetime +import keyword +import warnings +import json +import re +from array import array +from operator import itemgetter + + +__all__ = [ + "DataType", "NullType", "StringType", "BinaryType", "BooleanType", "DateType", + "TimestampType", "DecimalType", "DoubleType", "FloatType", "ByteType", "IntegerType", + "LongType", "ShortType", "ArrayType", "MapType", "StructField", "StructType", ] + + +class DataType(object): + + """Spark SQL DataType""" + + def __repr__(self): + return self.__class__.__name__ + + def __hash__(self): + return hash(str(self)) + + def __eq__(self, other): + return (isinstance(other, self.__class__) and + self.__dict__ == other.__dict__) + + def __ne__(self, other): + return not self.__eq__(other) + + @classmethod + def typeName(cls): + return cls.__name__[:-4].lower() + + def jsonValue(self): + return self.typeName() + + def json(self): + return json.dumps(self.jsonValue(), + separators=(',', ':'), + sort_keys=True) + + +class PrimitiveTypeSingleton(type): + + """Metaclass for PrimitiveType""" + + _instances = {} + + def __call__(cls): + if cls not in cls._instances: + cls._instances[cls] = super(PrimitiveTypeSingleton, cls).__call__() + return cls._instances[cls] + + +class PrimitiveType(DataType): + + """Spark SQL PrimitiveType""" + + __metaclass__ = PrimitiveTypeSingleton + + def __eq__(self, other): + # because they should be the same object + return self is other + + +class NullType(PrimitiveType): + + """Spark SQL NullType + + The data type representing None, used for the types which has not + been inferred. + """ + + +class StringType(PrimitiveType): + + """Spark SQL StringType + + The data type representing string values. + """ + + +class BinaryType(PrimitiveType): + + """Spark SQL BinaryType + + The data type representing bytearray values. + """ + + +class BooleanType(PrimitiveType): + + """Spark SQL BooleanType + + The data type representing bool values. + """ + + +class DateType(PrimitiveType): + + """Spark SQL DateType + + The data type representing datetime.date values. + """ + + +class TimestampType(PrimitiveType): + + """Spark SQL TimestampType + + The data type representing datetime.datetime values. + """ + + +class DecimalType(DataType): + + """Spark SQL DecimalType + + The data type representing decimal.Decimal values. + """ + + def __init__(self, precision=None, scale=None): + self.precision = precision + self.scale = scale + self.hasPrecisionInfo = precision is not None + + def jsonValue(self): + if self.hasPrecisionInfo: + return "decimal(%d,%d)" % (self.precision, self.scale) + else: + return "decimal" + + def __repr__(self): + if self.hasPrecisionInfo: + return "DecimalType(%d,%d)" % (self.precision, self.scale) + else: + return "DecimalType()" + + +class DoubleType(PrimitiveType): + + """Spark SQL DoubleType + + The data type representing float values. + """ + + +class FloatType(PrimitiveType): + + """Spark SQL FloatType + + The data type representing single precision floating-point values. + """ + + +class ByteType(PrimitiveType): + + """Spark SQL ByteType + + The data type representing int values with 1 singed byte. + """ + + +class IntegerType(PrimitiveType): + + """Spark SQL IntegerType + + The data type representing int values. + """ + + +class LongType(PrimitiveType): + + """Spark SQL LongType + + The data type representing long values. If the any value is + beyond the range of [-9223372036854775808, 9223372036854775807], + please use DecimalType. + """ + + +class ShortType(PrimitiveType): + + """Spark SQL ShortType + + The data type representing int values with 2 signed bytes. + """ + + +class ArrayType(DataType): + + """Spark SQL ArrayType + + The data type representing list values. An ArrayType object + comprises two fields, elementType (a DataType) and containsNull (a bool). + The field of elementType is used to specify the type of array elements. + The field of containsNull is used to specify if the array has None values. + + """ + + def __init__(self, elementType, containsNull=True): + """Creates an ArrayType + + :param elementType: the data type of elements. + :param containsNull: indicates whether the list contains None values. + + >>> ArrayType(StringType) == ArrayType(StringType, True) + True + >>> ArrayType(StringType, False) == ArrayType(StringType) + False + """ + self.elementType = elementType + self.containsNull = containsNull + + def __repr__(self): + return "ArrayType(%s,%s)" % (self.elementType, + str(self.containsNull).lower()) + + def jsonValue(self): + return {"type": self.typeName(), + "elementType": self.elementType.jsonValue(), + "containsNull": self.containsNull} + + @classmethod + def fromJson(cls, json): + return ArrayType(_parse_datatype_json_value(json["elementType"]), + json["containsNull"]) + + +class MapType(DataType): + + """Spark SQL MapType + + The data type representing dict values. A MapType object comprises + three fields, keyType (a DataType), valueType (a DataType) and + valueContainsNull (a bool). + + The field of keyType is used to specify the type of keys in the map. + The field of valueType is used to specify the type of values in the map. + The field of valueContainsNull is used to specify if values of this + map has None values. + + For values of a MapType column, keys are not allowed to have None values. + + """ + + def __init__(self, keyType, valueType, valueContainsNull=True): + """Creates a MapType + :param keyType: the data type of keys. + :param valueType: the data type of values. + :param valueContainsNull: indicates whether values contains + null values. + + >>> (MapType(StringType, IntegerType) + ... == MapType(StringType, IntegerType, True)) + True + >>> (MapType(StringType, IntegerType, False) + ... == MapType(StringType, FloatType)) + False + """ + self.keyType = keyType + self.valueType = valueType + self.valueContainsNull = valueContainsNull + + def __repr__(self): + return "MapType(%s,%s,%s)" % (self.keyType, self.valueType, + str(self.valueContainsNull).lower()) + + def jsonValue(self): + return {"type": self.typeName(), + "keyType": self.keyType.jsonValue(), + "valueType": self.valueType.jsonValue(), + "valueContainsNull": self.valueContainsNull} + + @classmethod + def fromJson(cls, json): + return MapType(_parse_datatype_json_value(json["keyType"]), + _parse_datatype_json_value(json["valueType"]), + json["valueContainsNull"]) + + +class StructField(DataType): + + """Spark SQL StructField + + Represents a field in a StructType. + A StructField object comprises three fields, name (a string), + dataType (a DataType) and nullable (a bool). The field of name + is the name of a StructField. The field of dataType specifies + the data type of a StructField. + + The field of nullable specifies if values of a StructField can + contain None values. + + """ + + def __init__(self, name, dataType, nullable=True, metadata=None): + """Creates a StructField + :param name: the name of this field. + :param dataType: the data type of this field. + :param nullable: indicates whether values of this field + can be null. + :param metadata: metadata of this field, which is a map from string + to simple type that can be serialized to JSON + automatically + + >>> (StructField("f1", StringType, True) + ... == StructField("f1", StringType, True)) + True + >>> (StructField("f1", StringType, True) + ... == StructField("f2", StringType, True)) + False + """ + self.name = name + self.dataType = dataType + self.nullable = nullable + self.metadata = metadata or {} + + def __repr__(self): + return "StructField(%s,%s,%s)" % (self.name, self.dataType, + str(self.nullable).lower()) + + def jsonValue(self): + return {"name": self.name, + "type": self.dataType.jsonValue(), + "nullable": self.nullable, + "metadata": self.metadata} + + @classmethod + def fromJson(cls, json): + return StructField(json["name"], + _parse_datatype_json_value(json["type"]), + json["nullable"], + json["metadata"]) + + +class StructType(DataType): + + """Spark SQL StructType + + The data type representing rows. + A StructType object comprises a list of L{StructField}. + + """ + + def __init__(self, fields): + """Creates a StructType + + >>> struct1 = StructType([StructField("f1", StringType, True)]) + >>> struct2 = StructType([StructField("f1", StringType, True)]) + >>> struct1 == struct2 + True + >>> struct1 = StructType([StructField("f1", StringType, True)]) + >>> struct2 = StructType([StructField("f1", StringType, True), + ... [StructField("f2", IntegerType, False)]]) + >>> struct1 == struct2 + False + """ + self.fields = fields + + def __repr__(self): + return ("StructType(List(%s))" % + ",".join(str(field) for field in self.fields)) + + def jsonValue(self): + return {"type": self.typeName(), + "fields": [f.jsonValue() for f in self.fields]} + + @classmethod + def fromJson(cls, json): + return StructType([StructField.fromJson(f) for f in json["fields"]]) + + +class UserDefinedType(DataType): + """ + .. note:: WARN: Spark Internal Use Only + SQL User-Defined Type (UDT). + """ + + @classmethod + def typeName(cls): + return cls.__name__.lower() + + @classmethod + def sqlType(cls): + """ + Underlying SQL storage type for this UDT. + """ + raise NotImplementedError("UDT must implement sqlType().") + + @classmethod + def module(cls): + """ + The Python module of the UDT. + """ + raise NotImplementedError("UDT must implement module().") + + @classmethod + def scalaUDT(cls): + """ + The class name of the paired Scala UDT. + """ + raise NotImplementedError("UDT must have a paired Scala UDT.") + + def serialize(self, obj): + """ + Converts the a user-type object into a SQL datum. + """ + raise NotImplementedError("UDT must implement serialize().") + + def deserialize(self, datum): + """ + Converts a SQL datum into a user-type object. + """ + raise NotImplementedError("UDT must implement deserialize().") + + def json(self): + return json.dumps(self.jsonValue(), separators=(',', ':'), sort_keys=True) + + def jsonValue(self): + schema = { + "type": "udt", + "class": self.scalaUDT(), + "pyClass": "%s.%s" % (self.module(), type(self).__name__), + "sqlType": self.sqlType().jsonValue() + } + return schema + + @classmethod + def fromJson(cls, json): + pyUDT = json["pyClass"] + split = pyUDT.rfind(".") + pyModule = pyUDT[:split] + pyClass = pyUDT[split+1:] + m = __import__(pyModule, globals(), locals(), [pyClass], -1) + UDT = getattr(m, pyClass) + return UDT() + + def __eq__(self, other): + return type(self) == type(other) + + +_all_primitive_types = dict((v.typeName(), v) + for v in globals().itervalues() + if type(v) is PrimitiveTypeSingleton and + v.__base__ == PrimitiveType) + + +_all_complex_types = dict((v.typeName(), v) + for v in [ArrayType, MapType, StructType]) + + +def _parse_datatype_json_string(json_string): + """Parses the given data type JSON string. + >>> def check_datatype(datatype): + ... scala_datatype = sqlCtx._ssql_ctx.parseDataType(datatype.json()) + ... python_datatype = _parse_datatype_json_string(scala_datatype.json()) + ... return datatype == python_datatype + >>> all(check_datatype(cls()) for cls in _all_primitive_types.values()) + True + >>> # Simple ArrayType. + >>> simple_arraytype = ArrayType(StringType(), True) + >>> check_datatype(simple_arraytype) + True + >>> # Simple MapType. + >>> simple_maptype = MapType(StringType(), LongType()) + >>> check_datatype(simple_maptype) + True + >>> # Simple StructType. + >>> simple_structtype = StructType([ + ... StructField("a", DecimalType(), False), + ... StructField("b", BooleanType(), True), + ... StructField("c", LongType(), True), + ... StructField("d", BinaryType(), False)]) + >>> check_datatype(simple_structtype) + True + >>> # Complex StructType. + >>> complex_structtype = StructType([ + ... StructField("simpleArray", simple_arraytype, True), + ... StructField("simpleMap", simple_maptype, True), + ... StructField("simpleStruct", simple_structtype, True), + ... StructField("boolean", BooleanType(), False), + ... StructField("withMeta", DoubleType(), False, {"name": "age"})]) + >>> check_datatype(complex_structtype) + True + >>> # Complex ArrayType. + >>> complex_arraytype = ArrayType(complex_structtype, True) + >>> check_datatype(complex_arraytype) + True + >>> # Complex MapType. + >>> complex_maptype = MapType(complex_structtype, + ... complex_arraytype, False) + >>> check_datatype(complex_maptype) + True + >>> check_datatype(ExamplePointUDT()) + True + >>> structtype_with_udt = StructType([StructField("label", DoubleType(), False), + ... StructField("point", ExamplePointUDT(), False)]) + >>> check_datatype(structtype_with_udt) + True + """ + return _parse_datatype_json_value(json.loads(json_string)) + + +_FIXED_DECIMAL = re.compile("decimal\\((\\d+),(\\d+)\\)") + + +def _parse_datatype_json_value(json_value): + if type(json_value) is unicode: + if json_value in _all_primitive_types.keys(): + return _all_primitive_types[json_value]() + elif json_value == u'decimal': + return DecimalType() + elif _FIXED_DECIMAL.match(json_value): + m = _FIXED_DECIMAL.match(json_value) + return DecimalType(int(m.group(1)), int(m.group(2))) + else: + raise ValueError("Could not parse datatype: %s" % json_value) + else: + tpe = json_value["type"] + if tpe in _all_complex_types: + return _all_complex_types[tpe].fromJson(json_value) + elif tpe == 'udt': + return UserDefinedType.fromJson(json_value) + else: + raise ValueError("not supported type: %s" % tpe) + + +# Mapping Python types to Spark SQL DataType +_type_mappings = { + type(None): NullType, + bool: BooleanType, + int: IntegerType, + long: LongType, + float: DoubleType, + str: StringType, + unicode: StringType, + bytearray: BinaryType, + decimal.Decimal: DecimalType, + datetime.date: DateType, + datetime.datetime: TimestampType, + datetime.time: TimestampType, +} + + +def _infer_type(obj): + """Infer the DataType from obj + + >>> p = ExamplePoint(1.0, 2.0) + >>> _infer_type(p) + ExamplePointUDT + """ + if obj is None: + raise ValueError("Can not infer type for None") + + if hasattr(obj, '__UDT__'): + return obj.__UDT__ + + dataType = _type_mappings.get(type(obj)) + if dataType is not None: + return dataType() + + if isinstance(obj, dict): + for key, value in obj.iteritems(): + if key is not None and value is not None: + return MapType(_infer_type(key), _infer_type(value), True) + else: + return MapType(NullType(), NullType(), True) + elif isinstance(obj, (list, array)): + for v in obj: + if v is not None: + return ArrayType(_infer_type(obj[0]), True) + else: + return ArrayType(NullType(), True) + else: + try: + return _infer_schema(obj) + except ValueError: + raise ValueError("not supported type: %s" % type(obj)) + + +def _infer_schema(row): + """Infer the schema from dict/namedtuple/object""" + if isinstance(row, dict): + items = sorted(row.items()) + + elif isinstance(row, tuple): + if hasattr(row, "_fields"): # namedtuple + items = zip(row._fields, tuple(row)) + elif hasattr(row, "__FIELDS__"): # Row + items = zip(row.__FIELDS__, tuple(row)) + elif all(isinstance(x, tuple) and len(x) == 2 for x in row): + items = row + else: + raise ValueError("Can't infer schema from tuple") + + elif hasattr(row, "__dict__"): # object + items = sorted(row.__dict__.items()) + + else: + raise ValueError("Can not infer schema for type: %s" % type(row)) + + fields = [StructField(k, _infer_type(v), True) for k, v in items] + return StructType(fields) + + +def _need_python_to_sql_conversion(dataType): + """ + Checks whether we need python to sql conversion for the given type. + For now, only UDTs need this conversion. + + >>> _need_python_to_sql_conversion(DoubleType()) + False + >>> schema0 = StructType([StructField("indices", ArrayType(IntegerType(), False), False), + ... StructField("values", ArrayType(DoubleType(), False), False)]) + >>> _need_python_to_sql_conversion(schema0) + False + >>> _need_python_to_sql_conversion(ExamplePointUDT()) + True + >>> schema1 = ArrayType(ExamplePointUDT(), False) + >>> _need_python_to_sql_conversion(schema1) + True + >>> schema2 = StructType([StructField("label", DoubleType(), False), + ... StructField("point", ExamplePointUDT(), False)]) + >>> _need_python_to_sql_conversion(schema2) + True + """ + if isinstance(dataType, StructType): + return any([_need_python_to_sql_conversion(f.dataType) for f in dataType.fields]) + elif isinstance(dataType, ArrayType): + return _need_python_to_sql_conversion(dataType.elementType) + elif isinstance(dataType, MapType): + return _need_python_to_sql_conversion(dataType.keyType) or \ + _need_python_to_sql_conversion(dataType.valueType) + elif isinstance(dataType, UserDefinedType): + return True + else: + return False + + +def _python_to_sql_converter(dataType): + """ + Returns a converter that converts a Python object into a SQL datum for the given type. + + >>> conv = _python_to_sql_converter(DoubleType()) + >>> conv(1.0) + 1.0 + >>> conv = _python_to_sql_converter(ArrayType(DoubleType(), False)) + >>> conv([1.0, 2.0]) + [1.0, 2.0] + >>> conv = _python_to_sql_converter(ExamplePointUDT()) + >>> conv(ExamplePoint(1.0, 2.0)) + [1.0, 2.0] + >>> schema = StructType([StructField("label", DoubleType(), False), + ... StructField("point", ExamplePointUDT(), False)]) + >>> conv = _python_to_sql_converter(schema) + >>> conv((1.0, ExamplePoint(1.0, 2.0))) + (1.0, [1.0, 2.0]) + """ + if not _need_python_to_sql_conversion(dataType): + return lambda x: x + + if isinstance(dataType, StructType): + names, types = zip(*[(f.name, f.dataType) for f in dataType.fields]) + converters = map(_python_to_sql_converter, types) + + def converter(obj): + if isinstance(obj, dict): + return tuple(c(obj.get(n)) for n, c in zip(names, converters)) + elif isinstance(obj, tuple): + if hasattr(obj, "_fields") or hasattr(obj, "__FIELDS__"): + return tuple(c(v) for c, v in zip(converters, obj)) + elif all(isinstance(x, tuple) and len(x) == 2 for x in obj): # k-v pairs + d = dict(obj) + return tuple(c(d.get(n)) for n, c in zip(names, converters)) + else: + return tuple(c(v) for c, v in zip(converters, obj)) + else: + raise ValueError("Unexpected tuple %r with type %r" % (obj, dataType)) + return converter + elif isinstance(dataType, ArrayType): + element_converter = _python_to_sql_converter(dataType.elementType) + return lambda a: [element_converter(v) for v in a] + elif isinstance(dataType, MapType): + key_converter = _python_to_sql_converter(dataType.keyType) + value_converter = _python_to_sql_converter(dataType.valueType) + return lambda m: dict([(key_converter(k), value_converter(v)) for k, v in m.items()]) + elif isinstance(dataType, UserDefinedType): + return lambda obj: dataType.serialize(obj) + else: + raise ValueError("Unexpected type %r" % dataType) + + +def _has_nulltype(dt): + """ Return whether there is NullType in `dt` or not """ + if isinstance(dt, StructType): + return any(_has_nulltype(f.dataType) for f in dt.fields) + elif isinstance(dt, ArrayType): + return _has_nulltype((dt.elementType)) + elif isinstance(dt, MapType): + return _has_nulltype(dt.keyType) or _has_nulltype(dt.valueType) + else: + return isinstance(dt, NullType) + + +def _merge_type(a, b): + if isinstance(a, NullType): + return b + elif isinstance(b, NullType): + return a + elif type(a) is not type(b): + # TODO: type cast (such as int -> long) + raise TypeError("Can not merge type %s and %s" % (a, b)) + + # same type + if isinstance(a, StructType): + nfs = dict((f.name, f.dataType) for f in b.fields) + fields = [StructField(f.name, _merge_type(f.dataType, nfs.get(f.name, NullType()))) + for f in a.fields] + names = set([f.name for f in fields]) + for n in nfs: + if n not in names: + fields.append(StructField(n, nfs[n])) + return StructType(fields) + + elif isinstance(a, ArrayType): + return ArrayType(_merge_type(a.elementType, b.elementType), True) + + elif isinstance(a, MapType): + return MapType(_merge_type(a.keyType, b.keyType), + _merge_type(a.valueType, b.valueType), + True) + else: + return a + + +def _create_converter(dataType): + """Create an converter to drop the names of fields in obj """ + if isinstance(dataType, ArrayType): + conv = _create_converter(dataType.elementType) + return lambda row: map(conv, row) + + elif isinstance(dataType, MapType): + kconv = _create_converter(dataType.keyType) + vconv = _create_converter(dataType.valueType) + return lambda row: dict((kconv(k), vconv(v)) for k, v in row.iteritems()) + + elif isinstance(dataType, NullType): + return lambda x: None + + elif not isinstance(dataType, StructType): + return lambda x: x + + # dataType must be StructType + names = [f.name for f in dataType.fields] + converters = [_create_converter(f.dataType) for f in dataType.fields] + + def convert_struct(obj): + if obj is None: + return + + if isinstance(obj, tuple): + if hasattr(obj, "_fields"): + d = dict(zip(obj._fields, obj)) + elif hasattr(obj, "__FIELDS__"): + d = dict(zip(obj.__FIELDS__, obj)) + elif all(isinstance(x, tuple) and len(x) == 2 for x in obj): + d = dict(obj) + else: + raise ValueError("unexpected tuple: %s" % str(obj)) + + elif isinstance(obj, dict): + d = obj + elif hasattr(obj, "__dict__"): # object + d = obj.__dict__ + else: + raise ValueError("Unexpected obj: %s" % obj) + + return tuple([conv(d.get(name)) for name, conv in zip(names, converters)]) + + return convert_struct + + +_BRACKETS = {'(': ')', '[': ']', '{': '}'} + + +def _split_schema_abstract(s): + """ + split the schema abstract into fields + + >>> _split_schema_abstract("a b c") + ['a', 'b', 'c'] + >>> _split_schema_abstract("a(a b)") + ['a(a b)'] + >>> _split_schema_abstract("a b[] c{a b}") + ['a', 'b[]', 'c{a b}'] + >>> _split_schema_abstract(" ") + [] + """ + + r = [] + w = '' + brackets = [] + for c in s: + if c == ' ' and not brackets: + if w: + r.append(w) + w = '' + else: + w += c + if c in _BRACKETS: + brackets.append(c) + elif c in _BRACKETS.values(): + if not brackets or c != _BRACKETS[brackets.pop()]: + raise ValueError("unexpected " + c) + + if brackets: + raise ValueError("brackets not closed: %s" % brackets) + if w: + r.append(w) + return r + + +def _parse_field_abstract(s): + """ + Parse a field in schema abstract + + >>> _parse_field_abstract("a") + StructField(a,None,true) + >>> _parse_field_abstract("b(c d)") + StructField(b,StructType(...c,None,true),StructField(d... + >>> _parse_field_abstract("a[]") + StructField(a,ArrayType(None,true),true) + >>> _parse_field_abstract("a{[]}") + StructField(a,MapType(None,ArrayType(None,true),true),true) + """ + if set(_BRACKETS.keys()) & set(s): + idx = min((s.index(c) for c in _BRACKETS if c in s)) + name = s[:idx] + return StructField(name, _parse_schema_abstract(s[idx:]), True) + else: + return StructField(s, None, True) + + +def _parse_schema_abstract(s): + """ + parse abstract into schema + + >>> _parse_schema_abstract("a b c") + StructType...a...b...c... + >>> _parse_schema_abstract("a[b c] b{}") + StructType...a,ArrayType...b...c...b,MapType... + >>> _parse_schema_abstract("c{} d{a b}") + StructType...c,MapType...d,MapType...a...b... + >>> _parse_schema_abstract("a b(t)").fields[1] + StructField(b,StructType(List(StructField(t,None,true))),true) + """ + s = s.strip() + if not s: + return + + elif s.startswith('('): + return _parse_schema_abstract(s[1:-1]) + + elif s.startswith('['): + return ArrayType(_parse_schema_abstract(s[1:-1]), True) + + elif s.startswith('{'): + return MapType(None, _parse_schema_abstract(s[1:-1])) + + parts = _split_schema_abstract(s) + fields = [_parse_field_abstract(p) for p in parts] + return StructType(fields) + + +def _infer_schema_type(obj, dataType): + """ + Fill the dataType with types inferred from obj + + >>> schema = _parse_schema_abstract("a b c d") + >>> row = (1, 1.0, "str", datetime.date(2014, 10, 10)) + >>> _infer_schema_type(row, schema) + StructType...IntegerType...DoubleType...StringType...DateType... + >>> row = [[1], {"key": (1, 2.0)}] + >>> schema = _parse_schema_abstract("a[] b{c d}") + >>> _infer_schema_type(row, schema) + StructType...a,ArrayType...b,MapType(StringType,...c,IntegerType... + """ + if dataType is None: + return _infer_type(obj) + + if not obj: + return NullType() + + if isinstance(dataType, ArrayType): + eType = _infer_schema_type(obj[0], dataType.elementType) + return ArrayType(eType, True) + + elif isinstance(dataType, MapType): + k, v = obj.iteritems().next() + return MapType(_infer_schema_type(k, dataType.keyType), + _infer_schema_type(v, dataType.valueType)) + + elif isinstance(dataType, StructType): + fs = dataType.fields + assert len(fs) == len(obj), \ + "Obj(%s) have different length with fields(%s)" % (obj, fs) + fields = [StructField(f.name, _infer_schema_type(o, f.dataType), True) + for o, f in zip(obj, fs)] + return StructType(fields) + + else: + raise ValueError("Unexpected dataType: %s" % dataType) + + +_acceptable_types = { + BooleanType: (bool,), + ByteType: (int, long), + ShortType: (int, long), + IntegerType: (int, long), + LongType: (int, long), + FloatType: (float,), + DoubleType: (float,), + DecimalType: (decimal.Decimal,), + StringType: (str, unicode), + BinaryType: (bytearray,), + DateType: (datetime.date,), + TimestampType: (datetime.datetime,), + ArrayType: (list, tuple, array), + MapType: (dict,), + StructType: (tuple, list), +} + + +def _verify_type(obj, dataType): + """ + Verify the type of obj against dataType, raise an exception if + they do not match. + + >>> _verify_type(None, StructType([])) + >>> _verify_type("", StringType()) + >>> _verify_type(0, IntegerType()) + >>> _verify_type(range(3), ArrayType(ShortType())) + >>> _verify_type(set(), ArrayType(StringType())) # doctest: +IGNORE_EXCEPTION_DETAIL + Traceback (most recent call last): + ... + TypeError:... + >>> _verify_type({}, MapType(StringType(), IntegerType())) + >>> _verify_type((), StructType([])) + >>> _verify_type([], StructType([])) + >>> _verify_type([1], StructType([])) # doctest: +IGNORE_EXCEPTION_DETAIL + Traceback (most recent call last): + ... + ValueError:... + >>> _verify_type(ExamplePoint(1.0, 2.0), ExamplePointUDT()) + >>> _verify_type([1.0, 2.0], ExamplePointUDT()) # doctest: +IGNORE_EXCEPTION_DETAIL + Traceback (most recent call last): + ... + ValueError:... + """ + # all objects are nullable + if obj is None: + return + + if isinstance(dataType, UserDefinedType): + if not (hasattr(obj, '__UDT__') and obj.__UDT__ == dataType): + raise ValueError("%r is not an instance of type %r" % (obj, dataType)) + _verify_type(dataType.serialize(obj), dataType.sqlType()) + return + + _type = type(dataType) + assert _type in _acceptable_types, "unkown datatype: %s" % dataType + + # subclass of them can not be deserialized in JVM + if type(obj) not in _acceptable_types[_type]: + raise TypeError("%s can not accept object in type %s" + % (dataType, type(obj))) + + if isinstance(dataType, ArrayType): + for i in obj: + _verify_type(i, dataType.elementType) + + elif isinstance(dataType, MapType): + for k, v in obj.iteritems(): + _verify_type(k, dataType.keyType) + _verify_type(v, dataType.valueType) + + elif isinstance(dataType, StructType): + if len(obj) != len(dataType.fields): + raise ValueError("Length of object (%d) does not match with" + "length of fields (%d)" % (len(obj), len(dataType.fields))) + for v, f in zip(obj, dataType.fields): + _verify_type(v, f.dataType) + + +_cached_cls = {} + + +def _restore_object(dataType, obj): + """ Restore object during unpickling. """ + # use id(dataType) as key to speed up lookup in dict + # Because of batched pickling, dataType will be the + # same object in most cases. + k = id(dataType) + cls = _cached_cls.get(k) + if cls is None: + # use dataType as key to avoid create multiple class + cls = _cached_cls.get(dataType) + if cls is None: + cls = _create_cls(dataType) + _cached_cls[dataType] = cls + _cached_cls[k] = cls + return cls(obj) + + +def _create_object(cls, v): + """ Create an customized object with class `cls`. """ + # datetime.date would be deserialized as datetime.datetime + # from java type, so we need to set it back. + if cls is datetime.date and isinstance(v, datetime.datetime): + return v.date() + return cls(v) if v is not None else v + + +def _create_getter(dt, i): + """ Create a getter for item `i` with schema """ + cls = _create_cls(dt) + + def getter(self): + return _create_object(cls, self[i]) + + return getter + + +def _has_struct_or_date(dt): + """Return whether `dt` is or has StructType/DateType in it""" + if isinstance(dt, StructType): + return True + elif isinstance(dt, ArrayType): + return _has_struct_or_date(dt.elementType) + elif isinstance(dt, MapType): + return _has_struct_or_date(dt.keyType) or _has_struct_or_date(dt.valueType) + elif isinstance(dt, DateType): + return True + elif isinstance(dt, UserDefinedType): + return True + return False + + +def _create_properties(fields): + """Create properties according to fields""" + ps = {} + for i, f in enumerate(fields): + name = f.name + if (name.startswith("__") and name.endswith("__") + or keyword.iskeyword(name)): + warnings.warn("field name %s can not be accessed in Python," + "use position to access it instead" % name) + if _has_struct_or_date(f.dataType): + # delay creating object until accessing it + getter = _create_getter(f.dataType, i) + else: + getter = itemgetter(i) + ps[name] = property(getter) + return ps + + +def _create_cls(dataType): + """ + Create an class by dataType + + The created class is similar to namedtuple, but can have nested schema. + + >>> schema = _parse_schema_abstract("a b c") + >>> row = (1, 1.0, "str") + >>> schema = _infer_schema_type(row, schema) + >>> obj = _create_cls(schema)(row) + >>> import pickle + >>> pickle.loads(pickle.dumps(obj)) + Row(a=1, b=1.0, c='str') + + >>> row = [[1], {"key": (1, 2.0)}] + >>> schema = _parse_schema_abstract("a[] b{c d}") + >>> schema = _infer_schema_type(row, schema) + >>> obj = _create_cls(schema)(row) + >>> pickle.loads(pickle.dumps(obj)) + Row(a=[1], b={'key': Row(c=1, d=2.0)}) + >>> pickle.loads(pickle.dumps(obj.a)) + [1] + >>> pickle.loads(pickle.dumps(obj.b)) + {'key': Row(c=1, d=2.0)} + """ + + if isinstance(dataType, ArrayType): + cls = _create_cls(dataType.elementType) + + def List(l): + if l is None: + return + return [_create_object(cls, v) for v in l] + + return List + + elif isinstance(dataType, MapType): + kcls = _create_cls(dataType.keyType) + vcls = _create_cls(dataType.valueType) + + def Dict(d): + if d is None: + return + return dict((_create_object(kcls, k), _create_object(vcls, v)) for k, v in d.items()) + + return Dict + + elif isinstance(dataType, DateType): + return datetime.date + + elif isinstance(dataType, UserDefinedType): + return lambda datum: dataType.deserialize(datum) + + elif not isinstance(dataType, StructType): + # no wrapper for primitive types + return lambda x: x + + class Row(tuple): + + """ Row in DataFrame """ + __DATATYPE__ = dataType + __FIELDS__ = tuple(f.name for f in dataType.fields) + __slots__ = () + + # create property for fast access + locals().update(_create_properties(dataType.fields)) + + def asDict(self): + """ Return as a dict """ + return dict((n, getattr(self, n)) for n in self.__FIELDS__) + + def __repr__(self): + # call collect __repr__ for nested objects + return ("Row(%s)" % ", ".join("%s=%r" % (n, getattr(self, n)) + for n in self.__FIELDS__)) + + def __reduce__(self): + return (_restore_object, (self.__DATATYPE__, tuple(self))) + + return Row + + +def _create_row(fields, values): + row = Row(*values) + row.__FIELDS__ = fields + return row + + +class Row(tuple): + + """ + A row in L{DataFrame}. The fields in it can be accessed like attributes. + + Row can be used to create a row object by using named arguments, + the fields will be sorted by names. + + >>> row = Row(name="Alice", age=11) + >>> row + Row(age=11, name='Alice') + >>> row.name, row.age + ('Alice', 11) + + Row also can be used to create another Row like class, then it + could be used to create Row objects, such as + + >>> Person = Row("name", "age") + >>> Person + + >>> Person("Alice", 11) + Row(name='Alice', age=11) + """ + + def __new__(self, *args, **kwargs): + if args and kwargs: + raise ValueError("Can not use both args " + "and kwargs to create Row") + if args: + # create row class or objects + return tuple.__new__(self, args) + + elif kwargs: + # create row objects + names = sorted(kwargs.keys()) + values = tuple(kwargs[n] for n in names) + row = tuple.__new__(self, values) + row.__FIELDS__ = names + return row + + else: + raise ValueError("No args or kwargs") + + def asDict(self): + """ + Return as an dict + """ + if not hasattr(self, "__FIELDS__"): + raise TypeError("Cannot convert a Row class into dict") + return dict(zip(self.__FIELDS__, self)) + + # let obect acs like class + def __call__(self, *args): + """create new Row object""" + return _create_row(self, args) + + def __getattr__(self, item): + if item.startswith("__"): + raise AttributeError(item) + try: + # it will be slow when it has many fields, + # but this will not be used in normal cases + idx = self.__FIELDS__.index(item) + return self[idx] + except IndexError: + raise AttributeError(item) + + def __reduce__(self): + if hasattr(self, "__FIELDS__"): + return (_create_row, (self.__FIELDS__, tuple(self))) + else: + return tuple.__reduce__(self) + + def __repr__(self): + if hasattr(self, "__FIELDS__"): + return "Row(%s)" % ", ".join("%s=%r" % (k, v) + for k, v in zip(self.__FIELDS__, self)) + else: + return "" % ", ".join(self) + + +def _test(): + import doctest + from pyspark.context import SparkContext + # let doctest run in pyspark.sql.types, so DataTypes can be picklable + import pyspark.sql.types + from pyspark.sql import Row, SQLContext + from pyspark.sql.tests import ExamplePoint, ExamplePointUDT + globs = pyspark.sql.types.__dict__.copy() + sc = SparkContext('local[4]', 'PythonTest') + globs['sc'] = sc + globs['sqlCtx'] = sqlCtx = SQLContext(sc) + globs['ExamplePoint'] = ExamplePoint + globs['ExamplePointUDT'] = ExamplePointUDT + (failure_count, test_count) = doctest.testmod( + pyspark.sql.types, globs=globs, optionflags=doctest.ELLIPSIS) + globs['sc'].stop() + if failure_count: + exit(-1) + + +if __name__ == "__main__": + _test() diff --git a/python/pyspark/sql_tests.py b/python/pyspark/sql_tests.py deleted file mode 100644 index d314f46e8d..0000000000 --- a/python/pyspark/sql_tests.py +++ /dev/null @@ -1,299 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one or more -# contributor license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright ownership. -# The ASF licenses this file to You under the Apache License, Version 2.0 -# (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -""" -Unit tests for pyspark.sql; additional tests are implemented as doctests in -individual modules. -""" -import os -import sys -import pydoc -import shutil -import tempfile - -if sys.version_info[:2] <= (2, 6): - try: - import unittest2 as unittest - except ImportError: - sys.stderr.write('Please install unittest2 to test with Python 2.6 or earlier') - sys.exit(1) -else: - import unittest - -from pyspark.sql import SQLContext, IntegerType, Row, ArrayType, StructType, StructField, \ - UserDefinedType, DoubleType -from pyspark.tests import ReusedPySparkTestCase - - -class ExamplePointUDT(UserDefinedType): - """ - User-defined type (UDT) for ExamplePoint. - """ - - @classmethod - def sqlType(self): - return ArrayType(DoubleType(), False) - - @classmethod - def module(cls): - return 'pyspark.tests' - - @classmethod - def scalaUDT(cls): - return 'org.apache.spark.sql.test.ExamplePointUDT' - - def serialize(self, obj): - return [obj.x, obj.y] - - def deserialize(self, datum): - return ExamplePoint(datum[0], datum[1]) - - -class ExamplePoint: - """ - An example class to demonstrate UDT in Scala, Java, and Python. - """ - - __UDT__ = ExamplePointUDT() - - def __init__(self, x, y): - self.x = x - self.y = y - - def __repr__(self): - return "ExamplePoint(%s,%s)" % (self.x, self.y) - - def __str__(self): - return "(%s,%s)" % (self.x, self.y) - - def __eq__(self, other): - return isinstance(other, ExamplePoint) and \ - other.x == self.x and other.y == self.y - - -class SQLTests(ReusedPySparkTestCase): - - @classmethod - def setUpClass(cls): - ReusedPySparkTestCase.setUpClass() - cls.tempdir = tempfile.NamedTemporaryFile(delete=False) - os.unlink(cls.tempdir.name) - cls.sqlCtx = SQLContext(cls.sc) - cls.testData = [Row(key=i, value=str(i)) for i in range(100)] - rdd = cls.sc.parallelize(cls.testData) - cls.df = cls.sqlCtx.inferSchema(rdd) - - @classmethod - def tearDownClass(cls): - ReusedPySparkTestCase.tearDownClass() - shutil.rmtree(cls.tempdir.name, ignore_errors=True) - - def test_udf(self): - self.sqlCtx.registerFunction("twoArgs", lambda x, y: len(x) + y, IntegerType()) - [row] = self.sqlCtx.sql("SELECT twoArgs('test', 1)").collect() - self.assertEqual(row[0], 5) - - def test_udf2(self): - self.sqlCtx.registerFunction("strlen", lambda string: len(string), IntegerType()) - self.sqlCtx.inferSchema(self.sc.parallelize([Row(a="test")])).registerTempTable("test") - [res] = self.sqlCtx.sql("SELECT strlen(a) FROM test WHERE strlen(a) > 1").collect() - self.assertEqual(4, res[0]) - - def test_udf_with_array_type(self): - d = [Row(l=range(3), d={"key": range(5)})] - rdd = self.sc.parallelize(d) - self.sqlCtx.inferSchema(rdd).registerTempTable("test") - self.sqlCtx.registerFunction("copylist", lambda l: list(l), ArrayType(IntegerType())) - self.sqlCtx.registerFunction("maplen", lambda d: len(d), IntegerType()) - [(l1, l2)] = self.sqlCtx.sql("select copylist(l), maplen(d) from test").collect() - self.assertEqual(range(3), l1) - self.assertEqual(1, l2) - - def test_broadcast_in_udf(self): - bar = {"a": "aa", "b": "bb", "c": "abc"} - foo = self.sc.broadcast(bar) - self.sqlCtx.registerFunction("MYUDF", lambda x: foo.value[x] if x else '') - [res] = self.sqlCtx.sql("SELECT MYUDF('c')").collect() - self.assertEqual("abc", res[0]) - [res] = self.sqlCtx.sql("SELECT MYUDF('')").collect() - self.assertEqual("", res[0]) - - def test_basic_functions(self): - rdd = self.sc.parallelize(['{"foo":"bar"}', '{"foo":"baz"}']) - df = self.sqlCtx.jsonRDD(rdd) - df.count() - df.collect() - df.schema() - - # cache and checkpoint - self.assertFalse(df.is_cached) - df.persist() - df.unpersist() - df.cache() - self.assertTrue(df.is_cached) - self.assertEqual(2, df.count()) - - df.registerTempTable("temp") - df = self.sqlCtx.sql("select foo from temp") - df.count() - df.collect() - - def test_apply_schema_to_row(self): - df = self.sqlCtx.jsonRDD(self.sc.parallelize(["""{"a":2}"""])) - df2 = self.sqlCtx.applySchema(df.map(lambda x: x), df.schema()) - self.assertEqual(df.collect(), df2.collect()) - - rdd = self.sc.parallelize(range(10)).map(lambda x: Row(a=x)) - df3 = self.sqlCtx.applySchema(rdd, df.schema()) - self.assertEqual(10, df3.count()) - - def test_serialize_nested_array_and_map(self): - d = [Row(l=[Row(a=1, b='s')], d={"key": Row(c=1.0, d="2")})] - rdd = self.sc.parallelize(d) - df = self.sqlCtx.inferSchema(rdd) - row = df.head() - self.assertEqual(1, len(row.l)) - self.assertEqual(1, row.l[0].a) - self.assertEqual("2", row.d["key"].d) - - l = df.map(lambda x: x.l).first() - self.assertEqual(1, len(l)) - self.assertEqual('s', l[0].b) - - d = df.map(lambda x: x.d).first() - self.assertEqual(1, len(d)) - self.assertEqual(1.0, d["key"].c) - - row = df.map(lambda x: x.d["key"]).first() - self.assertEqual(1.0, row.c) - self.assertEqual("2", row.d) - - def test_infer_schema(self): - d = [Row(l=[], d={}), - Row(l=[Row(a=1, b='s')], d={"key": Row(c=1.0, d="2")}, s="")] - rdd = self.sc.parallelize(d) - df = self.sqlCtx.inferSchema(rdd) - self.assertEqual([], df.map(lambda r: r.l).first()) - self.assertEqual([None, ""], df.map(lambda r: r.s).collect()) - df.registerTempTable("test") - result = self.sqlCtx.sql("SELECT l[0].a from test where d['key'].d = '2'") - self.assertEqual(1, result.head()[0]) - - df2 = self.sqlCtx.inferSchema(rdd, 1.0) - self.assertEqual(df.schema(), df2.schema()) - self.assertEqual({}, df2.map(lambda r: r.d).first()) - self.assertEqual([None, ""], df2.map(lambda r: r.s).collect()) - df2.registerTempTable("test2") - result = self.sqlCtx.sql("SELECT l[0].a from test2 where d['key'].d = '2'") - self.assertEqual(1, result.head()[0]) - - def test_struct_in_map(self): - d = [Row(m={Row(i=1): Row(s="")})] - rdd = self.sc.parallelize(d) - df = self.sqlCtx.inferSchema(rdd) - k, v = df.head().m.items()[0] - self.assertEqual(1, k.i) - self.assertEqual("", v.s) - - def test_convert_row_to_dict(self): - row = Row(l=[Row(a=1, b='s')], d={"key": Row(c=1.0, d="2")}) - self.assertEqual(1, row.asDict()['l'][0].a) - rdd = self.sc.parallelize([row]) - df = self.sqlCtx.inferSchema(rdd) - df.registerTempTable("test") - row = self.sqlCtx.sql("select l, d from test").head() - self.assertEqual(1, row.asDict()["l"][0].a) - self.assertEqual(1.0, row.asDict()['d']['key'].c) - - def test_infer_schema_with_udt(self): - from pyspark.sql_tests import ExamplePoint, ExamplePointUDT - row = Row(label=1.0, point=ExamplePoint(1.0, 2.0)) - rdd = self.sc.parallelize([row]) - df = self.sqlCtx.inferSchema(rdd) - schema = df.schema() - field = [f for f in schema.fields if f.name == "point"][0] - self.assertEqual(type(field.dataType), ExamplePointUDT) - df.registerTempTable("labeled_point") - point = self.sqlCtx.sql("SELECT point FROM labeled_point").head().point - self.assertEqual(point, ExamplePoint(1.0, 2.0)) - - def test_apply_schema_with_udt(self): - from pyspark.sql_tests import ExamplePoint, ExamplePointUDT - row = (1.0, ExamplePoint(1.0, 2.0)) - rdd = self.sc.parallelize([row]) - schema = StructType([StructField("label", DoubleType(), False), - StructField("point", ExamplePointUDT(), False)]) - df = self.sqlCtx.applySchema(rdd, schema) - point = df.head().point - self.assertEquals(point, ExamplePoint(1.0, 2.0)) - - def test_parquet_with_udt(self): - from pyspark.sql_tests import ExamplePoint - row = Row(label=1.0, point=ExamplePoint(1.0, 2.0)) - rdd = self.sc.parallelize([row]) - df0 = self.sqlCtx.inferSchema(rdd) - output_dir = os.path.join(self.tempdir.name, "labeled_point") - df0.saveAsParquetFile(output_dir) - df1 = self.sqlCtx.parquetFile(output_dir) - point = df1.head().point - self.assertEquals(point, ExamplePoint(1.0, 2.0)) - - def test_column_operators(self): - from pyspark.sql import Column, LongType - ci = self.df.key - cs = self.df.value - c = ci == cs - self.assertTrue(isinstance((- ci - 1 - 2) % 3 * 2.5 / 3.5, Column)) - rcc = (1 + ci), (1 - ci), (1 * ci), (1 / ci), (1 % ci) - self.assertTrue(all(isinstance(c, Column) for c in rcc)) - cb = [ci == 5, ci != 0, ci > 3, ci < 4, ci >= 0, ci <= 7, ci and cs, ci or cs] - self.assertTrue(all(isinstance(c, Column) for c in cb)) - cbool = (ci & ci), (ci | ci), (~ci) - self.assertTrue(all(isinstance(c, Column) for c in cbool)) - css = cs.like('a'), cs.rlike('a'), cs.asc(), cs.desc(), cs.startswith('a'), cs.endswith('a') - self.assertTrue(all(isinstance(c, Column) for c in css)) - self.assertTrue(isinstance(ci.cast(LongType()), Column)) - - def test_column_select(self): - df = self.df - self.assertEqual(self.testData, df.select("*").collect()) - self.assertEqual(self.testData, df.select(df.key, df.value).collect()) - self.assertEqual([Row(value='1')], df.where(df.key == 1).select(df.value).collect()) - - def test_aggregator(self): - df = self.df - g = df.groupBy() - self.assertEqual([99, 100], sorted(g.agg({'key': 'max', 'value': 'count'}).collect()[0])) - self.assertEqual([Row(**{"AVG(key#0)": 49.5})], g.mean().collect()) - - from pyspark.sql import Dsl - self.assertEqual((0, u'99'), tuple(g.agg(Dsl.first(df.key), Dsl.last(df.value)).first())) - self.assertTrue(95 < g.agg(Dsl.approxCountDistinct(df.key)).first()[0]) - self.assertEqual(100, g.agg(Dsl.countDistinct(df.value)).first()[0]) - - def test_help_command(self): - # Regression test for SPARK-5464 - rdd = self.sc.parallelize(['{"foo":"bar"}', '{"foo":"baz"}']) - df = self.sqlCtx.jsonRDD(rdd) - # render_doc() reproduces the help() exception without printing output - pydoc.render_doc(df) - pydoc.render_doc(df.foo) - pydoc.render_doc(df.take(1)) - - -if __name__ == "__main__": - unittest.main() -- cgit v1.2.3