diff options
Diffstat (limited to 'python')
-rw-r--r-- | python/pyspark/sql/tests.py | 13 | ||||
-rw-r--r-- | python/pyspark/sql/types.py | 4 | ||||
-rwxr-xr-x | python/run-tests.py | 3 |
3 files changed, 19 insertions, 1 deletions
diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 333378c7f1..66827d4885 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -700,6 +700,19 @@ class SQLTests(ReusedPySparkTestCase): self.assertTrue(now - now1 < datetime.timedelta(0.001)) self.assertTrue(now - utcnow1 < datetime.timedelta(0.001)) + def test_decimal(self): + from decimal import Decimal + schema = StructType([StructField("decimal", DecimalType(10, 5))]) + df = self.sqlCtx.createDataFrame([(Decimal("3.14159"),)], schema) + row = df.select(df.decimal + 1).first() + self.assertEqual(row[0], Decimal("4.14159")) + tmpPath = tempfile.mkdtemp() + shutil.rmtree(tmpPath) + df.write.parquet(tmpPath) + df2 = self.sqlCtx.read.parquet(tmpPath) + row = df2.first() + self.assertEqual(row[0], Decimal("3.14159")) + def test_dropna(self): schema = StructType([ StructField("name", StringType(), True), diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py index 160df40d65..7e64cb0b54 100644 --- a/python/pyspark/sql/types.py +++ b/python/pyspark/sql/types.py @@ -1069,6 +1069,10 @@ def _verify_type(obj, dataType): if obj is None: return + # StringType can work with any types + if isinstance(dataType, StringType): + return + if isinstance(dataType, UserDefinedType): if not (hasattr(obj, '__UDT__') and obj.__UDT__ == dataType): raise ValueError("%r is not an instance of type %r" % (obj, dataType)) diff --git a/python/run-tests.py b/python/run-tests.py index 7638854def..cc56077937 100755 --- a/python/run-tests.py +++ b/python/run-tests.py @@ -72,7 +72,8 @@ LOGGER = logging.getLogger() def run_individual_python_test(test_name, pyspark_python): - env = {'SPARK_TESTING': '1', 'PYSPARK_PYTHON': which(pyspark_python)} + env = dict(os.environ) + env.update({'SPARK_TESTING': '1', 'PYSPARK_PYTHON': which(pyspark_python)}) LOGGER.debug("Starting test(%s): %s", pyspark_python, test_name) start_time = time.time() try: |